【PYG】端到端训练,GNN中提取节点特征然后通过全连接层(FC)
在许多情况下,直接在GNN中提取节点特征然后通过全连接层(FC)进行预测的效果往往比先提取节点特征然后再单独用FC进行预测要好。这种效果差异背后的主要原因包括以下几点:
1. 端到端训练(End-to-End Training)
直接在GNN中提取节点特征并通过FC进行预测允许模型端到端地进行训练。在端到端训练中,所有网络参数(包括GNN层和FC层的参数)在每次反向传播中都可以得到更新。这使得模型能够联合优化特征提取和最终预测,从而在整个过程中学到更好的特征。
class GNNWithFC(nn.Module):
def __init__(self, in_channels, hidden_channels, gnn_out_channels, num_classes):
super(GNNWithFC, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, gnn_out_channels)
self.fc = nn.Linear(gnn_out_channels, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
gnn_features = F.relu(x)
out = self.fc(gnn_features)
return out, gnn_features
# 创建模型并进行端到端训练
model = GNNWithFC(in_channels=num_node_features, hidden_channels=32, gnn_out_channels=64, num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
model.train()
optimizer.zero_grad()
output, gnn_features = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
2. 特征提取和分类的联合优化
在联合优化的情况下,GNN层在特征提取时不仅仅考虑到节点的本地信息,还可以通过梯度反传学习到那些对FC层分类效果最有帮助的特征。这种联合优化能够使得特征提取器(GNN)和分类器(FC)之间的协同作用更加显著。
3. 避免特征过拟合
如果先提取节点特征然后再单独用FC进行预测,中间存储的特征可能会引入过拟合的风险。特征提取器(GNN)可能会生成一些对当前任务无关紧要甚至有害的特征,但这些特征不会被及时优化掉,因为它们不在整个模型的反馈回路内。
4. 更好地利用图结构信息
GNN直接提取的特征和FC联合优化能够更好地捕捉图结构信息和节点特征之间的关系。如果先提取特征然后再用FC进行预测,图结构信息可能在提取特征的过程中被部分丢失,从而影响最终的预测效果。
5. 减少模型复杂度和计算开销
在单一模型内直接进行端到端训练可以简化模型结构,减少额外的数据存储和处理步骤,从而降低计算开销和存储需求。
对比:分阶段训练
如果分阶段训练,即先提取特征然后再进行分类,代码可能如下:
# 第一步:提取特征
gnn_model = GCN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
gnn_features = []
for graph in graphs:
gnn_features.append(gnn_model(graph.x, graph.edge_index))
gnn_features = torch.cat(gnn_features, dim=0)
# 第二步:用FC层进行分类
fc_model = nn.Linear(64, num_classes)
optimizer = torch.optim.Adam(fc_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
fc_model.train()
optimizer.zero_grad()
output = fc_model(gnn_features)
loss = criterion(output, target)
loss.backward()
optimizer.step()
总结
直接在GNN中提取节点特征然后通过FC进行预测的效果通常更好,因为它能够实现端到端训练,联合优化特征提取和分类,避免特征过拟合,更好地利用图结构信息,并简化模型复杂度和计算开销。这些优势使得直接在GNN中进行特征提取和分类的模型通常具有更好的性能和更高的稳定性。
原文地址:https://blog.csdn.net/xiong_xin/article/details/140235020
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!