从GCN到GAT:基于PyTorch Geometric的Cora论文分类实战与可视化分析
1. 从零开始理解Cora数据集第一次接触Cora数据集时我完全被那些论文引用关系搞晕了。这个数据集就像学术界的社交网络每篇论文都是一个人引用关系就是谁认识谁。具体来说Cora包含2708篇机器学习论文每篇论文被表示为节点引用关系就是连接这些节点的边。数据集里最有趣的是特征表示方式。每篇论文的特征是一个1433维的向量对应1433个关键词的one-hot编码。简单说就是如果论文包含某个关键词对应位置就是1否则是0。这种表示方法虽然简单粗暴但特别适合我们理解图神经网络如何处理非结构化数据。下载数据集时有个小技巧原始链接有时不太稳定我通常会先下载到本地再处理。解压后你会看到两个关键文件cora.cites记录论文间的引用关系cora.content包含论文特征和类别标签# 数据集路径示例建议放在项目根目录的data文件夹 path data/cora/ cites path cora.cites content path cora.content处理数据时最容易踩的坑是节点编号。原始论文ID是不连续的我们需要重新映射为0开始的连续索引。我专门写了个index_dict来做这个转换否则后面构建邻接矩阵时会出大问题。2. 图卷积网络(GCN)实战详解GCN的核心思想特别像人际关系的传播你的特征会受朋友影响朋友的特征又受他们朋友的影响。在代码实现时PyTorch Geometric的GCNConv层帮我们封装了所有复杂的数学运算。先来看网络结构设计。我习惯用两层GCN第一层将1433维特征降到16维第二层再降到类别数Cora是7类。中间加ReLU激活和Dropout防止过拟合class GCNNet(torch.nn.Module): def __init__(self, num_feature, num_label): super(GCNNet,self).__init__() self.GCN1 GCNConv(num_feature, 16) self.GCN2 GCNConv(16, num_label) self.dropout torch.nn.Dropout(p0.5) def forward(self, data): x, edge_index data.x, data.edge_index x self.GCN1(x, edge_index) x F.relu(x) x self.dropout(x) x self.GCN2(x, edge_index) return F.log_softmax(x, dim1)训练时有个重要细节Cora的标准划分是只用140个节点做训练。刚开始我觉得这么少数据肯定不行但实测发现GCN的泛化能力惊人。关键是要做好以下三点使用Adam优化器学习率设为0.01添加L2正则化(weight_decay5e-4)Dropout率设为0.5我的实验记录显示大约在50轮后验证集准确率就稳定在81%左右。这比传统ML方法高出至少20%充分展示了图结构的价值。3. 图注意力网络(GAT)实现技巧GAT就像是给GCN装上了智能眼镜让它能自动关注重要的邻居节点。我第一次看到GAT的多头注意力机制时感觉就像发现了新大陆。实现时要注意几个关键参数heads注意力头数常用8个concat是否拼接各头的输出dropout注意力系数的丢弃率class GATNet(torch.nn.Module): def __init__(self, num_feature, num_label): super(GATNet,self).__init__() self.GAT1 GATConv(num_feature, 8, heads8, concatTrue, dropout0.6) self.GAT2 GATConv(8*8, num_label, dropout0.6) def forward(self, data): x, edge_index data.x, data.edge_index x self.GAT1(x, edge_index) x F.relu(x) x self.GAT2(x, edge_index) return F.log_softmax(x, dim1)训练GAT比GCN更考验耐心。由于参数更多需要适当调大Dropout率我设为0.6否则很容易过拟合。另外发现学习率不宜过大否则会出现震荡。经过200轮训练后测试集准确率能达到约83%比GCN高出2个百分点。有个实用技巧使用torch.manual_seed固定随机种子这样每次运行结果可以复现。我在不同随机种子下测试发现GAT的性能波动比GCN大说明它对参数初始化更敏感。4. 模型对比与特征可视化训练完模型只是开始真正的乐趣在于分析它们学到了什么。我用t-SNE将最后一层的节点嵌入降维到2D空间结果非常有意思。先看GCN的特征分布同类节点聚集成云状不同类别间有重叠区域边界比较模糊而GAT的特征分布类别边界更清晰同类节点聚集更紧密存在明显的类别过渡带# t-SNE可视化代码片段 ts TSNE(n_components2) embedding ts.fit_transform(out[test_mask].cpu().detach().numpy()) plt.scatter(embedding[:,0], embedding[:,1], clabels[test_mask])通过对比可以直观理解GAT的优势注意力机制让模型能够聚焦关键邻居学到的特征判别性更强。不过GAT计算量也更大在我的笔记本上训练时间比GCN长约40%。建议大家在分析时关注几个关键点不同类别间的混淆情况异常点的分布位置特征空间的密度分布这些观察能帮助我们调整模型结构。比如发现某个类别总是混淆可能需要增加该类的训练样本或者调整损失函数的类别权重。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2502545.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!