PyTorch搭建LeNet训练集详细实现

news2025/5/26 8:19:42

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,
                                          shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                         shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()

classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):
    img = img / 2 + 0.5
    nping = img.numpy()
    plt.imshow(np.transpose(nping, (1, 2, 0)))
    plt.show()


# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

# 将创建的模型实例化
net = LeNet()  # 实例化
loss_fuction = nn.CrossEntropyLoss()  # 定义损失函数

# 通过优化器将所有可训练的参数都进行训练,lr是learningrate学习率
optimizer = optim.Adam(net.parameters(), lr=0.001)

#通过for循环实现训练过程,循环几次就是将训练集迭代多少次
for epoch in range(5):

    running_loss = 0.0  # 用来累加在学习过程中的损失
    for step, data in enumerate(trainloader, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()   # 历时损失梯度清零。

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_fuction(outputs, labels)  # 计算神经网络的预测值和真实标签之间的损失
        loss.backward()
        optimizer.step()  # step()函数实现参数更新

        # print statistics  打印数据的过程
        running_loss += loss.item()
        if step % 500 == 499:  # 每隔500步打印一次数据的信息
            with torch.no_grad():  # 上下文管理器
                outputs = net(test_image)
                predict_y = torch.max(outputs, dim=1)[1]
                accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)

                print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
                       (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0

print('Finished Training')

# 将模型保存到文件夹中
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

 详细解释:

比较重点的单独解释了,其他的在注释中。

 optimizer.zero_grad()   # 历时损失梯度清零。

 ? 为什么每计算一个batch,就要调用一次 optimizer.zero_grad()函数

=> 通过清楚历史梯度,就会对计算的历史梯度进行累加。通过这个特性,能变相的实现一个很大的batch数值的训练(因为batch数值越大,训练效果越好)

 with torch.no_grad():  # 上下文管理器

 上下文管理器: 在接下来的计算过程中,不再去计算每个节点的误差损失梯度。

如果不调用这个函数,将会在测试过程中占用更多的算力,消耗更多的资源和占用更多的内存资源,导致内存容易崩。

print函数中打印参数解释:

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
                       (epoch + 1, step + 1, running_loss / 500, accuracy))

epoch + 1:迭代到第几轮了

step + 1:某一轮的第几步

running_loss / 500:训练过程中500步平均训练误差

accuracy:准确率

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1504923.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多维时序 | Matlab实现BiGRU-Mutilhead-Attention双向门控循环单元融合多头注意力机制多变量时序预测

多维时序 | Matlab实现BiGRU-Mutilhead-Attention双向门控循环单元融合多头注意力机制多变量时序预测 目录 多维时序 | Matlab实现BiGRU-Mutilhead-Attention双向门控循环单元融合多头注意力机制多变量时序预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.多维时序 …

C++笔记之给枚举类型的变量赋值

C++笔记之给枚举类型的变量赋值 —— 杭州 2024-03-10 code review! 在C++中,你可以在结构体内部定义一个枚举类型,并在创建结构体变量时给枚举类型的变量赋值。下面是一个简单的例子展示了如何做到这一点: 代码 #include <iostream>// 定义结构体 struct MyStru…

虚函数与纯虚函数有什么区别?

总的来说有两点区别&#xff1a; 1.虚函数的作用主要是矫正指针&#xff08;口语化的说法&#xff09; 2.虚函数不一定要重新定义&#xff0c;纯虚函数一定要定义&#xff08;口语化的说法&#xff09; 1&#xff09;. 虚函数的作用主要是矫正指针&#xff0c;使得基类的指针…

ArrayDeque集合源码分析

ArrayDeque集合源码分析 文章目录 ArrayDeque集合源码分析一、字段分析二、构造函数分析方法、方法分析四、总结 实现了 Deque&#xff0c;说面该数据结构一定是个双端队列&#xff0c;我们知道 LinkedList 也是双端队列&#xff0c;并且是用双向链表 存储结构的。而 ArrayDequ…

深入探讨AI团队的角色分工

目录 前言1. 软件工程师&#xff1a;构建系统基石的关键执行者2. 机器学习工程师&#xff1a;数据与模型的塑造专家3. 机器学习研究员&#xff1a;引领算法创新的智囊4. 机器学习应用科学家&#xff1a;理论与实践的巧妙连接5. 数据分析师&#xff1a;洞察数据&#xff0c;智慧…

如何在Linux上为PyCharm创建和配置Desktop Entry

在Linux操作系统中&#xff0c;.desktop 文件是一种桌面条目文件&#xff0c;用于在图形用户界面中添加程序快捷方式。本文将指导您如何为PyCharm IDE创建和配置一个 .desktop 文件&#xff0c;从而能够通过应用程序菜单或桌面图标快速启动PyCharm。 步骤 1: 确定PyCharm安装路…

力扣中档题:删除排序链表中的 重复元素

此题可以选择暴力解决&#xff0c;首先将链表中的元素放到数组中&#xff0c;然后将数组中的重复元素放到另一个数组中&#xff0c;最后再判断并将目标值放到第三个数组中排序再改链表&#xff0c;注意链表nextNULL的操作 struct ListNode* deleteDuplicates(struct ListNode*…

【Nestjs实操】环境变量和全局配置

一、环境变量 1、使用dotenv 安装pnpm add dotenv。 根目录下创建.env文件&#xff0c;内容如下&#xff1a; NODE_ENVdevelopment使用 import {config} from "dotenv"; const path require(path); config({path:path.join(__dirname,../.env)}); console.log(…

简介:基于 OpenTiny 组件库的 rendereless 无渲染组件架构

在 HAE 自研阶段&#xff0c;我们实现的数据双向绑定、面向对象的 JS 库、配置式开发的注册表等特性&#xff0c;随着前端技术的高速发展现在已经失去存在的意义&#xff0c;但是在 AUI 阶段探索的新思路新架构&#xff0c;经过大量的业务落地验证&#xff0c;再次推动前端领域…

万用表数据导出变化曲线图——pycharm实现视频数据导出变化曲线图

万用表数据导出变化曲线图——pycharm实现视频数据导出变化曲线图 一、效果展示二、环境配置三、代码构思四、代码展示五、代码、python环境包链接 一、效果展示 图1.1 效果展示 &#xff08;左图&#xff1a;万用表视频截图&#xff1b;右图&#xff1a;表中数据变化曲线图&am…

宽度优先搜索算法(BFS)

宽度优先搜索算法&#xff08;BFS&#xff09;是什么&#xff1f; 宽度优先搜索算法&#xff08;BFS&#xff09;&#xff08;也称为广度优先搜索&#xff09;主要运用于树、图和矩阵&#xff08;这三种可以都归类在图中&#xff09;&#xff0c;用于在图中从起始顶点开始逐层…

字节跳动的 SDXL-LIGHTNING : 体验飞一般的文生图

TikTok 的母公司字节跳动推出了最新的文本到图像生成人工智能模型&#xff0c;名为SDXL-Lightning。顾名思义&#xff0c;这个新模型只需很轻量的推理步骤&#xff08;1&#xff0c;4 或 8 步&#xff09;即可实现极其快速且高质量的文本到图像生成功能。与原始 SDXL 模型相比&…

嵌入式 Linux 学习

在学习嵌入式 Linux 之前&#xff0c;我们先来了解一下嵌入式 Linux 有哪些东西。 1. 嵌入式 Linux 的组成 嵌入式 Linux 系统&#xff0c;就相当于一套完整的 PC 软件系统。 无论你是 Linux 电脑还是 windows 电脑&#xff0c;它们在软件方面的组成都是类似的。 我们一开电…

.NET高级面试指南专题十六【 装饰器模式介绍,包装对象来包裹原始对象】

装饰器模式&#xff08;Decorator Pattern&#xff09;是一种结构型设计模式&#xff0c;用于动态地给对象添加额外的职责&#xff0c;而不改变其原始类的结构。它允许向对象添加行为&#xff0c;而无需生成子类。 实现原理&#xff1a; 装饰器模式通过创建一个包装对象来包裹原…

【数据可视化】动手用matplotlib绘制关联规则网络图

下载文中数据、代码、绘图结果 文章目录 关于数据绘图函数完整可运行的代码运行结果 关于数据 如果想知道本文的关联规则数据是怎么来的&#xff0c;请阅读这篇文章 绘图函数 Python中似乎没有很方便的绘制网络图的函数。 下面是本人自行实现的绘图函数&#xff0c;如果想…

【深度学习笔记】6_9 深度循环神经网络deep-rnn

注&#xff1a;本文为《动手学深度学习》开源内容&#xff0c;部分标注了个人理解&#xff0c;仅为个人学习记录&#xff0c;无抄袭搬运意图 6.9 深度循环神经网络 本章到目前为止介绍的循环神经网络只有一个单向的隐藏层&#xff0c;在深度学习应用里&#xff0c;我们通常会用…

three.js如何实现简易3D机房?(四)点击事件+呼吸灯效果

接上一篇&#xff1a; three.js如何实现简易3D机房&#xff1f;&#xff08;三&#xff09;显示信息弹框/标签&#xff1a;http://t.csdnimg.cn/5W2wA 目录 八、点击事件 1.实现效果 2.获取相交点 3.呼吸灯效果 4.添加点击事件 5.问题解决 八、点击事件 1.实现效果 2.…

ChatGPT发不出消息?GPT发不出消息怎么办?

前言 今天发现&#xff0c;很多人的ChatGPT无法发送信息&#xff0c;我就登陆看一下自己的GPT的情况&#xff0c;结果还真的无法发送消息&#xff0c;ChatGPT 无法发送消息&#xff0c;但是能查看历史的对话&#xff0c;不过通过下面的方法解决了。 第一时间先打开官方的网站&a…

Mint_21.3 drawing-area和goocanvas的FB笔记(七)

FreeBASIC gfx 基本 graphics 绘图 8、ScreenControl与屏幕窗口位置设置 FreeBASIC通过自建屏幕窗口摆脱了原来的屏幕模式限制&#xff0c;既然是窗口&#xff0c;在屏幕坐标中就有它的位置。ScreenControl GET_WINDOW_POS x, y 获取窗口左上角的x, y位置&#xff1b;ScreenC…

【REST2SQL】11 基于jwt-go生成token与验证

【REST2SQL】01RDB关系型数据库REST初设计 【REST2SQL】02 GO连接Oracle数据库 【REST2SQL】03 GO读取JSON文件 【REST2SQL】04 REST2SQL第一版Oracle版实现 【REST2SQL】05 GO 操作 达梦 数据库 【REST2SQL】06 GO 跨包接口重构代码 【REST2SQL】07 GO 操作 Mysql 数据库 【RE…