图卷积网络:从理论到实践

news2025/6/8 23:54:05

图卷积网络(Graph Convolutional Networks, GCNs)彻底改变了基于图的机器学习领域,使得深度学习能够应用于非欧几里得结构,如社交网络、引文网络和分子结构。本文将解释GCN的直观理解、数学原理,并提供代码片段帮助您理解和实现基础的GCN。

图表示法基础

定义图G = (V, E),其中:

  • V V V:节点集合
  • E E E:边集合
  • A ∈ R N × N A \in \mathbb{R}^{N \times N} ARN×N:邻接矩阵
  • X ∈ R N × F X \in \mathbb{R}^{N \times F} XRN×F:节点特征矩阵

其中, N N N是节点数量, F F F是每个节点的输入特征数量。

邻接矩阵

邻接矩阵是表示图中节点之间连接(边)的一种方式。

  • 对于具有 N N N个节点的图, A A A是一个 N × N N \times N N×N的矩阵。
  • 如果节点 i i i和节点 j j j之间有边,则 A i j = 1 A_{ij} = 1 Aij=1(如果带权重,则为边的权重);否则 A i j = 0 A_{ij} = 0 Aij=0
  • 在无向图中, A A A是对称的( A i j = A j i A_{ij} = A_{ji} Aij=Aji)。
  • 例如,一个3节点图,其中节点0连接到节点1和2:
    A = [ 0 1 1 1 0 0 1 0 0 ] A = \begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 0 \\ 1 & 0 & 0 \end{bmatrix} A= 011100100

请添加图片描述

节点特征矩阵

节点特征矩阵存储图中每个节点的特征(属性)。

  • N N N是节点数量, F F F是每个节点的特征数量。
  • 每一行 X i X_i Xi是节点 i i i的特征向量。
  • 例如,如果每个节点有3个特征(比如年龄、收入和组别),共有4个节点:
    X = [ 23 50000 1 35 60000 2 29 52000 1 41 58000 3 ] X = \begin{bmatrix} 23 & 50000 & 1 \\ 35 & 60000 & 2 \\ 29 & 52000 & 1 \\ 41 & 58000 & 3 \end{bmatrix} X= 23352941500006000052000580001213
  • 这些特征是GCN用来学习的输入。

两者共同构成了图卷积网络的基本输入:

  • 邻接矩阵 A A A描述了节点如何连接。
  • 节点特征矩阵 X X X描述了每个节点的特征。

GCN层公式(Kipf & Welling, 2016)

GCN层的核心公式如下:

H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)}) H(l+1)=σ(D~1/2A~D~1/2H(l)W(l))

这个公式包含了很多信息,我们将在下面详细解析:

输入:
  • H ( l ) H^{(l)} H(l):上一层的节点特征(对于第一层, H ( 0 ) = X H^{(0)} = X H(0)=X,即输入特征)
  • A ~ = A + I \tilde{A} = A + I A~=A+I:添加了自环的邻接矩阵( I I I是单位矩阵)。图中的自环是指节点与自身相连的边。在邻接矩阵中,节点 i i i的自环表示为 A ~ i i = 1 \tilde{A}_{ii} = 1 A~ii=1。添加自环后,我们得到新矩阵: A ~ = A + I \tilde{A} = A + I A~=A+I。这一步很重要,因为我们希望在聚合时保留节点自身的特征。否则,节点只能从邻居获取信息,而丢失了自身特征。
  • D ~ \tilde{D} D~ A ~ \tilde{A} A~的对角度矩阵(包含每个节点的连接数,包括自环)
  • W ( l ) W^{(l)} W(l):第 l l l层的可训练权重矩阵
  • σ \sigma σ:非线性激活函数(如ReLU)
关键操作:
  • 消息传递:
    • A ~ H ( l ) \tilde{A}H^{(l)} A~H(l):每个节点聚合其邻居的特征向量
    • 添加自环( A ~ = A + I \tilde{A} = A + I A~=A+I)确保节点在聚合时包含自身特征
  • 归一化:防止特征尺度在层间变化过大,通过节点度进行归一化有助于训练稳定性
    • D ~ − 1 / 2 A ~ D ~ − 1 / 2 \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} D~1/2A~D~1/2:这步称为对称归一化或重归一化技巧。
    • 如果没有归一化,具有许多连接(高度数)的节点在聚合后会有更大的特征值,这可能导致数值不稳定和训练困难。
    • D ~ \tilde{D} D~:度矩阵(对角矩阵,其中 D ~ i i = ∑ j A ~ i j \tilde{D}_{ii} = \sum_j \tilde{A}_{ij} D~ii=jA~ij
    • D ~ − 1 / 2 \tilde{D}^{-1/2} D~1/2:度矩阵的逆平方根
    • 左乘( D ~ − 1 / 2 A ~ \tilde{D}^{-1/2} \tilde{A} D~1/2A~):将每一行除以节点度数的平方根。这归一化了每个节点发出消息的影响。
    • 右乘( ⋅ D ~ − 1 / 2 \cdot \tilde{D}^{-1/2} D~1/2):将每一列除以节点度数的平方根。这归一化了每个节点接收消息的影响。

考虑一个简单的3节点图:

节点0连接到节点1
节点1连接到节点0和2
节点2连接到节点1

添加自环后:

A = [[1, 1, 0],
    [1, 1, 1],
    [0, 1, 1]]
    
D = [[2, 0, 0],
    [0, 3, 0],
    [0, 0, 2]]  # 度数:2, 3, 2

D^(-1/2) = [[1/√2, 0,     0   ],
           [0,    1/√3,  0   ],
           [0,    0,     1/√2]]

归一化后的矩阵为:

D^(-1/2)AD^(-1/2) = 
  [[1/2,   1/√6,    0   ],
  [1/√6,  1/3,    1/√6  ],
  [0,     1/√6,    1/2  ]]

在每一层,节点都会聚合来自其邻居(包括自身)的信息。网络越深,信息传播得越远。每个节点的新表示是其自身特征和邻居特征的加权平均。权重通过训练过程学习得到。归一化确保具有许多邻居的节点不会主导学习过程。

在社交网络中,每个人(节点)都有一些特征(如年龄、兴趣等),GCN层让每个人根据其朋友的信息更新自己的理解。归一化确保受欢迎的人(有很多朋友)不会主导学习过程。

在Cora数据集上实现节点分类的GCN

Cora数据集是一个引文网络,其中节点代表学术论文,边代表引用关系。每篇论文都有一组特征(如作者、标题、摘要)和一个标签(如论文主题)。总共有2,780篇论文(节点)和5,429条引用(边)。每篇论文由一个二进制词向量表示,表示1,433个唯一词典单词的存在(1)或不存在(0)。论文被分为7个类别(如神经网络、概率方法等)。目标是根据每篇论文的特征和引用关系预测其类别。

模型架构

GCN模型有2层:

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)  # 输入到隐藏层
        self.conv2 = GCNConv(16, dataset.num_classes)       # 隐藏层到输出

第一层GCN将输入特征(1,433维)降维到16维。第二层GCN将16维降维到7维(类别数)。

前向传播函数
def forward(self):
    x, edge_index = data.x, data.edge_index
    x = self.conv1(x, edge_index)  # 第一层GCN
    x = F.relu(x)                  # 非线性激活
    x = F.dropout(x, training=self.training)  # 可选的dropout
    x = self.conv2(x, edge_index)  # 第二层GCN
    return F.log_softmax(x, dim=1)  # 每个类别的对数概率

x = self.conv1(x, edge_index) 做了几件事:它向图中添加自环,计算归一化邻接矩阵 D ~ − 1 / 2 A ~ D ~ − 1 / 2 \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} D~1/2A~D~1/2,与输入特征和权重 H ( l ) W ( l ) H^{(l)} W^{(l)} H(l)W(l)相乘,并应用归一化和聚合。基本上,所有复杂的数学运算都由GCNConv层处理了。F.relu(x)应用ReLU激活函数,F.dropout(x, training=self.training)应用dropout来防止过拟合。第二层GCN x = self.conv2(x, edge_index) 做同样的事情,但是使用不同的权重 H ( l ) W ( l ) H^{(l)} W^{(l)} H(l)W(l)

训练过程
model = GCN()
data = dataset[0]  # 获取第一个图对象
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

我们使用带权重衰减的Adam优化器。Adam是一种自适应学习率优化算法,它结合了AdaGrad和RMSProp的优点。它维护每个参数的学习率,并使用梯度的移动平均和梯度平方的移动平均。由于稀疏梯度在GNN中很常见,使用Adam是合理的。

它有两个主要参数:lr是学习率,weight_decay是L2正则化参数。权重衰减通过向损失函数添加惩罚项来防止过拟合,并将模型权重推向较小的值,防止任何单个权重变得过大。使用L2时,原始损失 L ( θ ) L(\theta) L(θ)变为 L ( θ ) + λ ∑ θ i 2 L(\theta) + \lambda \sum \theta_i^2 L(θ)+λθi2,其中 λ \lambda λ是权重衰减参数。weight_decay=5e-4意味着 λ = 0.0005 \lambda = 0.0005 λ=0.0005。它通过保持权重较小来防止过拟合,并使模型对未见过的数据更具泛化能力。

loss = F.nll_loss(...)是负对数似然损失(NLL),通常用于分类任务。它衡量模型的预测概率与真实标签的匹配程度。对于单个样本,它表示为 − log ⁡ ( p 真实类别 ) -\log(p_{\text{真实类别}}) log(p真实类别)。如果模型对正确类别100%确信,则损失为0。data.train_mask是一个布尔掩码,指示哪些节点在训练集中。data.y是每个节点的标签。我们只使用train_mask为True的节点进行训练。val_mask用于验证的节点,test_mask用于最终评估的节点。

与许多图数据集一样,标签仅对节点的一个小子集可用,模型通过有监督损失从标记节点学习,并通过图结构从未标记节点学习。因此,这是半监督学习。在Cora数据集中,总共有2,708个节点,其中约140个节点(5%)用于训练,500个用于验证,1000个用于测试。GCN假设相连的节点可能相似。这被称为同质性假设,它被编码到学习算法中。GCN的消息传递直接编码了这些偏差。

模型评估
model.eval()
pred = model().argmax(dim=1)  # 获取预测类别
correct = pred[data.test_mask] == data.y[data.test_mask]
accuracy = int(correct.sum()) / int(data.test_mask.sum())

完整代码如下。首先,安装必要的包:

pip install torch-geometric

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

# 加载数据
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

# 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 训练循环
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model()
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

# 评估
model.eval()
pred = model().argmax(dim=1)
correct = pred[data.test_mask] == data.y[data.test_mask]
accuracy = int(correct.sum()) / int(data.test_mask.sum())
print(f'测试准确率: {accuracy:.4f}')

运行结果:

Epoch 0, Loss: 1.9515
Epoch 20, Loss: 0.1116
Epoch 40, Loss: 0.0147
Epoch 60, Loss: 0.0142
Epoch 80, Loss: 0.0166
Epoch 100, Loss: 0.0155
Epoch 120, Loss: 0.0137
Epoch 140, Loss: 0.0124
Epoch 160, Loss: 0.0114
Epoch 180, Loss: 0.0107
测试准确率: 0.8100

我们可以看到,模型在只看到少量标记节点的情况下就能达到相当不错的准确率(81%)。这展示了图结构与节点特征结合的力量。在下一篇博客中,我们将介绍EvolveGCN,这是一个可以处理动态图数据的动态GCN模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2404721.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ES 学习总结一 基础内容

ElasticSearch学习 一、 初识ES1、 认识与安装2、 倒排索引2.1 正向索引2.2 倒排索引 3、 基本概念3.1 文档和字段3.2 索引和倒排 4 、 IK分词器 二、 操作1、 mapping 映射属性2、 索引库增删改查3、 文档的增删改查3.1 新增文档3.2 查询文档3.3 删除文档3.4 修改文档3.5 批处…

Maven 构建缓存与离线模式

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c=1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,高并发设计,Springboot和微服务,熟悉Linux,ESXI虚拟化以及云原生Docker和K8s,热衷于探…

基于51单片机的光强控制LED灯亮灭

目录 具体实现功能 设计介绍 资料内容 全部内容 资料获取 具体实现功能 具体功能: (1)按下按键K后光敏电阻进行光照检测,LCD1602显示光照强度值; (2)光照值小于15时,上面2个LE…

【Linux操作系统】基础开发工具(yum、vim、gcc/g++)

文章目录 Linux软件包管理器 - yumLinux下的三种安装方式什么是软件包认识Yum与RPMyum常用指令更新软件安装与卸载查找与搜索清理缓存与重建元数据 yum源更新1. 备份现有的 yum 源配置2. 下载新的 repo 文件3. 清理并重建缓存 Linux编辑器 - vim启动vimVim 的三种主要模式常用操…

【Survival Analysis】【机器学习】【3】 SHAP可解釋 AI

前言: SHAP(SHapley Additive explanations) 是一种基于博弈论的可解释工具。 现在很多高分的 论文里面都会带这种基于SHAP 分析的图,用于评估机器学习模型中特征对预测结果的贡献度. pip install -i https://pypi.tuna.tsinghua.edu.cn/sim…

ModuleNotFoundError No module named ‘torch_geometric‘未找到

ModuleNotFoundError: No module named torch_geometric’未找到 试了很多方法,都没成功,安装torch对应版本的torch_geometric都不行, 后来发现是pip被设置了环境变量,所有pip文件都给安装在了一个文件夹了 排查建议 1. 检查 p…

Cell-o1:强化学习训练LLM解决单细胞推理问题

细胞类型注释是分析scRNA-seq数据异质性的关键任务。尽管最近的基础模型实现了这一过程的自动化,但它们通常独立注释细胞,未考虑批次水平的细胞背景或提供解释性推理。相比之下,人类专家常基于领域知识为不同细胞簇注释不同的细胞类型。为模拟…

vue3: bingmap using typescript

项目结构&#xff1a; <template><div class"bing-map-market"><!-- 加载遮罩层 --><div class"loading-overlay" v-show"isLoading || errorMessage"><div class"spinner-container"><div class&qu…

超大规模芯片验证:基于AMD VP1902的S8-100原型验证系统实测性能翻倍

引言&#xff1a; 随着AI、HPC及超大规模芯片设计需求呈指数级增长原型验证平台已成为芯片设计流程中验证复杂架构、缩短迭代周期的核心工具。然而&#xff0c;传统原型验证系统受限于单芯片容量&#xff08;通常<5000万门&#xff09;、多芯片分割效率及系统级联能力&#…

【工作记录】接口功能测试总结

如何对1个接口进行接口测试 一、单接口功能测试 1、接口文档信息 理解接口文档的内容&#xff1a; 请求URL: https://[ip]:[port]/xxxserviceValidation 请求方法: POST 请求参数: serviceCode(必填), servicePsw(必填) 响应参数: status, token 2、编写测试用例 2.1 正…

Dubbo Logback 远程调用携带traceid

背景 A项目有调用B项目的服务&#xff0c;A项目使用 logback 且有 MDC 方式做 traceid&#xff0c;调用B项目的时候&#xff0c;traceid 没传递过期&#xff0c;导致有时候不好排查问题和链路追踪 准备工作 因为使用的是 alibaba 的 dubbo 所以需要加入单独的包 <depend…

NLP学习路线图(二十):FastText

在自然语言处理(NLP)领域,词向量(Word Embedding)是基石般的存在。它将离散的符号——词语——转化为连续的、富含语义信息的向量表示,使得计算机能够“理解”语言。而在众多词向量模型中,FastText 凭借其独特的设计理念和卓越性能,尤其是在处理形态丰富的语言和罕见词…

力扣面试150题--除法求值

Day 62 题目描述 做法 此题本质是一个图论问题&#xff0c;对于两个字母相除是否存在值&#xff0c;其实就是判断&#xff0c;从一个字母能否通过其他字母到达&#xff0c;做法如下&#xff1a; 遍历所有等式&#xff0c;为每个变量分配唯一的整数索引。初始化一个二维数组 …

美业破局:AI智能体如何用数据重塑战略决策(5/6)

摘要&#xff1a;文章深入剖析美业现状与挑战&#xff0c;指出其市场规模庞大但竞争激烈&#xff0c;面临获客难、成本高、服务标准化缺失等问题。随后阐述 AI 智能体与数据驱动决策的概念&#xff0c;强调其在美业管理中的重要性。接着详细说明 AI 智能体在美业数据收集、整理…

生成模型+两种机器学习范式

生成模型&#xff1a;从数据分布到样本创造 生成模型&#xff08;Generative Model&#xff09; 是机器学习中一类能够学习数据整体概率分布&#xff0c;并生成新样本的模型。其核心目标是建模输入数据 x 和标签 y 的联合概率分布 P(x,y)&#xff0c;即回答 “数据是如何产生的…

【学习笔记】Python金融基础

Python金融入门 1. 加载数据与可视化1.1. 加载数据1.2. 折线图1.3. 重采样1.4. K线图 / 蜡烛图1.5. 挑战1 2. 计算2.1. 收益 / 回报2.2. 绘制收益图2.3. 累积收益2.4. 波动率2.5. 挑战2 3. 滚动窗口3.1. 创建移动平均线3.2. 绘制移动平均线3.3 Challenge 4. 技术分析4.1. OBV4.…

A Execllent Software Project Review and Solutions

The Phoenix Projec: how do we produce software? how many steps? how many people? how much money? you will get it. i am a pretty judge of people…a prank

windows命令行面板升级Git版本

Date: 2025-06-05 11:41:56 author: lijianzhan Git 是一个 ‌分布式版本控制系统‌ (DVCS)&#xff0c;由 Linux 之父 Linus Torvalds 于 2005 年开发&#xff0c;用于管理 Linux 内核开发。它彻底改变了代码协作和版本管理的方式&#xff0c;现已成为软件开发的事实标准工具&…

大故障,阿里云核心域名疑似被劫持

2025年6月5日凌晨&#xff0c;阿里云多个服务突发异常&#xff0c;罪魁祸首居然是它自家的“核心域名”——aliyuncs.com。包括对象存储 OSS、内容分发 CDN、镜像仓库 ACR、云解析 DNS 等服务在内&#xff0c;全部受到波及&#xff0c;用户业务连夜“塌房”。 更让人惊讶的是&…

SQLMesh实战:用虚拟数据环境和自动化测试重新定义数据工程

在数据工程领域&#xff0c;软件工程实践&#xff08;如版本控制、测试、CI/CD&#xff09;的引入已成为趋势。尽管像 dbt 这样的工具已经推动了数据建模的标准化&#xff0c;但在测试自动化、工作流管理等方面仍存在不足。 SQLMesh 应运而生&#xff0c;旨在填补这些空白&…