自学内容网 自学内容网

【PYG】端到端训练,GNN中提取节点特征然后通过全连接层(FC)

在许多情况下,直接在GNN中提取节点特征然后通过全连接层(FC)进行预测的效果往往比先提取节点特征然后再单独用FC进行预测要好。这种效果差异背后的主要原因包括以下几点:

1. 端到端训练(End-to-End Training)

直接在GNN中提取节点特征并通过FC进行预测允许模型端到端地进行训练。在端到端训练中,所有网络参数(包括GNN层和FC层的参数)在每次反向传播中都可以得到更新。这使得模型能够联合优化特征提取和最终预测,从而在整个过程中学到更好的特征。

class GNNWithFC(nn.Module):
    def __init__(self, in_channels, hidden_channels, gnn_out_channels, num_classes):
        super(GNNWithFC, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, gnn_out_channels)
        self.fc = nn.Linear(gnn_out_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        gnn_features = F.relu(x)
        out = self.fc(gnn_features)
        return out, gnn_features

# 创建模型并进行端到端训练
model = GNNWithFC(in_channels=num_node_features, hidden_channels=32, gnn_out_channels=64, num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    output, gnn_features = model(batch)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

2. 特征提取和分类的联合优化

在联合优化的情况下,GNN层在特征提取时不仅仅考虑到节点的本地信息,还可以通过梯度反传学习到那些对FC层分类效果最有帮助的特征。这种联合优化能够使得特征提取器(GNN)和分类器(FC)之间的协同作用更加显著。

3. 避免特征过拟合

如果先提取节点特征然后再单独用FC进行预测,中间存储的特征可能会引入过拟合的风险。特征提取器(GNN)可能会生成一些对当前任务无关紧要甚至有害的特征,但这些特征不会被及时优化掉,因为它们不在整个模型的反馈回路内。

4. 更好地利用图结构信息

GNN直接提取的特征和FC联合优化能够更好地捕捉图结构信息和节点特征之间的关系。如果先提取特征然后再用FC进行预测,图结构信息可能在提取特征的过程中被部分丢失,从而影响最终的预测效果。

5. 减少模型复杂度和计算开销

在单一模型内直接进行端到端训练可以简化模型结构,减少额外的数据存储和处理步骤,从而降低计算开销和存储需求。

对比:分阶段训练

如果分阶段训练,即先提取特征然后再进行分类,代码可能如下:

# 第一步:提取特征
gnn_model = GCN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
gnn_features = []
for graph in graphs:
    gnn_features.append(gnn_model(graph.x, graph.edge_index))
gnn_features = torch.cat(gnn_features, dim=0)

# 第二步:用FC层进行分类
fc_model = nn.Linear(64, num_classes)
optimizer = torch.optim.Adam(fc_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(100):
    fc_model.train()
    optimizer.zero_grad()
    output = fc_model(gnn_features)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

总结

直接在GNN中提取节点特征然后通过FC进行预测的效果通常更好,因为它能够实现端到端训练,联合优化特征提取和分类,避免特征过拟合,更好地利用图结构信息,并简化模型复杂度和计算开销。这些优势使得直接在GNN中进行特征提取和分类的模型通常具有更好的性能和更高的稳定性。


原文地址:https://blog.csdn.net/xiong_xin/article/details/140235020

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!