Pytorch从零开始训练模型【识别数字模型】并测试

news2025/7/13 22:56:37

1 准备数据集

import torch
import torchvision
# 去网上下载CIFAR10数据集【此数据集为经典的图像数字识别数据集】
# train = True 代表取其中得训练数据集;
# transform 参数代表将图像转换为Tensor形式
# download 为True时会去网上下载数据集到指定路径【root】中,若本地已有此数据集则直接使用
train_data=torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
# 测试数据集
test_data=torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 查看数据集长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))
print('测试数据集长度为:{}'.format(test_data_size))
image-20230223193856588

2 加载数据集

from torch.utils.data import DataLoader
# 将数据集加载,每64张作为一组
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

3 搭建神经网络

3.1 原理演示

image-20230223194808453

Pytorch图像输入输出格式图像变换官方文档

此处仅举一个例子,即输入图像为 3 @ 32 × 32 3@32\times 32 3@32×323通道,高为32,宽为32】以通过 5 × 5 5\times5 5×5卷积核进行卷积操作时需要填写哪些参数,使其变成 32 @ 32 × 32 32@32\times 32 32@32×3232通道,高为32,宽为32】的图像

image-20230223200623643

不知道卷积是啥意思的可以参考如下👇视频

B站土堆说卷积操作

官方文档中关于二维图像的卷积函数中的参数是这样解释的【此处仅对需要用到的部分进行注释】

image-20230223201030306

输入图像与输出图像的对应数学关系如下👇

image-20230223201538197

图上没有采用空洞卷积,dilation参数默认是1,假设stride通常都是1,先尝试
H o u t = ⌊ H i n + 2 × p a d d i n g [ 0 ] − d i l a t i o n [ 0 ] × ( k e r n e l _ s i z e [ 0 ] − 1 ) − 1 s t r i d e [ 0 ] + 1 ⌋ 32 = ⌊ 32 + 2 × p a d d i n g [ 0 ] − 1 × ( 5 − 1 ) − 1 1 + 1 ⌋ p a d d i n g [ 0 ] = 2 W o u t = ⌊ W i n + 2 × p a d d i n g [ 1 ] − d i l a t i o n [ 1 ] × ( k e r n e l _ s i z e [ 1 ] − 1 ) − 1 s t r i d e [ 1 ] + 1 ⌋ ∴ 同理可得, p a d d i n g [ 1 ] 也为 2 H_{out}=\left \lfloor \dfrac{H_{in}+2\times padding[0]-dilation[0]\times(kernel\_size[0]-1)-1}{stride[0]}+1 \right \rfloor \\ 32=\left \lfloor \dfrac{32+2\times padding[0] - 1\times (5-1)-1}{1}+1 \right \rfloor \\ padding[0]=2\\ \\ W_{out}=\left \lfloor \dfrac{W_{in}+2\times padding[1]-dilation[1]\times(kernel\_size[1]-1)-1}{stride[1]}+1 \right \rfloor \\ \therefore 同理可得,padding[1]也为2 Hout=stride[0]Hin+2×padding[0]dilation[0]×(kernel_size[0]1)1+132=132+2×padding[0]1×(51)1+1padding[0]=2Wout=stride[1]Win+2×padding[1]dilation[1]×(kernel_size[1]1)1+1同理可得,padding[1]也为2
padding直接传入参数2即可,默认会将padding这个tuple都赋为2

# 因此进行如上变换需要传参如下
self.conv1 = Conv2d(in_channels=3, out_channels=32,kernel_size=5, stride=1, padding=2)

3.2 代码实现

from torch.nn import Conv2d,MaxPool2d,Flatten,Linear
# 神经网络模型
class TrainModule(nn.Module):
    def __init__(self):
        super().__init__()
        # 按照上边参考图对数据依次进行如下处理,最后得到的就是关于当前图像的分类概率预测
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32,
                      kernel_size=5, stride=1, padding=(2, 2)),
            # 最大池化操作
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            # 扁平化
            nn.Flatten(),
            # 线性处理
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x


net = TrainModule()
# 以下部分均为测试神经网络功能
# 模拟输入数据:64张图片。每个图的格式为3*32*32
input = torch.ones((64, 3, 32, 32))
# 调用神经网络类时会调用forward方法
output = net(input)
# 输出结果为torch.Size([64, 10]),说明返回结果是64张图片再十个数字类别中的概率
print(output.shape)

4 损失函数

我们通过比对预测值和真实值的差距从而得知模型的训练优劣,损失函数将这个指标数据化,损失函数越小,说明模型训练得越好

import tensorboard
from torch.utils.tensorboard import SummaryWriter
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器,SGD为随机梯度下降法
# 传入需要梯度下降的参数以及学习率,1e-2等价于0.01
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)

# 记录训练次数
total_train_step = 0
# 记录测试次数
total_test_step = 0
# 训练轮数【电脑性能有限,只打个样例】
epoch = 2
writer = SummaryWriter('./logs')
for i in range(epoch):
    print('------第{}轮训练开始------'.format(i+1))
    # 训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        # 将图片传入神经网络后得到输出结果
        outputs = net(imgs)
        # 将输出结果与原标签进行比对,计算损失函数
        loss = loss_fn(outputs, targets)
        # 在应用层面可以简单理解梯度清零,反向传播,优化器优化为三个固定步骤
        # 梯度清零
        optimizer.zero_grad()
        # 反向传播,更新权重
        loss.backward()
        # 对得到的参数进行优化
        optimizer.step()
        total_train_step += 1
        # 为避免打印太多,训练100次才打印1次
        if total_train_step % 100 == 0:
            # loss.item()作用是把tensor转为一个数字
            print('------训练次数:{},Loss:{}------'.format(total_train_step, loss.item()))
            writer.add_scalar('train_loss', loss.item(), total_train_step)

    # 测试步骤开始
    total_test_loss = 0
    total_accuracy = 0
    # with的这段语句可以简单理解为提升运行效率
    with torch.no_grad():
        # 拿测试集中的数据来验证模型
        for data in test_dataloader:
            imgs, targets = data
            outputs = net(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            # agrmax(1)是将tensor对象按行看最大值下标进行存储,此处是数字图像,因此最大值下标实则就是我们的预测值
            # 此处是拿标签进行验证,统计预测正确的概率,方便后边计算正确率
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    print('整体测试集上的Loss:{}'.format(total_test_loss))
    print('整体测试集上的正确率:{}'.format(total_accuracy/test_data_size))
    writer.add_scalar('test_loss', total_test_loss, total_test_step)
    total_test_step += 1
    # 将每轮训练的模型都进行保存,方便以后直接调用训练完毕的模型
    torch.save(net, 'tarin_{}.pth'.format(total_test_step))

writer.close()
image-20230224135713072 image-20230224140022012

5. 利用GPU训练【优化训练速度】

如何电脑有GPU的话,优先 利用GPU进行训练速度会快很多

# 在有cuda方法的部分都加上
if torch.cuda.is_available():
    # 把数据交给GPU处理
	imgs = imgs.cuda()
	targets = targets.cuda()

如下方法也可以【常用】

# 指定device指向设备的第一张显卡
device = torch.device('cuda:0')
# 优先使用这张显卡处理
imgs = imgs.to(device)

# 也可以这样指定防止出错
device = torch.device( "cuda" if torch.cuda.is_available() else "cpu")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/369327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用Python获取弹幕的两种方式(一种简单但量少,另一量大管饱)

前言 弹幕可以给观众一种“实时互动”的错觉,虽然不同弹幕的发送时间有所区别,但是其只会在视频中特定的一个时间点出现,因此在相同时刻发送的弹幕基本上也具有相同的主题,在参与评论时就会有与其他观众同时评论的错觉。 在国内…

零基础想转行学习Python,该如何学习,有学习路线分享吗?(2023年给初学者的建议)

Python属于一种面向对象、解释性的高级语言,它如今在众多领域都被应用,包括操作系统管理、Web开发、服务器运维的自动化脚本、科学计算、桌面软件、服务器软件(网络软件)、游戏等方面,且Python在今后将被大规模地应用到大数据和人工智能方面。…

MPI编程size总为1 解决方案

之前遇到的问题在这儿mpi编程 comm.Get_rank()全为0而comm.Get_size()全为1应该怎么办?_wennyLee的博客-CSDN博客 后来,我尝试了在pycharm的terminal中运行程序 mpiexec -n 4 test.py 但是又出现了新的问题↓ 然后为了解决”不是有效的 Win32 应用程序…

前端白屏的检测方案,让你知道自己的页面白了

前言 页面白屏,绝对是让前端开发者最为胆寒的事情,特别是随着 SPA 项目的盛行,前端白屏的情况变得更为复杂且棘手起来( 这里的白屏是指页面一直处于白屏状态 ) 要是能检测到页面白屏就太棒了,开发者谁都不…

线性数据结构:链表 LinkList

一、前言 链表的历史 于1955-1956年,由兰德公司的Allen Newell、Cliff Shaw和Herbert A. Simon开发了链表,作为他们的信息处理语言的主要数据结构。链表的另一个早期出现是由 Hans Peter Luhn 在 1953 年 1 月编写的IBM内部备忘录建议在链式哈希表中使…

推荐使用什么样的平台表单制作工具好?

在办公自动化迅猛发展的今天,传统的表单制作工具已经不能满足各行各业的生产需求,引用专业的低代码开发平台表单制作工具可以助力企业提高作业协作效率。那么,什么平台的表单制作工具可以实现这一目的?今天,我们就一起…

月薪过3W的软件测试工程师,都是怎么做到的?

对任何职业而言,薪资始终都会是众多追求的重要部分。前几年的软件测试行业还是一个风口,随着不断地转行人员以及毕业的大学生疯狂地涌入软件测试行业,目前软件测试行业“缺口”已经基本饱和。当然,我说的是最基础的功能测试的岗位…

用规则来搭建团队:写周报不一定是坏事

你好,我是Smile,一位有二十年工作经验的技术专家。今天我会结合我的经历,和你聊聊搭建技术团队这个话题。 众所周知,技术团队很大程度上决定了一个公司业务的生命力和生命周期,因此技术团队的投入成本往往很高&#x…

快到金3银4了,准备跳槽的可以看看

前两天跟朋友感慨,今年的铜九铁十、裁员、疫情导致好多人都没拿到offer!现在已经12月了,具体明年的金三银四只剩下两个月。 对于想跳槽的职场人来说,绝对要从现在开始做准备了。这时候,很多高薪技术岗、管理岗的缺口和市场需求也…

百度Q4:亮剑AI,重塑金身

2月22日港股盘后,港交所最纯正的AI概念股百度发布了2022年第四季度以及全年的业绩报告,在新冠疫情冲击宏观经济的第四季度,百度经营利润、经营利润率也均实现同比增长。 凭借在AI领域的长期投入,尽管有疫情侵扰、外部环境所带来的…

HACKTHEBOX——Curling

nmap还是老规矩,先扫描目标对外开放端口情况,只发现了22和80端口对外开启nmap -sV -sC -oA nmap 10.10.10.150http80端口对外开启,从扫描结果来看好像运行着Joomla,所以先访问看看,可以看到帖子由super user撰写&#…

教你用反射机制如何几分钟搭建完后端

如果想快速搭建后台跨域使用这些技术 反射mybatis-plusjson 反射可以实现动态数据的传输 一般对数据库进行操作肯定离不开这些代码 如果我们用反射机制只需要这一个就行 而说到反射的好处,一般情况下我们做增删改查需要大量的接口才能完成,而用反射我…

2023如果纯做业务测试的话,在测试行业有出路吗?

直接抛出我的结论:手工做业务类测试,没有前途。 个人建议赶紧从业务测试跳出来,立即学习代码,走自动化测试方向。目前趋势,业务测试需要用自动化做。 为了让大家能够信服我的观点,本文将从以下方面进行阐…

LeetCode题目笔记——2357. 使数组中所有元素都等于零

文章目录题目描述题目链接题目难度——简单方法一:直接模拟代码/Python方法二:哈希表代码/Python总结题目描述 给你一个非负整数数组 nums 。在一步操作中,你必须: 选出一个正整数 x ,x 需要小于或等于 nums 中 最小…

嵌入式系统硬件设计与实践(学习方法)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 刚读书的时候,对什么是嵌入式,其实并不太清楚。等到自己知道的时候,已经毕业很多年了。另外对于计算机毕业的学…

RK3588关键电路 PCB Layout设计指南

1、音频接口电路 PCB 设计(1)所有 CLK 信号建议串接 22ohm 电阻,并靠近 RK3588 放置,提高信号质量;(2)所有 CLK 信号走线不得挨在一起,避免串扰;需要独立包地&#xff0c…

jianzhiOffer第二版难重点记录

04. 二维数组中的查找https://leetcode.cn/problems/er-wei-shu-zu-zhong-de-cha-zhao-lcof/ 思路:可以每层用以恶搞二分查找,优化思路:从左下角出发直接用二分。 ​​​​​​07. 重建二叉树https://leetcode.cn/problems/zhong-jian-er-cha…

Redis 常用数据类型之 zset

目录 一、zset数据结构 二、Redis的zset 三、详细操作 基础操作(zadd、zcrad、zcount) 排序操作(zrange 、zrevrange ) 根据分数显示元素(zrangebyscore) 删除操作(zrem、zremrangebyran…

DSPE-PEG-TCO;磷脂-聚乙二醇-反式环辛烯科研用化学试剂简介

中文名称 磷脂-聚乙二醇-反式环辛烯 英文名称 DSPE-PEG-TCO 外观:粉末或半固体,取决于分子量。 溶剂:溶于大部分有机溶剂,如:DCM、DMF、DMSO、THF等等。在水中有很好的溶解性 稳定性:冷藏保存&#xff…

安装包UI美化之路-通过nsNiuniuSkin来做Electron程序的打包、发布与升级

nsNiuniuSkin从发布之初,因其简单、简洁、高效,受到了非常多公司的青睐,现在已经越来越多的公司采用我们的这套解决方案来制作安装包了! 从一个安装包UI插件,逐步演化成一套集美观、安全、简洁、自动化为一体的完整的…