model.py篇

news2025/7/18 10:00:19

model.py篇

目录如下:

  • 引言
  • 找LeNet5网络结构
  • 书写代码
  • 测试结果
  • 函数解释

引言

卷积主要用于特征的提取,而model.py则是为了从输入信息中筛选出我们需要的信息。

我们在阅读完论文后,对我们需要的模型进行搭建,下以LeNet5的model为例:

找LeNet5网络结构

在这里插入图片描述

我们使用微信截图或者command+shift+4截图,使图片悬浮于最上层,观察图片书写自己的网络。

书写代码

在该步骤中,我们需要创建网络名class,继承自nn.Module,在该类中需要重写__init__(self)方法和forward(self)方法。__init__()方法用以搭建网络模型,forward()方法用以接收batch个input,正向传播后输出batch个output。

我们使用torchsummary.summary包对模型进行可视化,对输出结果进行输出。

'''
输入[batch, 3, 28, 28]
输出[batch. 10]
'''
# --- add path
import sys, os
from turtle import forward
project_dir = os.path.dirname(__file__)
sys.path.append(project_dir)
# ---
# --- import package
import torch
import torch.nn as nn
# ---


class LeNet(nn.Module):
    """model

    Args:
        torch (_type_): _description_
    """

    def __init__(self):
        """build model"""
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(5, 5))  # (in_channels, out_channels, kernel_size)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)     # (kernel_size, stride)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)       # full connect to 1 dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):                       # x is [batch, channel, height, width]
        """forward propagation

        Args:
            x (Tensor): [batches, channels, height, width]
        """
        x = self.relu((self.conv1(x)))          # input[batch, 3, 32, 32]   output[batch, 16, 28, 28]   tensor in pytorch is [batch, channel, height, width]
        x = self.pool1(x)                       # height and width become 1/2   output(16, 14, 14)
        x = self.relu((self.conv2(x)))          # output[batch, 32, 10, 10]
        x = self.pool2(x)                       # output[batch, 32, 5, 5]
        x = x.view(-1, 32*5*5)                  # output(32*5*5)    # -1 means Automated reasoning
        x = self.relu(self.fc1(x))              # output(120)
        x = self.relu(self.fc2(x))              # output(84)
        x = self.fc3(x)                         # output(10)
        return x


if __name__ == "__main__":
    """Visual model"""
    from torchsummary import summary
    model = LeNet()
    summary(model, input_size=(3, 32, 32))

测试结果

我们使用torchsummary.summary包对模型进行可视化,对输出结果进行输出。

在这里插入图片描述

函数解释

在数学网络结构时,遇到不熟悉的函数就去pytorch官网手册查找是一个不错的学习方法。下列几个常见函数。可以参考大纲当做词典进行查找。

def init(self):

一定要重写父类的__init__方法:

super(LeNet, self).__init__()

torch.nn.Conv2d()

该函数即创建卷积层,他的函数声明如下:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, 
				stride=1, padding=0, dilation=1, groups=1, bias=True, 
				padding_mode='zeros', device=None, dtype=None)
  • in_channels即输入特征矩阵维度,由上一层输出矩阵决定,等于该层卷积核的channels(维数)
  • out_channels即输出特征矩阵维度,由该层卷积核的numbers(个数)决定。
  • kernel_size(卷积核大小),stride(步长),padding(填补)三因素再加in_size(输入特征矩阵大小)共同决定out_size(输出特征矩阵大小),公式为out_size = ( in_size - kernel_size + 2 * padding ) / stride +1

torch.nn.MaxPool2d()

该函数即创建最大池化层,下采样层的一种,它的函数声明如下:

torch.nn.MaxPool2d(kernel_size, 
					stride=None, padding=0, dilation=1, 
					return_indices=False, ceil_mode=False)
  • 该层不改变channels,只影响输出矩阵的大小out_size(stride一般默认等于kernel_size),计算公式为out_size = ( in_size - kernel_size + 2 * padding ) / stride + 1

torch.nn.Linear()

该函数即创建全连接层,全连接层是一维的。在书写该函数时我们先要通过公式计算最终离全连接层最近的那一层的输出矩阵的out_size为多少,即输入全连接层的自变量个数为多少。Linear()函数的声明如下:

torch.nn.Linear(in_features, out_features, 
				bias=True, device=None, dtype=None)
  • in_features 表示输入全连接层的参数个数
  • out_features 表示输出全连接层的参数个数

def forward(self, x):

重写forward方法,即正向传播过程,其中x即输入,x的通道排列顺序为PyTorch接受的Tensor的通道排列顺序:[batch, channel, height, width]

self.conv1(x)

self.conv1是前面在__init__()函数中定义的对象,为何对象名后能直接加小括号添加变量呢?这里其实调用了__call__函数。

call():Python中,只要在创建类型的时候定义了__call__()方法,这个类型就是可调用的。Python中的所有东西都是对象,其中包括int/str/func/class这四类,它们都是对象,都是从一个类创建而来的。元类就是创建这些对象的东西,type就是Python的内建元类。其中,self.conv1是可调用的对象,说明在创建它的类型(父类或它本身)的时候,定义了__call__()方法。

# 下面两种调用方法是等价的
x = torch.nn.functional.relu(self.conv1.__call__(x))
x = torch.nn.functional.relu(self.conv1(x))

self.pool1(x)同理也调用了__call__()

torch.nn.functional.relu()

函数定义如下,传入tensor进行relu处理后,传出tensor。需要注意的是relu()是不需要训练参数的。

torch.nn.functional.relu(input, inplace=False) → Tensor

torch.Tensor.view()

view函数即展平操作,在进入全连接层前需要进行展平处理view(x, y),其中y为你要接受的input参数,如3255,根据Tensor通道排序,x为batch值,我们往往将x=-1进行自动推理。

torch.nn.Sequential()

该函数的主要作用就是在搭建网络过程中,继承nn.Module的网络模型类,不用在其__init__()函数中为每一个卷积层添加变量名,可以将这些层封装进一个对象中,体现面向对象封装的思想。

函数原型如下:

torch.nn.Sequential(*args)

example如下:

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

torch.nn.ReLU()

ReLU和relu的区别要弄清楚,在使用sequential的时候更推荐使用ReLU()函数,该函数不需要传入tensor参数,直接对上文的tensor进行ReLU。

函数声明如下:

# ReLU
torch.nn.ReLU(inplace=False)
# relu
torch.nn.functional.relu(input, inplace=False) → Tensor

判断torch.Tensor中的每个元素是否相等

使用torch.Tensor.sum()用来计量一维Tensor中每个元素是否相等。使用torch.Tensor.item()将Tensor中的单个元素转成数值。要将他全部转成一维的才可用sum()。

# 判断是否相等
sum = 0
for i in range(2):
    sum += (c[i,] == b[i,]).sum().item()
print(sum)

torch.flatten()

flatten()函数的使用同torch.Tensor.view()函数使用,flatter指定将特征矩阵从第几个维度开始压缩,view指定矩阵任意维度,推荐使用view。如下:

# test flattern
import torch
a = torch.randn(2, 3, 4, 5)
b = a.view(2, 60)
c = torch.flatten(a, start_dim=1)
print(c.size())	# b, c的shape相等
print(b.size())

python 寻找上一级目录

寻找上一级目录使用os.path.dirname(),如下:

# 找上一级目录文件 -- 使用dirname
import os
my_path = __file__
find_path = os.path.dirname()

torch.nn.kaiming_normal_()

一种数学方法,用于初始化权重值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/17778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

子域名访问计数(哈希表、字符串、索引)

力扣地址:力扣 网站域名 "discuss.leetcode.com" 由多个子域名组成。顶级域名为 "com" ,二级域名为 "leetcode.com" ,最低一级为 "discuss.leetcode.com" 。当访问域名 "discuss.leetcode.com&…

【Struts2】idea快速搭建struts2框架

文章目录什么是SSH框架?Struts2框架1、struts2的环境搭建1.1 创建web项目(maven),导入struts2核心jar包1.2 配置web.xml(过滤器),是struts2的入口,先进入1.3 创建核心配置文件struts…

力扣(LeetCode)13. 罗马数字转整数(C++)

模拟 罗马数字和掰手指数数的区别在于,IV/IXIV/IXIV/IX 这类倒着数数的,和阿拉伯数字最大的区别在于 555 的 10k10^k10k 倍 k∈Nk\isin Nk∈N ,需要被表示出来。所以除了记录 I/X/C/MI/X/C/MI/X/C/M ——1/10/100/10001/10/100/10001/10/100…

五种IO模型

文章目录什么是IO操作系统的IO五种IO模型阻塞IO非阻塞IO多路转接IO(复用IO)信号驱动IO异步IO同步异步什么是IO IO,即input/output,IO模型即输入输出模型,而比较常见且听说的便是磁盘IO,网络IO. 按照冯诺依曼结构的来看,假设我们把运算器、控制器、存储器三个设备看做一个整体…

Kruskal算法求最小生成树

输入样例: 4 5 1 2 1 1 3 2 1 4 3 2 3 2 3 4 4输出样例: 6适用于稀疏图,快; 实现步骤: 1.将所有边将权重从小到大排序;sort;O(mlogm) 2.枚举每条边a,b;权…

RabbitMQ的 AMQP协议都是些什么内容呢

之前也讲述过关于 RabbitMQ 的相关内容,比如他们的配置,以及 RabbitMQ 整合 SpringBoot 使用,而且自己使用过之后,就会在自己的简历上面写上自己使用 RabbitMQ 实现了什么功能,但是这就会导致,有些面试官就…

QPainter、QPen 、QBrush(概念)

Qt中的三大绘画类: QPainter :进行绘画QPaintDevice :提供画图设备,是一个二维的抽象(是所有可绘制对象的基类)QPaintEngine :提供了画家用于绘制到不同类型的设备上的界面QPainter(画家) QPainter 提供高度优化的功能来完成 GUI …

Presto 聚合中groupBy分组的实现

一.前言 本文只要探索在Presto中groupby是怎么实现的。在Preso中,groupby的分组主要通过对数据Hash的数值比较进行分组,其中有2种情况,一直是仅有一个groupby字段而且字段是Bigint类型的,此场景下会使用BigintGroupByHash来实现分…

Spring boot使用ProGuard实现代码混淆

目录参考一、 ProGuard简介二、混淆配置要点三、快速开始方案一 配置文件新增proguard.cfg配置插件打包方案二 pom中定义配置参考 Spring boot使用ProGuard实现代码混淆 SpringBoot 玩一玩代码混淆,防止反编译代码泄露 代码混淆常见于安卓的apk安装文件, 服务端的…

11月千言最新评测推荐,覆盖中文对话、视频语义理解、可信AI等前沿方向

千言数据集是百度联合中国计算机学会、中国中文信息学会共同发起的数据共建计划,千言针对每个自然语言处理问题,均收集和整理多个开源数据集,进行统一的处理并提供统一的测评方式,帮助加速模型的研发。截至目前,千言评…

[附源码]java毕业设计上海景宏不锈钢厨房设备报修系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

云原生系列 【轻松入门容器基础操作】

✅作者简介: CSDN内容合伙人,全栈领域新星创作者,阿里云专家博主,华为云云 享专家博主,掘金后端评审团成员 💕前言: 最近云原生领域热火朝天,那么云原生是什么?何为云原生…

数据分析 | Pandas 200道练习题 进阶篇(3)

文章目录DA21 大佬用户成就值比例DA22 牛客网用户最高的正确率DA23 统计牛客网用户的名字长度DA24 去掉信息不全的用户DA25 修补缺失的用户数据DA26 解决牛客网用户重复的数据总结:❤️ 作者简介:大家好我是小鱼干儿♛是一个热爱编程、热爱算法的大三学生…

生信步骤|MAFFT结合HMMER进行多序列比对和基于隐马模型的基因搜索

蛋白质都是由相似的小型结构域组成的。如果我们有若干个已知的蛋白序列,那我们就可以根据这些蛋白序列比较其含有的保守域,寻找在蛋白数据库中上是否也有一样保守域的蛋白。而后根据统计学模型,将显著性较高的蛋白序列预测为同一类基因家族蛋…

Oracle SQL执行计划操作(5)——分区相关操作

5. 分区相关操作 该类操作与SQL语句执行计划中分区表操作相关。根据不同的具体SQL语句及其他相关因素,如下各操作可能会出现于相关SQL语句的执行计划。 1)PARTITION RANGE ALL 对范围分区(RANGE PARTITION)表的所有分区进行子…

内存泄漏检测C版小工具

一 内存泄漏简介 内存泄漏(Memory Leak)是指程序中己动态分配的堆内存由于某种原因程序未释放或无法释放,造成系统内存的浪费,导致程序运行速度减慢甚至系统崩溃等严重后果。 内存泄漏分类: 1.堆内存泄漏&#xff1…

基于LMI的非线性混沌系统滑模控制

目录 前言 1.非线性系统 2.控制器设计 3.仿真分析 3.1仿真混沌系统 3.2 LMI求解反馈阵F 3.3仿真模型 ​​​​3.4仿真结果 3.5注意事项 前言 前面我们介绍了很多种滑模面设计,以及介绍了几篇结合LMI的滑模控制,其核心思想可以看作是用LMI去控制…

【python与数据分析】Numpy数值计算基础——补充

目录 二、矩阵生成与常用操作 1.生成矩阵 2.矩阵转置 3.查看矩阵特征 4.矩阵乘法 5.计算相关系数矩阵 6.计算方差、协方差、标准差 7.行列扩展 8.常用变量 9.矩阵在不同维度上的计算 10.应用 (1)使用蒙特卡罗方法估计圆周率的值 &#xff0…

【Transformers】第 10 章 :从零开始训练 Transformer

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

JS实现复制富文本到剪贴板/粘贴板的最佳实践

背景 最近有想实现一个功能,通过点击一个button按钮,来复制网页内容(含html)来实现复制后粘贴到邮件或者word具有富文本的效果。在网站翻了一些资料,要么就是方法已经被弃用,要么就是兼容性特别差,要么就是不能复制成…