【模型学习之路】PyG的使用+基于点的任务

news2025/7/19 14:13:13

这一篇是关于PyG的基本使用

目录

前言

PyG的数据结构

演示

图的可视化

基于点的任务

任务分析

MLP

GCN


前言

对图结构感兴趣的朋友可以学一下常用的有关图结构的库:networkx详细介绍 `networkx` 库,探讨它的基本功能、如何创建图、操作图以及其常用参数。-CSDN博客

PyG零基础的朋友可以看一下这个视频的11~14集1-PyTorch Geometric工具包安装与配置方法_哔哩哔哩_bilibili

PyG的数据结构

演示

我们用一个简单的数据集作为演示。

import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub
dataset = KarateClub()  # 这是一个数据集,里面只有一个图
len(dataset)

# output
1

简单介绍一下每个数据的维度信息。

data = dataset[0]
print(data)

# output
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
  • x: 节点特征 [m, f] (m: 节点数,f: 特征数)。

  • edge_index: 边的索引 [2, e] (e: 边数)。可以看作是e条边,两两相连,然后转置了。

  • y: 标签 [m] (m: 节点数)。自然可以做成多输出的,那么维度就会是[m, n_tasks]

  • mask: [m] 一个相对玄学一点的东西,之后在不同场景中介绍

图的可视化

可以利用networks做可视化

import networkx as nx
from matplotlib import pyplot as plt
from torch_geometric.utils import to_networkx
from visualize import visualize_graph

def visualize_graph(G, color):
    plt.figure(figsize=(5, 5))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), node_color=color,
                        with_labels=False, node_size=100, cmap='Set2')

G = to_networkx(data)
visualize_graph(G, color=data.y)

在PyG中,邻接矩阵edge_index是按照稀疏矩阵的方式存的,我们可以把转化为我们平时之前用的密集矩阵。

# 转化为密集型的邻接矩阵
G_adj = nx.to_numpy_array(G)
print(G_adj)

# output
[[0. 1. 1. ... 1. 0. 0.]
 [1. 0. 1. ... 0. 0. 0.]
 [1. 1. 0. ... 0. 1. 0.]
 ...
 [1. 0. 0. ... 0. 1. 1.]
 [0. 0. 1. ... 1. 0. 1.]
 [0. 0. 0. ... 1. 1. 0.]]

顺便画个热力图

import seaborn as sns
sns.heatmap(G_adj, cmap='Blues')

顺便把邻接矩阵和度矩阵都画一下

import numpy as np
adj = G_adj
d = np.diag(np.sum(adj, axis=1))
adj = adj + np.eye(adj.shape[0])
d = d + np.eye(adj.shape[0])

plt.figure(figsize=(10, 4))
plt.subplot(121)
sns.heatmap(adj, cmap='Blues')
plt.subplot(122)
sns.heatmap(d, cmap='Blues')
plt.show()

基于点的任务

任务分析

先像前言中教程一样,先拿到数据集,然后做一些分析。

import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

#output
Number of graphs: 1
Number of features: 1433
Number of classes: 7


# 就一张图,直接取出来便是
data = dataset[0]
print(data)
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Training node label ratio: {data.train_mask.sum().item() / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

#output
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Training node label ratio: 0.05
Has isolated nodes: False
Has self-loops: False
Is undirected: True

上面的很多指标意思都很明显,就不过多解释了。

重点是,基于点的任务到底要干什么?

基于点的任务,往往会约束在同一张图中。这张图有很多节点,如果它们有标签值,就可以划分到训练集(train_mask为True的点)、验证集(valid_mask为True的点)和测试集(test_mask为True的点),用于模型的训练与预测。

我们的目标是,对于一些没有标签值的点,我们就用训练好的模型去预测它们。

所以,其实这个任务非常加了邻接矩阵的MLP。我们将两者对比一下。

MLP

模型定义很简单,这是一个7分类问题。

import torch
import torch.nn as nn

class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        torch.manual_seed(666)
        self.fc = nn.Sequential(
            nn.Linear(1433, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 7)
        )
        
    def forward(self, x):
        return self.fc(x)

训练的时候,注意一下细节:

1. 这个数据集用train_mask和test_mask来划分训练集和测试集。在training时,只用在train_mask为True的上面计算损失即可。同样在计算testing的acc时,也只需要用到test_mask为True的即可。
2. x的shape为[2708, 1433],这里整个训练采用的是最传统的方式,即整个数据集不划分batch(或者说整个数据集就是一个batch)

model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    out = model(data.x)
    pred = out.argmax(dim=1)  # [2708, 7] -> [2708]
    acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return acc

mlp_loss_lst = []
mlp_acc_lst = []
for epoch in range(1, 201):
    """
    input: [2708, 1433]
    output: [2708, 7]
    """
    loss = train()
    acc = test()
    mlp_loss_lst.append(loss)
    mlp_acc_lst.append(acc)

GCN

先来看看模型。

PyG将图神经网络里面各种经典的架构基本都有实现,这里的GCNConv直接调库,原理的话我们上一个专栏详细说过了:【模型学习之路】手写+分析GAT_手写gat-CSDN博客

仔细观察,GCNConv在维度上的表现简直就跟nn.Linear一模一样!

当然,内部自然要复杂的多。此外,它在调用时还要有edge_index,格式和前面的data.edge_index是统一的。

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(666)
        self.conv1 = GCNConv(1433, 16)
        self.conv2 = GCNConv(16, 7)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)  # [2708, 1433] -> [2708, 16]
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)  # [2708, 16] -> [2708, 7]
        return x

训练过程大差不差。

model = GCN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # [2708, 7] -> [2708]
    acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return acc

gcn_loss_lst = []
gcn_acc_lst = []
for epoch in range(1, 201):
    """
    input: [2708, 1433]
    output: [2708, 7]
    """
    loss = train()
    acc = test()
    gcn_loss_lst.append(loss)
    gcn_acc_lst.append(acc)

两者的对比也是挺明显的:

plt.figure(figsize=(10, 5))
plt.plot(mlp_acc_lst, label='MLP')
plt.plot(gcn_acc_lst, label='GCN')
plt.ylim(0, 1)
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2248948.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何监控Elasticsearch集群状态?

大家好,我是锋哥。今天分享关于【如何监控Elasticsearch集群状态?】面试题。希望对大家有帮助; 如何监控Elasticsearch集群状态? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 监控 Elasticsearch 集群的状态对于确保…

Edify 3D: Scalable High-Quality 3D Asset Generation

Deep Imagination Research | NVIDIA 目录 一、Abstract 二、核心内容 1、多视图扩散模型 3、重建模型: 4、数据处理模块: 三、结果 1、文本到 3D 生成结果 2、图像到 3D 生成结果 3、四边形网格拓扑结构 一、Abstract NVIDIA 开发的用于高质量…

QUAD-MxFE平台

QUAD-MxFE平台 16Tx/16Rx直接L/S/C频段采样相控阵/雷达/电子战/卫星通信开发平台 概览 优势和特点 四通道MxFE数字化处理卡 使用MxFE的多通道、宽带系统开发平台 与Xilinx VCU118评估板(不包括)搭配使用 16个RF接收(Rx)通道(32个数字Rx通道…

操作系统 锁——针对实习面试

目录 操作系统 锁什么是死锁?说说死锁产生的条件?死锁如何预防?死锁如何避免?银行家算法具体怎么操作?死锁如何解决?死锁会产生什么影响?乐观锁与悲观锁有什么区别? 操作系统 锁 什么…

UI设计-色彩、层级、字体、边距(一)

一.色彩:色彩可以影响人的心理与行动,具有不同的象征意义;有冷暖,轻重,软硬等等。 1.色彩情绪:最直观的视觉感受 一个活动的页面所用的颜色必须要与其内容相适应,让人看起来舒服。有时我们会不…

从入门到精通数据结构----四大排序(上)

目录 首言: 1. 插入排序 1.1 直接插入排序 1.2 希尔排序 2. 选择排序 2.1 直接选择排序 2.2 堆排序 3. 交换排序 3.1 冒泡排序 3.2 快排 结尾: 首言: 本篇文章主要介绍常见的四大排序:交换排序、选择排序、插入排序、归并排…

【C++第三方库】Muduo库结合ProtoBuf库搭建服务端和客户端的过程和源码

每日激励:“不设限和自我肯定的心态:I can do all things。 — Stephen Curry” 绪论​: 本章我将结合之前的这俩个第三方库快速上手protobuf序列化和反序列化框架和muduo网络,来去实现muduo库在protocol协议搭建服务端和客户端。…

Scala—Map用法详解

Scala—Map用法详解 在 Scala 中,Map 是一种键值对的集合,其中每个键都是唯一的。Scala 提供了两种类型的 Map:不可变 Map 和可变 Map。 1. 不可变集合(Map) 不可变 Map 是默认的 Map 实现,位于 scala.co…

文本处理之sed

1、概述 sed是文本编辑器,作用是对文本的内容进行增删改查。 和vim不一样,sed是按行进行处理。 sed一次处理一行内容,处理完一行之后紧接着处理下一行,一直到文件的末尾 模式空间:临时储存,修改的结果临…

了解网络威胁情报:全面概述

网络威胁情报 CTI 是指系统地收集和分析与威胁相关的数据,以提供可操作的见解,从而增强组织的网络安全防御和决策过程。 在数字威胁不断演变的时代,了解网络威胁情报对于组织来说至关重要。复杂网络攻击的兴起凸显了制定强有力的策略以保护敏…

Python 海龟绘图 turtle 的介绍

python的计算生态中包含标准库和第三方库 标准库:随着解释器直接安装到操作系统中的功能模块 第三方库:需要经过安装才能使用的功能模块 库Library 包 Package 模块Module 统称为模块 turtle 是一个图形绘制的函数库,是标准库&#…

学习日志017--python的几种排序算法

冒泡排序 def bubble_sort(alist):i 0while i<len(alist):j0while j<len(alist)-1:if alist[j]>alist[j1]:alist[j],alist[j1] alist[j1],alist[j]j1i1l [2,4,6,8,0,1,3,5,7,9] bubble_sort(l) print(l) 选择排序 def select_sort(alist):i 0while i<len(al…

java集合及源码

目录 一.集合框架概述 1.1集合和数组 数组 集合 1.2Java集合框架体系 常用 二. Collection中的常用方法 添加 判断 删除 其它 集合与数组的相互转换 三Iterator(迭代器)接口 3.0源码 3.1作用及格式 3.2原理 3.3注意 3.4获取迭代器(Iterator)对象 3.5. 实现…

⭐️ GitHub Star 数量前十的工作流项目

文章开始前&#xff0c;我们先做个小调查&#xff1a;在日常工作中&#xff0c;你会使用自动化工作流工具吗&#xff1f;&#x1f64b; 事实上&#xff0c;工作流工具已经变成了提升效率的关键。其实在此之前我们已经写过一篇博客&#xff0c;跟大家分享五个好用的工作流工具。…

【Jenkins】自动化部署 maven 项目笔记

文章目录 前言1. Jenkins 新增 Maven 项目2. Jenkins 配置 Github 信息3. Jenkins 清理 Workspace4. Jenkins 配置 后置Shell脚本后记 前言 目标&#xff1a;自动化部署自己的github项目 过程&#xff1a;jenkins 配置、 shell 脚本积累 相关连接 Jenkins 官方 docker 指导d…

杂7杂8学一点之多普勒效应

最重要的放在最前面&#xff0c;本文学习资料&#xff1a;B站介绍多普勒效应的优秀视频。如果上学时老师这么讲课&#xff0c;我估计会爱上上课。 目录 1. 多普勒效应 2. 多普勒效应对通信的影响 3. 多普勒效应对低轨卫星通信的影响 1. 多普勒效应 一个小石头扔进平静的湖面…

【python数据结构算法】排序算法 #冒泡 #选择排序 #快排 #插入排序

思维导图 一、经典冒泡 冒泡排序&#xff1a;是一种简单的排序算法&#xff0c;它重复的遍历要排序的序列&#xff0c;一次比较两个元素&#xff0c;如果他们的顺序错误&#xff0c;就把他们交换过来。 冒泡排序算法的运作如下&#xff1a; 比较相邻的元素。如果第一个比第二…

Linux系统之fuser命令的基本使用

Linux系统之fuser命令的基本使用 一、fuser命令介绍二、fuser命令使用帮助2.1 help帮助信息2.1 基本语法①通用选项②文件/设备相关选项③网络相关选项④进程操作选项⑤其他选项 三、fuser命令的基本使用3.1 查找挂载点的进程3.2 查看指定设备进程信息3.3 查找监听特定端口的进…

stable Diffusion官方模型下载

v2-1_768-ema-pruned.safetensors 下载地址&#xff1a; https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main 下载完成后&#xff0c;放到&#xff1a;E:\AITOOLS\stable-diffusion-webui\models\Stable-diffusion 模型&#xff1a;sd_xl_base_1.0.safetens…

《并查集算法详解及实用模板》

《重生我要成为并查集高手&#x1f354;&#x1f354;&#x1f354;》 并查集&#xff1a;快速查询和快速合并&#xff0c; 路径压缩&#xff0c; 按大小&#xff0c;高度&#xff0c;秩合并。 静态数组实现 &#x1f607;前言 在数据的海洋中&#xff0c;有一种悄然流淌的力量…