Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py环境安装教程及图数据集制作

news2025/7/24 4:10:31

最近需要训练图卷积神经网络(Graph Convolution Neural Network, GCNN),在配置GCNN环境上总结了一些经验。

我觉得对于初学者而言,图神经网络的训练会有2个难点:

①环境配置

②数据集制作

一、环境配置

我最初光想到要给GCNN配环境就觉得有些困难,感觉相比于目标检测、分类识别这些任务用规则数据,图神经网络的模型、数据都是图,所以内心觉得会比较难。

我之前更有一个误区,就是觉得不规则结构的图数据不能用CUDA进行并行加速。实际上,图,在电脑里也是以张量这种规则结构数据存在的,完全能用CUDA进行加速计算,训练GCN前配置CUDA完全OK。


以下是我配置的环境,可用CUDA成功运行link_pred.py

几个关键包的版本:

torch                                2.4.1
torch-geometric               2.3.1
torchaudio                       2.4.1
torchvision                       0.14.0
torchviz                            0.0.2

pandas                             1.0.3

numpy                              1.20.0

 CUDA: 11.8

注意要先安装好CUDA,显示了:

 

再安装GPU版本的torch,不然python检测安装的是cpu版本的torch。这时,就得卸载重新安装了

环境配置成功:

print(torch.__version__)
print(torch.cuda.is_available())

如果CUDA环境安装失败,会打印:

2.4.1+cpu
False

其实只安装torch和CUDA还好,如果你的python中有numpy和pandas可能解决版本之间的冲突会耗费不少时间,我就是在numpy和pandas版本上试了很久,最终找到现在的版本是相互兼容的。

CUDA的版本切换可以参考我的另一篇博客:

CUDA版本切换

二、数据集制作

掌握图数据集制作的关键在于掌握slices切片:

for ...
    data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index,             edge_label=Edge_label)
    
    data_list.append(data)
data_, slices = self.collate(data_list)  # 将不同大小的图数据对齐,填充
torch.save((data_, slices), self.processed_paths[0])

和CNN不同的是,GCN没有样本维度,需要把所有样本拼成一张大图喂给GCN进行训练 

数据集生成代码:

#作者:zhouzhichao
#创建时间:2025/5/30
#内容:生成200个样本的PYG数据集


import h5py
import hdf5storage
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import negative_sampling

base_dir = "D:\\无线通信网络认知\\论文1\\experiment\\直推式拓扑推理实验\\拓扑生成\\200样本\\"

N = 30
grapg_size = N
train_n = 31
M = 3000


class graph_data(InMemoryDataset):
    def __init__(self, root, signals=None, tp_list = None, transform=None, pre_transform=None):
        # self.Signals = Signals
        # self.Tp_list = Tp_list
        self.signals = signals
        self.tp_list = tp_list
        super().__init__(root, transform, pre_transform)
        # self.data, self.slices = torch.load(self.processed_paths[0])

        self.data = torch.load(self.processed_paths[0])


    # 返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致
    @property
    def processed_file_names(self):
        return ['gcn_data.pt']

    # 生成数据集所用的方法
    def process(self):
        # data_list = []
        # for k in range(200):
        # signals = self.Signals[:, :, k]
        # tp_list = np.array(mat_file[self.Tp_list[0, k]])
        signals = self.signals
        tp_list =self.tp_list
        # tp = Tp[:,:,k]

        X = torch.tensor(signals, dtype=torch.float)

        # 所有的边
        Edge_index = torch.tensor(tp_list, dtype=torch.long)

        # 所有的边1标签
        edge_label = np.ones((tp_list.shape[1]))
        # edge_label = np.zeros((tp_list.shape[1]))
        Edge_label = torch.tensor(edge_label, dtype=torch.float)

        neg_edge_index = negative_sampling(
            edge_index=Edge_index, num_nodes=grapg_size,
            num_neg_samples=Edge_index.shape[1], method='sparse')
        # 拼接正负样本索引

        # c = 0
        # for i in range(31):
        #     for i in range(31):
        #         if torch.equal(Edge_index[:, i], neg_edge_index[:, i]):
        #             c = c + 1
        # print("c: ",c)

        Edge_label_index = Edge_index
        perm = torch.randperm(Edge_index.size(1))
        Edge_index = Edge_index[:, perm]
        Edge_index = Edge_index[:, :train_n]

        Edge_label_index = torch.cat(
            [Edge_label_index, neg_edge_index],
            dim=-1,
        )
        # 拼接正负样本
        Edge_label = torch.cat([
            Edge_label,
            Edge_label.new_zeros(neg_edge_index.size(1))
        ], dim=0)
        # Edge_label = torch.cat([
        #     Edge_label,
        #     Edge_label.new_ones(neg_edge_index.size(1))
        # ], dim=0)

        data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)
        torch.save(data, self.processed_paths[0])
            # data_list.append(data)
        # data_, slices = self.collate(data_list)  # 将不同大小的图数据对齐,填充

        # torch.save((data_, slices), self.processed_paths[0])



for snr in [0,20,40]:

    print("snr: ", snr)

    mat_file = h5py.File(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')
    # mat_file = hdf5storage.loadmat(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')
    # 获取数据集
    Signals = mat_file["Signals"][()]
    # signals = np.swapaxes(signals, 1, 0)
    Tp = mat_file["Tp"][()]
    Tp_list = mat_file["Tp_list"][()]
    # tp_list = tp_list - 1
    # 关闭文件
    # mat_file.close()
    # graph_data("gcn_data")

    # n = Signals.shape[2]
    n = 10
    for i in range(n):
        signals = Signals[:,:,i]
        tp_list = np.array(mat_file[Tp_list[0, i]])
        root = "gcn_data-"+str(i)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)
        graph_data(root, signals = signals, tp_list = tp_list)
        print("")

print("...图数据生成完成...")

训练代码:

#作者:zhouzhichao
#创建时间:25年5月29日
#内容:统计图中有关系节点和无关系节点的GCN特征欧式距离



import sys
import torch
import random
import numpy as np
import pandas as pd
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
sys.path.append('D:\无线通信网络认知\论文1\experiment\直推式拓扑推理实验\GCN推理')
from gcn_dataset import graph_data
print(torch.__version__)
print(torch.cuda.is_available())

mode = "gcn"

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(Input_L, 1000)
        self.conv2 = GCNConv(1000, 20)

    def encode(self, x, edge_index):
        x1 = self.conv1(x, edge_index)
        x1_1 = x1.relu()
        x2 = self.conv2(x1_1, edge_index)
        x2_2 = x2.relu()
        return x2_2

    def decode(self, z, edge_label_index):
        # 节点和边都是矩阵,不同的计算方法致使:节点->节点,节点->边
        # nodes_relation = (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
        # distances  = torch.norm(z[edge_label_index[0]] - z[edge_label_index[1]], dim=-1)
        distance_squared = torch.sum((z[edge_label_index[0]] - z[edge_label_index[1]]) ** 2, dim=-1)
        # print("distance_squared: ",distance_squared)
        return distance_squared

    def decode_all(self, z):
        prob_adj = z @ z.t()  # 得到所有边概率矩阵
        return (prob_adj > 0).nonzero(as_tuple=False).t()  # 返回概率大于0的边,以edge_index的形式

    @torch.no_grad()
    def test(self,input_data):
        model.eval()

        z = model.encode(input_data.x, input_data.edge_index)
        out = model.decode(z, input_data.edge_label_index).view(-1)
        out = 1 - out


N = 30
train_n = 31
M = 3000
# snr = -20
# for train_n in range(1,51):
# for M in range(3000, 499, -100):
for snr in [0,20,40]:

    print("snr: ", snr)

    for I in range(10):
        root = "gcn_data-"+str(I)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)
        gcn_data = graph_data(root)

        Input_L = gcn_data.x.shape[1]

        model = Net()
        # model = Net().to(device)
        optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
        criterion = torch.nn.BCEWithLogitsLoss()


        def train():
            model.train()
            optimizer.zero_grad()
            z = model.encode(gcn_data.x, gcn_data.edge_index)
            # out = model.decode(z, train_data.edge_label_index).view(-1).sigmoid()
            out = model.decode(z, gcn_data.edge_label_index).view(-1)
            out = 1 - out
            loss = criterion(out, gcn_data.edge_label)
            loss.backward()
            optimizer.step()
            return loss

        min_loss = 99999
        count = 0#早停
        for epoch in range(10000):
            loss = train()
            if loss<min_loss:
                min_loss = loss
                count = 0
            count = count + 1
            if count>100:
                break
            print("epoch:  ",epoch,"   loss: ",round(loss.item(),2), "   min_loss: ",round(min_loss.item(),2))


        z = model.encode(gcn_data.x, gcn_data.edge_index)
        out = model.decode(z, gcn_data.edge_label_index).view(-1)

        list_0 = []
        list_1 = []

        for i in range(len(gcn_data.edge_label)):
            true_label = gcn_data.edge_label[i].item()
            euclidean_distance_value = out[i].item()
            if true_label==1:
                list_1.append(euclidean_distance_value)
            if true_label==0:
                list_0.append(euclidean_distance_value)

        minlength = min(len(list_1), len(list_0))

        list_1 = random.sample(list_1, minlength)
        list_0 = random.sample(list_0, minlength)

        value = list_1 + list_0
        large_class = list(np.full(len(value), snr))
        small_class = list(np.full(len(list_1), 1)) + list(np.full(len(list_0), 0))


        data = {
            'large_class': large_class,
            'small_class': small_class,
            'value': value
        }

        # 创建一个 DataFrame
        df = pd.DataFrame(data)
        #
        # # 保存到 Excel 文件
        file_path = 'D:\无线通信网络认知\论文1\大修意见\图聚类、阈值相似性图实验补充\\' + mode + '_similarity_' + str(snr) + 'db_'+str(I)+'.xlsx'
        df.to_excel(file_path, index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2397452.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React---day6、7

6、组件之间进行数据传递 **6.1 父传子&#xff1a;**props传递属性 父组件&#xff1a; <div><ChildCpn name"蒋乙菥" age"18" height"1,88" /> </div>子组件&#xff1a; export class ChildCpn extends React.Component…

hook组件-useEffect、useRef

hook组件-useEffect、useRef useEffect 用法及执行机制 WillMount -> render -> DidMount ShouldUpdate -> WillUpdate -> render -> DidUpdate WillUnmount(只有这个安全) WillReceiveProps useEffect(callback) 默认所有依赖都更新useEffect(callback, [])&am…

随机游动算法解决kSAT问题

input&#xff1a;n个变量的k-CNF公式 ouput&#xff1a;该公式的一组满足赋值或宣布没有满足赋值 算法步骤&#xff1a; 随机均匀地初始化赋值 a ∈ { 0 , 1 } n a\in\{0,1\}^n a∈{0,1}n.重复t次&#xff08;后面会估计这个t&#xff09;&#xff1a; a. 如果在当前赋值下…

《Discuz! X3.5开发从入门到生态共建》第1章 Discuz! 的前世今生-优雅草卓伊凡

《Discuz! X3.5开发从入门到生态共建》第1章 Discuz! 的前世今生-优雅草卓伊凡 第一节 从康盛创想到腾讯收购&#xff1a;PC时代的辉煌 1.1 Discuz! 的诞生&#xff1a;康盛创想的开源梦想 2001年&#xff0c;中国互联网正处于萌芽阶段&#xff0c;个人网站和论坛开始兴起。…

笔试强训:Day6

一、小红的口罩&#xff08;贪心优先级队列&#xff09; 登录—专业IT笔试面试备考平台_牛客网 #include<iostream> #include<queue> #include<vector> using namespace std; int n,k; int main(){//用一个小根堆 每次使用不舒适度最小的cin>>n>&…

谷歌Stitch:AI赋能UI设计,免费高效新利器

在AI技术日新月异的今天&#xff0c;各大科技巨头都在不断刷新我们对智能工具的认知。最近&#xff0c;谷歌在其年度I/O开发者大会期间&#xff0c;除了那些聚光灯下的重磅发布&#xff0c;还悄然上线了一款令人惊喜的AI工具——Stitch。这是一款全新的、完全免费的AI驱动UI&am…

运营商地址和ip属地一样吗?怎么样更改ip属地地址

‌在互联网时代&#xff0c;IP属地和运营商地址是两个经常被提及的概念&#xff0c;但它们是否相同&#xff1f;如何更改IP属地地址&#xff1f;这些问题困扰着许多网民。本文将深入探讨这两个概念的区别&#xff0c;并详细介绍更改IP属地地址的方法。 一、运营商地址和IP属地一…

在QT中,利用charts库绘制FFT图形

第1章 添加charts库 1.1 .pro工程添加chart库 1.1.1 在.pro工程里面添加charts库 1.1.2 在需要使用的地方添加这两个库函数&#xff0c;顺序一点不要搞错&#xff0c;先添加.pro&#xff0c;否则编译器会找不到这两个.h文件。 第2章 Charts关键绘图函数 2.1 QChart 类 QChart 是…

流媒体协议分析:流媒体传输的基石

在流媒体传输过程中&#xff0c;协议的选择至关重要&#xff0c;它决定了数据如何封装、传输和解析&#xff0c;直接影响着视频的播放质量和用户体验。本文将深入分析几种常见的流媒体传输协议&#xff0c;探讨它们的特点、应用场景及优缺点。 协议分类概述 流媒体传输协议根据…

vscode中让文件夹一直保持展开不折叠

vscode中让文件夹一直保持展开不折叠 问题 很多小伙伴使用vscode发现空文件夹会折叠显示, 让人看起来非常难受, 如下图 解决办法 首先打开设置->setting, 搜索compact Folders, 去掉勾选即可, 如下图所示 效果如下 看起来非常爽 ! ! !

JAVA-springboot整合Mybatis

SpringBoot从入门到精通-第15章 MyBatis框架 学习MyBatis心路历程 2022年学习java基础时候&#xff0c;想着怎么使用java代码操作数据库&#xff0c;咨询了项目上开发W同事&#xff0c;没有引用框架&#xff0c;操作数据库很麻烦&#xff0c;就帮我写好多行代码&#xff0c;就…

深度学习pycharm debug

深度学习中&#xff0c;Debug 是定位并解决代码逻辑错误&#xff08;如张量维度不匹配&#xff09;、训练异常&#xff08;如 Loss 波动&#xff09;、数据问题&#xff08;如标签错误&#xff09;的关键手段&#xff0c;通过打印维度、可视化梯度等方法确保模型正常运行、优化…

MicroPython+L298N+ESP32控制电机转速

要使用MicroPython控制L298N电机驱动板来控制电机的转速&#xff0c;你可以通过PWM&#xff08;脉冲宽度调制&#xff09;信号来调节电机速度。L298N是一个双H桥驱动器&#xff0c;可以同时控制两个电机的正反转和速度。 硬件准备&#xff1a; 1. L298N 电机控制板 2. ESP32…

在部署了一台mysql5.7的机器上部署mysql8.0.35

在已部署 MySQL 5.7 的机器上部署 MySQL 8.0.35 的完整指南 在同一台服务器上部署多个 MySQL 版本需要谨慎规划&#xff0c;避免端口冲突和数据混淆。以下是详细的部署步骤&#xff1a; 一、规划配置 端口分配 MySQL 5.7&#xff1a;使用默认端口 3306MySQL 8.0.35&#xff1…

QT入门学习(一)---新建工程与、信号与槽

一: 新建QT项目 二:QT文件构成 2.1 first.pro 项目管理文件&#xff0c;下面来看代码解析 QT core guigreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c11TARGET main# The following define makes your compiler emit warnings if you use # any Qt feature …

UE5.4.4+Rider2024.3.7开发环境配置

文章目录 一、UE5安装 安装有两种方式一种的源码编译安装、一种是EPIC安装&#xff0c;推荐后者&#xff0c;只需要注册一个EPIC账号就可以一键安装。 二、C环境安装 1.下载VisualStudioSetup 下载链接如下下载 Visual Studio Tools - 免费安装 Windows、Mac、Linux 选择社…

Windows环境下PHP,在PowerShell控制台输出中文乱码

解决方法&#xff1a; 以管理员运行PowerShell , 输入&#xff1a; chcp 65001 重启控制台&#xff1b;然后就正常输出中文&#xff1b;

性能优化 - 理论篇:性能优化的七类技术手段

文章目录 Pre引言性能优化的七类技术手段性能优化策略一览表1. 复用优化2. 计算优化2.1 并行执行2.2 变同步为异步2.3 惰性加载 3. 结果集优化3.1 数据格式与协议选择3.2 字段精简与按需返回3.3 批量处理与分页3.4 索引与位图加速 4. 资源冲突优化4.1 锁的分类与特点4.2 无锁与…

华为IP(7)

端口隔离技术 产生的背景 1.以太交换网络中为了实现报文之间的二层隔离&#xff0c;用户通常将不同的端口加入不同的VLAN&#xff0c;实现二层广播域的隔离。 2.大型网络中&#xff0c;业务需求种类繁多&#xff0c;只通过VLAN实现二层隔离&#xff0c;会浪费有限的VLAN资源…

AIGC与影视制作:技术革命、产业重构与未来图景

文章目录 一、AIGC技术全景&#xff1a;从算法突破到产业赋能1. **技术底座&#xff1a;多模态大模型的进化路径**2. **核心算法&#xff1a;从生成对抗网络到扩散模型的迭代** 二、AIGC在影视制作全流程中的深度应用1. **剧本创作&#xff1a;从“灵感枯竭”到“创意井喷”**2…