PointNet网络模型代码解析

news2025/9/17 20:57:30

PointNet网络模型代码解析

    • T-Net3d
    • T-Netkd
    • FeatNet
    • PointNetCls
    • 网络结构可视化

论文地址:https://arxiv.org/pdf/1612.00593
参考代码地址:https://github.com/fxia22/pointnet.pytorch;

T-Net3d

首先数据输入为n*3,然后接一个T-net;

根据论文中的介绍:第一个变换网络是一个迷你PointNet,它以原始点云为输入,回归到3×3矩阵。它由每个点上的共享M LP(64,128,1024)网络(层输出大小为64,128,1024)、跨点的最大池和两个输出大小为512,256的完全连接层组成。输出矩阵初始化为单位矩阵。除最后一层外,所有层都包括ReLU和批处理规范化。

class TNet3d(nn.Module): # 注意这里一般都是继承pytorch中的基类模型
    def __init__(self):

        #super()函数用于调用父类的一个方法。
        # 具体来说,当你在一个类的方法中使用 super().method() 形式时,
        # 你实际上是在调用这个方法的父类实现。
        super(TNet3d,self).__init__()
        #MLP(多层感知机),主要作用是数据的升维处理
        self.conv1 = torch.nn.Conv1d(3,64,1)
        self.conv2 = torch.nn.Conv1d(64,128,1)
        self.conv3 = torch.nn.Conv1d(128,1024,1)

        #FC(全连接层),主要作用为数据降维,聚合特征
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,9)
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    # forward 函数通常用于定义数据通过神经网络时的前向传播逻辑
    # forward 方法定义了如何将输入数据传递通过各个层、激活函数,以及其他可能的操作,然后返回输出结果
    # 对一般网络而言查看forward函数最容易了解该网络结构
    def forward(self,x):
        batchsize = x.size()[0]

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x,2,keepdim=True)[0]

        x = x.view(-1,1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        # 该代码的目的是创建一个包含单位矩阵(经过展平后)的张量,并重复这个张量,使其适应批次处理。
        # 假设 batchsize 为 3,那么最终的结果是:
        # tensor([[1, 0, 0, 0, 1, 0, 0, 0, 1],
        #         [1, 0, 0, 0, 1, 0, 0, 0, 1],
        #         [1, 0, 0, 0, 1, 0, 0, 0, 1]])
        # 每一行都是一个展平的3x3单位矩阵。
        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)

        # iden和x是否使用GPU要一致
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1,3,3)
        return x

T-Netkd

这个类和前面的T-Net3d结果完全一样,只是利用k作为变量进行输入输出

class TNetkd(nn.Module):
    def __init__(self,k=64):
        super(TNetkd,self).__init__()
        self.conv1 = torch.nn.Conv1d(k,64,1)
        self.conv2 = torch.nn.Conv1d(64,128,1)
        self.conv3 = torch.nn.Conv1d(128,1024,1)
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k
    def forward(self,x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(-1,1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)

        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1,self.k,self.k)
        return x

FeatNet

PointNet中的特征提取层,结合了两个T-Net网络,进行数据升维,提取global feature

class PointNetFeat(nn.Module):
    def __init__(self,global_feat = True,feature_transform = False):
        super(PointNetFeat,self).__init__()
        self.stn = TNet3d()
        self.conv1 = torch.nn.Conv1d(3,64,1)
        self.conv2 = torch.nn.Conv1d(64,128,1)
        self.conv3 = torch.nn.Conv1d(128,1024,1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = TNetkd(64)

    def forward(self,x):
        n_pts = x.size()[2]     #记录点的数量
        trans = self.stn(x)     #第一TNet结构预测一个点的仿射变换矩阵,该仿射变换矩阵用于对输入进行空间变换,以提高模型的几何不变性和泛化能力。
        x = x.transpose(2,1)
        x = torch.bmm(x,trans)
        x = x.transpose(2,1)    #这三行代码是将点集与变换矩阵相乘

        x = F.relu(self.bn1(self.conv1(x))) #对变换后的点进行特征提取

        if self.feature_transform:      # 对特征进行仿射变换
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x,trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x   # 将特征层进行暂存,为后面分割使用
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))

        x = torch.max(x,2,keepdim=True)[0] # 最大池化层聚合所有点的信息
        x = x.view(-1,1024)
        if self.global_feat:
            return x,trans,trans_feat
        else:
            x = x.view(-1,1024,1).repeat(1,1,n_pts)
            return  torch.cat([x,pointfeat],1),trans,trans_feat

PointNetCls

特征设计完成后,根据论文就可以写出PointNet中的分类网络结果,论文中分类网络在得到global feature之后接了一个全连接层,就输出结果了,仔细想想是不是差了点什么,没错就是local feature,这也是PointNet++改进的地方,后续跟进PointNet++的讲解;

在类分数预测之前,在输出维度为256的最后一个完全连接层上使用保持比为0.7的丢弃;批量归一化的衰减率从0.5开始,逐渐增加到0.99;使用初始学习率为0.001、动量为0.9、批量大小为32的adam优化器,学习率每20个时期除以2;

class PointNetCls(nn.Module):
    def __init__(self,k=2,feature_transform=False):
        super(PointNetCls,self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetFeat(global_feat=True,feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,k)
        self.dropout = nn.Dropout(p=0.3) # 论文写的p=0.7,代码中给的0.3后续效果还需要测试之后在看
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self,x):
        x,trans,trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)

        return F.log_softmax(x,dim=1),trans,trans_feat

网络结构可视化

利用神经网络可视化工具torchview画出Tnet网络结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1822405.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows上安装redis,并且用pycharm联通调用测试

在 Windows 上启动 Redis,官网版本不支持windows直接安装,你可以按照以下步骤进行操作: 使用Github Redis 版本启动 Redis 如果你想使用 Redis 在 Windows 上启动 Redis,以下是基本的步骤: 下载 Redis: 访…

在录制视频的时候,自动出现英伟达(nvidia)显卡弹窗关闭方式

具体情况具体分析,我遇见的是录制视频在视频里面出现弹窗问题 显示效果 在使用录制视频工具进行录制,回放视频的时候,视频会自动弹出英伟达显卡的弹窗,这个我们不需要,就把他关闭 关闭方式 找到GeForce Experienc…

能耗监控与管理平台

在当今社会,随着工业化、城市化的快速发展,能源消耗问题日益凸显,节能减排已成为全社会共同关注的焦点。在这个背景下,一款高效、智能的能耗监控与管理平台显得尤为重要。 一、HiWoo Cloud平台的概念 HiWoo Cloud是一款集数据采…

【权威出版/投稿优惠】2024年智慧城市与信息化教育国际会议(SCIE 2024)

2024 International Conference on Smart Cities and Information Education 2024年智慧城市与信息化教育国际会议 【会议信息】 会议简称:SCIE 2024 大会时间:点击查看 大会地点:中国北京 会议官网:www.iacscie.com 会议邮箱&am…

江协科技51单片机学习- p7 独立按键控制LED灯

前言: 本文是根据哔哩哔哩网站上“江协科技51单片机”视频的学习笔记,在这里会记录下江协科技51单片机开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了江协科技51单片机教学视频和链接中的内容。 引用: 51单片机入门教程-2…

Windows中LoadLibrary加载动态库失败,详细解释(解决思路)

今天在开发的过程中,需要用到动态库里的一些接口,又不希望全部载入,在这过程中使用LoadLibrary加载dll时,出现问题,特此记录一下自己怎么解决的思路。 目录 先介绍一下这几个函数为以下错误分析做准备 GetProcAddres…

使用asyncua模块如何在opcua框架的Server端添加方法及在Client端调用方法

1. 在opcua框架的Server端添加方法 参考文章: freeopcua调用方法输入参数| Python解析数组到输入列表 为OPC UA python服务器/客户端添加安全性(异步) OPCUA和asyncua — [3] 添加方法 OPC UA的Server端新增方法的关键代码如下:…

如何评估员工在新版FMEA培训后应用知识的效果?

随着制造业的快速发展,新版FMEA已成为企业提升产品质量、减少故障风险的关键一环。然而,培训只是第一步,如何有效评估员工在新版FMEA培训后应用知识的效果,才是确保培训成果转化的关键所在。 评估员工知识应用效果的首要步骤是制定…

[深度学习]基于C++和onnxruntime部署yolov10的onnx模型

基于C和ONNX Runtime部署YOLOv10的ONNX模型,可以遵循以下步骤: 准备环境:首先,确保已经下载后指定版本opencv和onnruntime的C库。 模型转换:按照官方源码:https://github.com/THU-MIG/yolov10 安装好yolov…

揭秘裂变客户背后的心理学:如何触动用户分享欲望?

在当今的社交媒体时代,裂变客户——即用户主动分享并推广某一产品或服务,已成为企业营销的重要策略。那么,如何触动用户的分享欲望呢?这背后其实隐藏着深刻的心理学原理。本文将以looka这个知名的国外设计工具为例,为s…

最新下载:EasyRecovery易恢复软件安装视频教程

EasyRecovery电脑数据丢失如何恢复?有时候我们在清理电脑的时候会不小心把一些文件夹的数据误删了,在数据恢复时大家会寻找一些数据恢复软件,比如Easyrecover数据恢复软件,但是许多小伙伴还不知道要怎么操作,文件恢复的操作和原理…

有什么好用的ai智能写作手机版?6个软件帮助你快速进行智能写作

有什么好用的ai智能写作手机版?6个软件帮助你快速进行智能写作 AI智能写作在现代社会中扮演着越来越重要的角色,许多人依赖这些工具来提高写作效率和质量。以下是六款不同类型的AI智能写作手机应用,它们可以帮助你快速进行智能写作&#xff…

3X+3问题,角谷猜想的姐妹问题

3X3问题是角谷猜想(3X1)的延伸,可以说是孪生问题。 对于任何奇数x,乘以3再加3,再析出偶数,即,除以(m1,2,3,...。),&#x…

借助ollama实现AI绘画提示词自由,操作简单只需一个节点!

只需要将ollama部署到本地,借助comfyui ollama节点即可给你的Ai绘画提示词插上想象的翅膀。具体看详细步骤! 第一步打开ollama官网:https://ollama.com/,并选择models显存太小选择的是llama3\8b参数的instruct-q6_k的这个模型。 运…

Ubuntu18.04 安装 colmap

安装依赖 sudo apt-get install \git \cmake \ninja-build \build-essential \libboost-program-options-dev \libboost-filesystem-dev \libboost-graph-dev \libboost-system-dev \libeigen3-dev \libflann-dev \libfreeimage-dev \libmetis-dev \libgoogle-glog-dev \libgt…

GenICam标准(一)

系列文章目录 GenICam标准(一) GenICam标准(二) 文章目录 系列文章目录1、概述GenApiGenTLSFNC(标准特征命名约定)CLProtocolGenCP 参考 emva 1、概述 如今的数码摄相机包含了很多的功能,而不仅…

为什么Mamba模型被拒?

Mamba模型问世 最近,国际学习表征会议(ICLR)公布了2024年会议的最终决定,其中引起广泛关注的是一个名为Mamba的模型。这个模型最初被认为是对抗著名的Transformer架构进行语言建模任务的主要竞争者,但最终被拒绝&…

植物大战僵尸杂交版 fatal error及问题解决闪退

echo off set KEY_NAMESoftware\PopCap\PlantsVsZombies set VALUE_NAMEScreenmode set DATA0 reg add HKCU%KEY_NAME% /v %VALUE_NAME% /t REG_DWORD /d %DATA% /f if %errorlevel% neq 0 ( echo 注册表数值数据修改失败 ) else ( echo 注册表数值数据已成功修改为0 ) 将上述…

什么是场外期权?场外期权有几种做法?

今天带你了解什么是场外期权?场外期权有几种做法?期权分为场内期权,场外期权。场内期权我们都知道,是在期货盘里购买的期权,但场外期权呢? 什么是场外期权? 场外期权是一种在交易所之外进行交易…

组件二次封装,通过属性事件透传,插槽使用,组件实例方法的绑定,深入理解 Vue.js 组件扩展与插槽

透传,插槽,组件实例方法的绑定,深入理解 Vue.js 组件扩展与插槽 前言 Vue.js 提供了强大的组件化系统,允许开发者构建可复用、可组合的UI组件。在实际项目中,直接使用第三方库提供的基础组件(如Element UI…