【RecBole-GNN/源码】RecBole-GNN中NCL源码解析

news2025/7/13 11:44:15

如果觉得我的分享有一定帮助,欢迎关注我的微信公众号 “码农的科研笔记”,了解更多我的算法和代码学习总结记录。或者点击链接扫码关注【RecBole-GNN/源码】RecBole-GNN中NCL源码解析

【RecBole-GNN/源码】RecBole-GNN中NCL源码解析


原文:https://arxiv.org/abs/2202.06200

源码:https://github.com/rucaibox/ncl

1 数据

方法开始经过数据处理等进入 ncl.pycalculate_loss(self, interaction) 方法。interaction数据形式如下

interaction数据形式

首先获得交互正负样本对数据

#根据交互interaction数据,获得交互正对和负对
user = interaction[self.USER_ID] #2048*1
pos_item = interaction[self.ITEM_ID] #2048*1
neg_item = interaction[self.NEG_ITEM_ID] #2048*1

2 GCN嵌入表示

利用GCN进行前向传播得到user和item的嵌入表示

#获得所有节点(user和item)的嵌入表示
all_embeddings = self.get_ego_embeddings() #9748*64
embeddings_list = [all_embeddings]

for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)): #max(3,2)
    #采用了LightGCNConv的方式
    all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) # 9748*64/2*1610886/1610886*1
    embeddings_list.append(all_embeddings)

#得到的embeddings_list是4个9748*64.
lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers + 1], dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

 user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])

3 Loss计算

#embeddings_list是4个9748*64.
center_embedding = embeddings_list[0] #得到初始的embedding,9748*64
context_embedding = embeddings_list[self.hyper_layers * 2] #得到第3个embedding,9748*64

#基于SSL(结构)得到loss (偶数跳的作为邻居正例),user:2048*1,pos_item:2048*1
ssl_loss = self.ssl_layer_loss(context_embedding, center_embedding, user, pos_item)
#基于语义计算loss,
proto_loss = self.ProtoNCE_loss(center_embedding, user, pos_item)

u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)

mf_loss = self.mf_loss(pos_scores, neg_scores)

u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)

reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)

3.1【ssl_layer_loss方法介绍】

该代码首先从当前嵌入中提取用户和商品的表征。然后,将用户和商品的表征分别与基于GCN的第二跳表征进行匹配,计算其相似度。接着,使用softmax函数将所有商品与用户的相似度加权平均,得到用户对所有商品的兴趣得分。然后,对每个商品,同样计算其与当前和基于GCN的第二跳表征之间的相似度,并使用softmax函数将其与所有商品的相似度加权平均,得到商品的受欢迎度得分。

接下来,将用户和商品的得分分别用作分子和分母,计算用户和商品的损失。最后,将两个损失加权相加,并乘以一个正则化参数,得到最终的ssl损失。

注意:context_embedding, center_embedding是包含用户和item的embedding。

def ssl_layer_loss(self, current_embedding, previous_embedding, user, item):
        #user和item表征分开
        current_user_embeddings, current_item_embeddings = torch.split(current_embedding, [self.n_users, self.n_items])
        previous_user_embeddings_all, previous_item_embeddings_all = torch.split(previous_embedding, [self.n_users, self.n_items])
###################
        #获取当前user对应embedding
        current_user_embeddings = current_user_embeddings[user]
        previous_user_embeddings = previous_user_embeddings_all[user]
        #将这两个表征进行归一化,以保证它们具有相同的尺度,并计算它们的点积作为相似度得分。
        norm_user_emb1 = F.normalize(current_user_embeddings)
        norm_user_emb2 = F.normalize(previous_user_embeddings)
        norm_all_user_emb = F.normalize(previous_user_embeddings_all)
        pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1)
        #计算用户对所有用户的得分。transpose(0, 1)表示对norm_all_user_emb的第0维和第1维进行转置操作,即将所有用户的表征转置,以便与当前用户的表征进行矩阵乘法。乘积的结果是一个大小为(n_items, 1)的向量,表示当前用户对所有所有用户的得分。
        ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1))
        #self.ssl_temp=0.1
        pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
        ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)
        #
        ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()
####################
        #同理计算
        current_item_embeddings = current_item_embeddings[item]
        previous_item_embeddings = previous_item_embeddings_all[item]
        norm_item_emb1 = F.normalize(current_item_embeddings)
        norm_item_emb2 = F.normalize(previous_item_embeddings)
        norm_all_item_emb = F.normalize(previous_item_embeddings_all)
        pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1)
        ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1))
        pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
        ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)

        ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

        ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item)
        return ssl_loss

3.2【ProtoNCE_loss方法介绍】

给定一个节点嵌入向量 node_embedding,以及一个用户 ID user 和一个物品 ID item,该函数会使用这些输入计算出用户和物品的 proto-contrastive loss。proto-contrastive loss 是一个用于无监督学习的损失函数,用于在嵌入空间中学习出具有相似性质的节点之间的距离,并且可以通过聚类算法来获取节点的簇标签。

  • 这个函数首先将 node_embedding 分成两部分,分别对应于所有用户和所有物品的嵌入向量 user_embeddings_all 和 item_embeddings_all。
  • 然后,它会从 user_embeddings_all 中选择 user 对应的嵌入向量 user_embeddings,并对其进行归一化处理,得到 norm_user_embeddings。接着,函数会获取 user 对应的簇标签 user2cluster,并使用它来获取对应的簇中心 user2centroids。然后,函数会计算出用户与其所属簇中心的内积,得到 pos_score_user。pos_score_user 会经过指数函数处理,并除以一个温度参数 self.ssl_temp。接着,函数会计算用户与所有簇中心的内积,得到 ttl_score_user。ttl_score_user 也会经过指数函数处理,并按行求和。然后,函数会使用 pos_score_user 和 ttl_score_user 来计算用户的 proto-contrastive loss,将其取负数并求和。
  • 函数接下来会处理物品的嵌入向量 item_embeddings_all,方法与处理用户的嵌入向量类似,得到 pos_score_item 和 ttl_score_item,并计算出物品的 proto-contrastive loss。
  • 最后,函数会将用户和物品的 proto-contrastive loss 加权求和,得到最终的 proto-contrastive loss。这个加权系数由参数 proto_reg 决定。函数返回最终的 proto-contrastive loss。
def ProtoNCE_loss(self, node_embedding, user, item):
        user_embeddings_all, item_embeddings_all = torch.split(node_embedding, [self.n_users, self.n_items])

        user_embeddings = user_embeddings_all[user]     # [B, e]
        norm_user_embeddings = F.normalize(user_embeddings)

        user2cluster = self.user_2cluster[user]     # [B,]
        user2centroids = self.user_centroids[user2cluster]   # [B, e]
        pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1)
        pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
        ttl_score_user = torch.matmul(norm_user_embeddings, self.user_centroids.transpose(0, 1))
        ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

        proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

        item_embeddings = item_embeddings_all[item]
        norm_item_embeddings = F.normalize(item_embeddings)

        item2cluster = self.item_2cluster[item]  # [B, ]
        item2centroids = self.item_centroids[item2cluster]  # [B, e]
        pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1)
        pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
        ttl_score_item = torch.matmul(norm_item_embeddings, self.item_centroids.transpose(0, 1))
        ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
        proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

        proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item)
        return proto_nce_loss

3.3【簇标签以及簇中心】

在每个epoch的时候都会开始进行self.model.e_step()(Running E-step )

    def e_step(self):
        user_embeddings = self.user_embedding.weight.detach().cpu().numpy() #
        item_embeddings = self.item_embedding.weight.detach().cpu().numpy()
        #self.user_centroids:1000*64 ,self.user_2cluster:6041*1
        #self.item_centroids:1000*64, self.item_2cluster:3707*1
        # 给予了user对应的簇,以及簇中心
        self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings)
        self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings)
    # self.k=1000
    def run_kmeans(self, x):
        """Run K-means algorithm to get k clusters of the input tensor x
        """
        import faiss
        kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True)
        kmeans.train(x)
        cluster_cents = kmeans.centroids

        _, I = kmeans.index.search(x, 1)

        # convert to cuda Tensors for broadcast
        centroids = torch.Tensor(cluster_cents).to(self.device)
        centroids = F.normalize(centroids, p=2, dim=1)

        node2cluster = torch.LongTensor(I).squeeze().to(self.device)
        return centroids, node2cluster

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/362597.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【教程】GitStats代码统计工具(附GitLab API相关)

使用GitStats进行代码统计 官方文档:GitStats - git history statistics generator GitStats是基于Git的数据统计生成器,输出格式为HTML,可直接在浏览器打开查看,展现为图表形式的可视化数据,内容包括: 常…

Spring Boot MyBatis-Plus 连接 Oracle 数据库 自动生成代码

IDEA 创建SpringBoot项目 项目创建移步 IDEA创建SpringBoot项目 添加依赖 <!--MyBatis--><dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><version>2.1.4</vers…

归因分析笔记21 可解释的机器学习-李宏毅讲座

视频链接: https://www.bilibili.com/video/BV1Wv411h7kN/?p96&vd_source7259e29498a413d91ab48c04f9329855 课件链接: https://view.officeapps.live.com/op/view.aspx?srchttps%3A%2F%2Fspeech.ee.ntu.edu.tw%2F~hylee%2Fml%2Fml2021-course-data%2Fxai_v4.pptx&…

【数据结构与算法】数据结构的基本概念,时间复杂度

&#x1f349;内容专栏&#xff1a;【数据结构与算法】 &#x1f349;本文脉络&#xff1a;数据结构和算法的基本概念&#xff0c;时间复杂度 &#x1f349;本文作者&#xff1a;Melon西西 &#x1f349;发布时间 &#xff1a;2023.2.21 目录 一、引入&#xff1a; 二、数据结…

一键恢复和重装系统的区别是什么

如果电脑出现系统故障问题的时候&#xff0c;我们的电脑系统还原和系统重装都是很好的解决方法之一&#xff0c;不过是二者之间是有区别的。那么我们的电脑系统还原和系统重装有什么区别呢?现在就跟大家聊聊电脑一键恢复和重装系统的区别有哪些。 工具/原料&#xff1a; 系统…

P6软件应用的核心收益

卷首语 提供了多用户、多项目的功能模块&#xff0c;支持多层次项目等级划分&#xff0c;资源分配计划&#xff0c;记录实际数据&#xff0c;自定义视图&#xff0c;并具有用户定义字段的扩展功能。 利用最佳实践&#xff0c;建立企业模板库 P6软件支持用户使用模板编制项目…

Arduino-交通灯

LED交通灯实验实验器件&#xff1a;■ 红色LED灯&#xff1a;1 个■ 黄色LED灯&#xff1a;1 个■ 绿色LED灯&#xff1a;1 个■ 220欧电阻&#xff1a;3 个■ 面包板&#xff1a;1 个■ 多彩杜邦线&#xff1a;若干实验连线1.将3个发光二极管插入面包板&#xff0c;2.用杜邦线…

Sqoop介绍_以及安装_测试---大数据之Apache Sqoop工作笔记001

这个sqoop主要是用来,把数据从mysql中导入到hdoop中,去看看介绍吧. sql to hadoop 然后我们来看看sqoop,可以看到这里稳定版本是1.4.7 然后1.4.7 跟centos6.8 不是太好配置 这里用了1.4.6 但是如果用1.4.7 和centos7 还行 可以看看官网,这里sqoop1 跟sqoop2 这里标注了s…

【论文笔记】Manhattan-SDF == ZJU == CVPR‘2022 Oral

Neural 3D Scene Reconstruction with the Manhattan-world Assumption 本文工作&#xff1a;基于曼哈顿世界假设&#xff0c;重建室内场景三维模型。 1.1 曼哈顿世界假设 参考阅读文献&#xff1a;Structure-SLAM: Low-Drift Monocular SLAM in Indoor EnvironmentsIEEE IR…

【原创】java+swing+mysql宿舍管理系统设计与实现

今天我们主要来介绍如何使用swing图形化gui工具和mysql数据库去开发一个学生宿舍管理系统&#xff0c;这样一个比较经典的项目&#xff0c;学生宿舍管理系统&#xff0c;相信都很多人都不同程度的写过&#xff0c;从实现上来说不难。 功能分析&#xff1a; 学生宿舍管理系统&…

mysql中利用sql语句修改字段名称,字段长度等操作(亲测)

在网站重构中&#xff0c;通常会进行数据结构的修改&#xff0c;所以添加&#xff0c;删除&#xff0c;增加mysql表的字段是难免的&#xff0c;有时为了方便&#xff0c;还会增加修改表或字段的注释&#xff0c;把同字段属性调整到一块儿。这些操作可以在phpmyadmin或者别的mys…

Lazada选品推荐,这些爆品成了东南亚开年大赢家

小编今日整理了最新快消行业情报&#xff0c;带您解读东南亚市场玩具、母婴、美妆、食品、宠物类目的最新热销品类和发展方向&#xff0c;宠物。赶在大促前为商家朋友们助力一波&#xff01;STEM玩具、精细化拟人化宠物食品、便携香水……一大波商机正在赶来&#xff01;准备好…

编译链接实战(9)elf符号表

文章目录符号的概念符号表探索前面介绍了elf文件的两种视图&#xff0c;以及两种视图的各自几个组成部分&#xff1a;elf文件有两种视图&#xff0c;链接视图和执行视图。在链接视图里&#xff0c;elf文件被划分成了elf 头、节头表、若干的节&#xff08;section&#xff09;&a…

C++项目——高并发内存池(2)——thread_cache的基础功能实现

1.并发内存池concurrent memory pool 组成部分 thread cache、central cache、page cache thread cache&#xff1a;线程缓存是每个线程独有的&#xff0c;用于小于64k的内存的分配&#xff0c;线程从这里申请内存不需要加锁&#xff0c;每个线程独享一个cache&#xff0c;这…

算法学习与填充计划---2023.2.21---夏目

&#x1f680;write in front&#x1f680; &#x1f4dd;个人主页&#xff1a;认真写博客的夏目浅石.CSDN &#x1f381;欢迎各位→点赞&#x1f44d; 收藏⭐️ 留言&#x1f4dd;​ &#x1f4e3;系列专栏&#xff1a;ACM周训练题目合集.CSDN &#x1f4ac;总结&#xff1a…

继 承

1.继承继承是面向对象三大特性之一有些类与类之间存在特殊的关系继承的好处: 减少重复代码语法: class 子类: 继承方式 父类子类也称为派生类 父类也称为基类class Python : public BasePage {public :void Content() {}};2.继承方式继承方式一共有三种:公共继承保护继承私有继…

Homekit智能家居一智能吸顶灯

买灯要看什么因素 好灯具的灯光可以说是家居的“魔术师”&#xff0c;除了实用的照明功能外&#xff0c;对细节的把控也非常到位。那么该如何选到一款各方面合适的灯呢&#xff1f; 照度 可以简单理解为清晰度&#xff0c;复杂点套公式来说照度光通量&#xff08;亮度&#x…

ChatGPT为什么不受开发者喜欢?

记得 ChatGPT 最开始上线不久的时候&#xff0c;看到的大部分尝鲜和测试结果都是开发者在做进行敲代码测试&#xff0c;可以说职业危机感非常强的一群人了。 再者&#xff0c;加上 ChatGPT 要使用起来其实是有一些技术门槛的&#xff0c;愿意折腾的人也多是程序员&#xff0c;…

操作系统和进程的资源消耗

free -h 获取操作系统当前内存Mem 行(第二行)是内存的使用情况。Swap 行(第三行)是交换空间的使用情况。total 列显示系统总的可用物理内存和交换空间大小。used 列显示已经被使用的物理内存和交换空间。free 列显示还有多少物理内存和交换空间可用使用。shared 列显示被共享使…

基于龙芯 2K1000 的嵌入式 Linux 系统移植和驱动程序设计(一)

2.1 需求分析 本课题以龙芯 2K1000 处理器为嵌入式系统的处理器&#xff0c;需要实现一个完成的嵌入式软件系统&#xff0c;系统能够正常启动并可以稳定运行嵌入式 Linux。设计网络设备驱 动&#xff0c;可以实现板卡与其他网络设备之间的网络连接和文件传输。设计 PCIE 设备驱…