PyTorch搭建基于图神经网络(GCN)的天气推荐系统(附源码和数据集)

news2025/7/22 22:53:35

需要源码和数据集请点赞关注收藏后评论区留言~~~

一、背景

极端天气情况一直困扰着人们的工作和生活。部分企业或者工种对极端天气的要求不同,但是目前主流的天气推荐系统是直接将天气信息推送给全部用户。这意味着重要的天气信息在用户手上得不到筛选,降低用户的满意度,甚至导致用户的经济损失。我们计划开发一个基于图神经网络的天气靶向模型,根据用户的历史交互行为,判断不同天气对他的利害程度。如果有必要,则将该极端天气情况推送给该用户,让其有时间做好应对准备。该模型能够减少不必要的信息传递,提高用户的体验感。

二、模型介绍

四、模型介绍

(一)数据集共有三个txt文件,分别是user.txt,weather.txt,rating.txt。这些文件一共包含900名用户,1600个天气状况,95964条用户的历史交互记录。

  1. user.txt

用户的信息记录在user.txt中。格式如下:

用户ID\t年龄\t性别\t职业\t地理位置

  1. weather.txt

天气的信息记录在weather.txt中。格式如下:

天气ID\t天气类型\t温度\t湿度\t风速 

  1. rating.txt

用户的历史交互记录在rating.txt中。格式如下:

用户ID\t天气ID\t评分

三、项目结构

如下图 data里面存放了数据集

四、运行结果

开始训练  可以看到第一行显示了一些训练的基本配置内容 包括用的设备cpu 训练批次 学习率等等

 可以看出随着训练次数的增加 损失率在不断降低

最后会自动选出一个最佳的测试和训练集的损失值

 结果可视化如下

 五、代码

部分源码如下

train类

import pandas as pd
import time
from utils import fix_seed_torch, draw_loss_pic
import argparse
from model import GCN
from Logger import Logger
from mydataset import MyDataset
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
import sys
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'


# 固定随机数种子
fix_seed_torch(seed=2021)
# 设置训练的超参数
parser = argparse.ArgumentParser()
parser.add_argument('--gcn_layers', type=int, default=2, help='the number of gcn layers')
parser.add_argument('--n_epochs', type=int, default=20, help='the number of epochs')
parser.add_argument('--embedSize', type=int, default=64, help='dimension of user and entity embeddings')
parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--ratio', type=float, default=0.8, help='size of training dataset')
args = parser.parse_args()
# 设备是否支持cuda
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
args.device = device
# 读取用户特征、天气特征、评分
user_feature = pd.read_csv('./data/user.txt', encoding='utf-8', sep='\t')
item_feature = pd.read_csv('./data/weather.txt', encoding='utf-8', sep='\t')
rating = pd.read_csv('./data/rating.txt', encoding='utf-8', sep='\t')
# 构建数据集
dataset = MyDataset(rating)
trainLen = int(args.ratio * len(dataset))
train, test = random_split(dataset, [trainLen, len(dataset) - trainLen])
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test, batch_size=len(test))
# 记录训练的超参数
start_time = '{}'.format(time.strftime("%m-%d-%H-%M", time.localtime()))
logger = Logger('./log/log-{}.txt'.format(start_time))
logger.info(' '.join('%s: %s' % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
# 定义模型
model = GCN(args, user_feature, item_feature, rating)
model.to(device)
# 定义优化器
optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=0.001)
# 定义损失函数
loss_function = MSELoss()
train_result = []
test_result = []
# 最好的epoch
best_loss = sys.float_info.max
# 训练
for i in range(args.n_epochs):
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        prediction=model(batch[0].to(device),batch[1].to(device))
        train_loss=torch.sqrt(loss_function(batch[2].float().to(device),prediction))
        train_loss.backward()
        optimizer.step()
    train_result.append(train_loss.item())
    model.eval()
    for data in test_loader:
        prediction=model(data[0].to(device),data[1].to(device))
        test_loss=torch.sqrt(loss_function(data[2].float().to(device),prediction))
        test_loss=test_loss.item()
        if best_loss>test_loss:
            best_loss=test_loss
            torch.save(model.state_dict(),'./model/bestModeParms-{}.pth'.format(start_time))
    test_result.append(test_loss)
    logger.info("Epoch{:d}:trainLoss{:.4f},testLoss{:.4f}".format(i,train_loss,test_loss))
else:
    model.load_state_dict(torch.load("./model/bestModeParms-11-18-19-47.pth"))
    user_id=input("请输入用户id")
    item_num=rating['itemId'].max()+1
    u=torch.tensor([int(user_id)for i in range(item_num)],dtype=float)
  气ID".format(user_id))
    print(i[0]for i in result)





# 画图
draw_loss_pic(train_result, test_result)

Logger类

import sys
import os
import logging


class Logger(object):

    def __init__(self, filename):

        self.logger = logging.getLogger(filename)
        self.logger.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s: %(message)s',
                                      datefmt='%Y-%m-%d %H-%M-%S')

        # write into file
        fh = logging.FileHandler(filename)
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)

        # show on console
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(formatter)

        # add to Handler
        self.logger.addHandler(fh)
        self.logger.addHandler(ch)


    def _flush(self):
        for handler in self.logger.handlers:
            handler.flush()

    def info(self, message):
        self.logger.info(message)
        self._flush()

model类

import numpy as np
import torch.nn
import torch.nn as nn
from utils import *
from torch.nn import Module
import scipy.sparse as sp


class GCN_Layer(Module):
    def __init__(self,inF,outF):
        super(GCN_Layer,self).__init__()
        self.W1=torch.nn.Linear(in_features=inF,out_features=outF)
        self.W2=torch.nn.Linear(in_features=inF,out_features=outF)
    def forward(self,graph,selfLoop,features):
        part1=self.W1(torch.sparse.mm(graph+selfLoop,features))
        part2 = self.W2(torch.mul(torch.sparse.mm(graph,features),features))
        return nn.LeakyReLU()(part1+part2)



    ######################
    #    请你补充代码       #
    ######################


class GCN(Module):
    def __init__(self, args, user_feature, item_feature, rating):
        super(GCN, self).__init__()
        self.args = args
        self.device = args.device
        self.user_feature = user_feature
        self.item_feature = item_feature
        self.rating = rating
        self.num_user = rating['user_id'].max() + 1
        self.num_item = rating['item_id'].max() + 1
        # user embedding
        self.user_id_embedding = nn.Embedding(user_feature['id'].max() + 1, 32)
        self.user_age_embedding = nn.Embedding(user_feature['age'].max() + 1, 4)
        self.user_gender_embedding = nn.Embedding(user_feature['gender'].max() + 1, 2)
        self.user_occupation_embedding = nn.Embedding(user_feature['occupation'].max() + 1, 8)
        self.user_location_embedding = nn.Embedding(user_feature['location'].max() + 1, 18)
        # item embedding
        self.item_id_embedding = nn.Embedding(item_feature['id'].max() + 1, 32)
        self.item_type_embedding = nn.Embedding(item_feature['type'].max() + 1, 8)
        self.item_temperature_embedding = nn.Embedding(item_feature['temperature'].max() + 1, 8)
        self.item_humidity_embedding = nn.Embedding(item_feature['humidity'].max() + 1, 8)
        self.item_windSpeed_embedding = nn.Embedding(item_feature['windSpeed'].max() + 1, 8)
        # 自循环
        self.selfLoop = self.getSelfLoop(self.num_user + self.num_item)
        # 堆叠GCN层
        self.GCN_Layers = torch.nn.ModuleList()
        for _ in range(self.args.gcn_layers):
            self.GCN_Layers.append(GCN_Layer(self.args.embedSize, self.args.embedSize))
        self.graph = self.buildGraph()
        self.transForm = nn.Linear(in_features=self.args.embedSize * (self.args.gcn_layers + 1),
                                   out_features=self.args.embedSize)

    def getSelfLoop(self, num):
        i = torch.LongTensor(
            [[k for k in range(0, num)], [j for j in range(0, num)]])
        val = torch.FloatTensor([1] * num)
        return torch.sparse.FloatTensor(i, val).to(self.device)

    def buildGraph(self):
        rating=self.rating.values
        graph=sp.coo_matrix(
            (rating[:,2],(rating[:,0],rating[:,1])),shape=(self.num_user,self.num_item)).tocsr()
        graph=sp.bmat([[sp.csr_matrix((graph.shape[0],graph.shape[0])),graph],
                       [graph.T,sp.csr_matrix((graph.shape[1],graph.shape[1]))]])

        row_sum_sqrt=sp.diags(1/(np.sqrt(graph.sum(axis=1).A.ravel())+1e-8))
        col_sum_sqrt = sp.diags(1 / (np.sqrt(graph.sum(axis=0).A.ravel()) + 1e-8))
        graph=row_sum_sqrt@graph@col_sum_sqrt
        graph=graph.tocoo()
        values=graph.data
        indices=np.vstack((graph.row,graph.col))
        graph=torch.sparse.FloatTensor(torch.LongTensor(indices),torch.FloatTensor(values),torch.Size(graph.shape))
        return graph.to(self.device)
        ######################
        #    请你补充代码       #
        ######################

    def getFeature(self):
        # 根据用户特征获取对应的embedding
        user_id = self.user_id_embedding(torch.tensor(self.user_feature['id']).to(self.device))
        age = self.user_age_embedding(torch.tensor(self.user_feature['age']).to(self.device))
        gender = self.user_gender_embedding(torch.tensor(self.user_feature['gender']).to(self.device))
        occupation = self.user_occupation_embedding(torch.tensor(self.user_feature['occupation']).to(self.device))
        location = self.user_location_embedding(torch.tensor(self.user_feature['location']).to(self.device))
        user_emb = torch.cat((user_id, age, gender, occupation, location), dim=1)
        # 根据天气特征获取对应的embedding
        item_id = self.item_id_embedding(torch.tensor(self.item_feature['id']).to(self.device))
        item_type = self.item_type_embedding(torch.tensor(self.item_feature['type']).to(self.device))
        temperature = self.item_temperature_embedding(torch.tensor(self.item_feature['temperature']).to(self.device))
        humidity = self.item_humidity_embedding(torch.tensor(self.item_feature['humidity']).to(self.device))
        windSpeed = self.item_windSpeed_embedding(torch.tensor(self.item_feature['windSpeed']).to(self.device))
        item_emb = torch.cat((item_id, item_type, temperature, humidity, windSpeed), dim=1)
        # 拼接到一起
        concat_emb = torch.cat([user_emb, item_emb], dim=0)
        return concat_emb.to(self.device)

    def forward(self, users, items):
        features=self.getFeature()
        final_emb=features.clone()
        for GCN_Layer in self.GCN_Layers:
            features=GCN_Layer(self.graph,self.selfLoop,features)
            final_emb=torch.cat((final_emb,features.clone()),dim=1)
        user_emb,item_emb=torch.split(final_emb,[self.num_user,self.num_item])
        user_emb=user_emb[users]
        item_emb=item_emb[items]
        user_emb=self.transForm(user_emb)
        item_emb=self.transForm(item_emb)

        prediction=torch.mul(user_emb,item_emb).sum(1)
        return prediction
        ######################
        #    请你补充代码       #
        ######################

mydataset类

from torch.utils.data import Dataset
import pandas as pd


class MyDataset(Dataset):
    def __init__(self, rating):
        super(Dataset, self).__init__()
        self.user = rating['user_id']
        self.weather = rating['item_id']
        self.rating = rating['rating']

    def __len__(self):
        return len(self.rating)

    def __getitem__(self, item):
        return self.user[item], self.weather[item], self.rating[item]

utils类

from torch.utils.data import Dataset
import pandas as pd


class MyDataset(Dataset):
    def __init__(self, rating):
        super(Dataset, self).__init__()
        self.user = rating['user_id']
        self.weather = rating['item_id']
        self.rating = rating['rating']

    def __len__(self):
        return len(self.rating)

    def __getitem__(self, item):
        return self.user[item], self.weather[item], self.rating[item]

创作不易 觉得有帮助请点赞关注收藏~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/16901.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(C语言)背答案

[#4练习赛]背答案 题目描述 传智专修学院“Java程序设计”的期末考试来源于一个选择库,共有 nnn 道题目,每道题目由问题和答案组成,都是一个字符串,保证所有题目题面互不相同。这个题库已经发给同学进行备考准备。 正式考试中&…

Labview+STM32无线温湿度采集

一.介绍 该项目采用正点原子的STM32ZET6精英板DHT11温湿度模块泽耀科技的无线串口作为下位机,Labview无线串口作为上位机读取下位机发来的数据并处理。 泽耀科技的产品是我在开发过程中经常用到的,他们不仅产品做的非常不错,而且资料齐全售后…

远离cmd,拥抱powershell

简介:cmd命令提示符是在操作系统中,提示进行命令输入的一种工作提示符。在不同的操作系统环境下,命令提示符各不相同。 在windows环境下,命令行程序为cmd.exe,是一个32位的命令行程序,微软Windows系统基于W…

动态规划--区间dp

区间dp题目列表:(1)石子合并(2)环形石子合并(3)能量项链(4)加分二叉树(5)凸多边形的划分(6)棋盘分割题目列表: (1)石子合并 在复习石子合并之前,为了直接进入专题“区间dp“,做一个区间dp的基础题,这个题目具有代表性…

1.2 Android 5.0 的特点

和其他版本相比, Android 5.0 的突出特性如下所示。 (1)全新的 Material 界面设计 Android 5.0 Lollipop 界面设计的灵感来源于自然、 物理学 以及基于打印效果的粗体、图标化的设计,换句话说,它的设 计是一种基于高品…

智慧建筑BIM解决方案-最新全套文件

智慧建筑BIM解决方案-最新全套文件一、建设背景为什么要发展智慧建筑二、思路架构三、建设方案智慧建筑建设时应考虑下面3个方面:1、减少耗能,促进资源利用效率2、优化工作和生活环境3、确保运营安全可靠四、获取 - 智慧建筑BIM全套最新解决方案合集一、…

m超外差单边带接收机的simulink仿真

目录 1.算法概述 2.仿真效果预览 3.MATLAB部分代码预览 4.完整MATLAB程序 1.算法概述 超外差是利用本地产生的振荡波与输入信号混频,将输入信号频率变换为某个预先确定的频率的方法。这种方法是为了适应远程通信对高频率、弱信号接收的需要,在外差原…

基于springboot在线玩具商城交易平台的设计与实现

随着科技创新不断突破玩具界限,特别是随着智能时代到来,电子游戏的兴起对传统玩具行业带来了冲击,智能玩具应运而生,成为新产品方向。智能玩具受消费者青睐, 随着电子商务的发展,其在我国的经济地位越来越…

spring boot酒店会员点餐系统毕业设计源码072005

Springboot酒店会员点餐系统 摘 要 进入21世纪以来,计算机有了迅速的发展。计算机应用、信息技术全面渗透到了人类社会的各个方面,信息化已成为世界经济和社会发展的大趋势。―企业的管理也从人工操作变得更加自动化、智能化和高效化。如果复杂的工作光靠…

PMP大家都是怎么备考的?使用什么工具可以分享一下吗?

这里分享PMP理论中的4个工具,在人生管理和项目管理中是通用的。所有的工具,只有在对的时间,用在对的地方,才能真正指导实践。 项目经理应符合PMI人才三角。分别为:技术项目管理;领导力;战略和…

腾讯云服务器后台重装后需要配置的一些东西

1、adduser 用户名(创建普通用户) 2、passwd 用户名(给普通用户设置密码) 3、userdel -r 用户名(删除普通用户) 4、修改/etc/sudoers文件(给普通用户可以提权的机会) 5、sudo yum in…

Hive——Hive常用内置函数总结

✅作者简介:最近接触到大数据方向的程序员,刚入行的小白一枚 🍊作者博客主页:皮皮皮皮皮皮皮卡乒的博客 🍋当前专栏:Hive学习进阶之旅 🍒研究方向:大数据方向,数据汇聚&a…

vdsm:添加接口调试demo

目录 添加API接口 2.添加api方法 3.Vdsm-api.yml添加参数 暴露jsonrpc接口: 需要重启vdsmd vdsm-client 调试 本文通过添加一个配置ovs全局参数的接口 添加API接口 文件路径:API.py 2.添加api方法 文件路径:network/api.py 3.Vdsm-ap…

4.2——Node.js的npm和包

目录初识node.jsnode.js的安装和查看版本使用node命令对js文件运行窗口的快捷键fs 文件系统模块fs.readFile() 方法写入文件fs.writeFile()案例——考试成绩整理路径问题path 路径模块路径拼接path.join()获取路径中的文件名path.basename()获取路径中的文件扩展名path.extname…

用Python的Django框架来制作一个RSS阅读器

Django带来了一个高级的聚合生成框架,它使得创建RSS和Atom feeds变得非常容易。 什么是RSS? 什么是Atom? RSS和Atom都是基于XML的格式,你可以用它来提供有关你站点内容的自动更新的feed。 了解更多关于RSS的可以访问 http://www…

[附源码]SSM计算机毕业设计足球队管理系统JAVA

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

[附源码]java毕业设计企业记账系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

QT 发布文章遇到问题解决方案

提供了两种可以发布 Qt 程序的方案,建议使用第二种直接生成对应的文件,直接打包就可以 1. 手动复制需要的文件到运行目录下 我们写完 QT 程序当然是要发布或者发给其他需要用到的人,由于找不到Qt6Core.dll,无法继续执行代码,打开 realease …

Python基础语法

一、字面量:在代码中,被写下来的固定的值 二、注释 /增加代码的可读性 单行注释 #空格注释文字内容 (加空格只是规范)#右边 多行注释 一对三个双引号 """注释内容""" 三、变量 -->程序运行时…

Linux基础内容(12)—— 程序地址空间

目录 1.误区和它的由来 2.虚拟地址的证明 3.虚拟地址的实现 1.虚拟空间的解释 2.操作系统管理和规划虚拟空间 3.虚拟地址与物理地址的联系 4.多进程的虚拟地址解释 5.磁盘中可执行文件的地址 6.进程地址空间出现的原因 接上面内容 Linux基础内容(11&#…