pytorch深度学习实战lesson22

news2025/7/12 9:25:00

第二十二课 LeNet 

    LeNet神经网络由深度学习三巨头之一的Yan LeCun提出,他同时也是卷积神经网络 (CNN,Convolutional Neural Networks)之父。

    LeNet主要用来进行手写字符的识别与分类,并在美国的银行中投入了使用。LeNet的实现确立了CNN的结构,现在神经网络中的许多内容在LeNet的网络结构中都能看到,例如卷积层,Pooling层,ReLU层。虽然LeNet早在20世纪90年代就已经提出了,但由于当时缺乏大规模的训练数据,计算机硬件的性能也较低,因此LeNet神经网络在处理复杂问题时效果并不理想。虽然LeNet网络结构比较简单,但是刚好适合神经网络的入门学习。

目录

理论部分

实践部分


理论部分

       Lenet最早是用来处理手写数字识别用的,随之出名的是mnist数据集。

Lenet的具体结构:

       如图,LeNet-5(因为有5个卷积层而得名)是由卷积层、池化层、全连接层的顺序连接,网络中的每个层使用一个可微分的函数将激活数据从一层传递到另一层。LeNet-5开创性的利用卷积从直接图像中学习特征,在计算性能受限的当时能够节省很多计算量,同时也指出卷积层的使用可以保证图像的空间相关性(也是基于此,之后的一些网络开始慢慢摒弃全连接层,同时全连接层还会带来过拟合的风险)。

       此外,LeNet也是第一个成功的卷积神经网络应用,在当时主要用于识别数字和邮政编码,其用于手写数字识别的训练结果如图所示。

实践部分

代码:

#卷积神经网络(LeNet)
#LeNet(LeNet-5)由两个部分组成: 卷积编码器和全连接层密集块
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
class Reshape(torch.nn.Module):
    def forward(self, x):
        return x.view(-1, 1, 28, 28)#把x弄成批量数不变,通道数为1,像素为28*28
net = torch.nn.Sequential(Reshape(), #先reshape
                          nn.Conv2d(1, 6,kernel_size=5,padding=2),#第一个卷积层:输入为1 ,输出为6,卷积核大小为5*5,填充为2(输入变为32*32)
                          nn.Sigmoid(),#非线性激活函数
                          nn.AvgPool2d(kernel_size=2, stride=2),#均值池化层,池化核2*2,步幅2
                          nn.Conv2d(6, 16, kernel_size=5),#第2个卷积层:输入为6,输出为16,卷积核5*5,不填充
                          nn.Sigmoid(),#非线性激活函数
                          nn.AvgPool2d(kernel_size=2, stride=2),#均值池化
                          nn.Flatten(),#把4维的输出变成1维的向量:第一维(批量)保持住,后面弄成一个维度。
                          nn.Linear(16 * 5 * 5, 120),#(400,120),括号里5的由来:(28+4-4)/2-4)/2
                          nn.Sigmoid(),#非线性激活函数
                          nn.Linear(120, 84),#84是自己凭经验设置的
                          nn.Sigmoid(),#非线性激活函数
                          nn.Linear(84, 10))#10是自己凭经验设置的
#检查模型,假设有也输出,那么本段代码的作用就是输出每层的特征
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape: \t', X.shape)
'''
由填充计算公式:(nh-kh+ph+1)*(nw-kw+pw+1)
由步幅计算公式:[(nh-kh+ph+sh)/sh]*[(nw-kw+pw+1)/sh]
其中,n是原始的维度,k是卷积核维度,p是填充的行数或列数,s是步幅大小。
输出是:
Reshape output shape: 	 torch.Size([1, 1, 28, 28])
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])第一次卷积的输出维度为(28+4-5+1)*(28+4-5+1)=28*28
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])非线性激活不影响维度
AvgPool2d output shape:  torch.Size([1, 6, 14, 14])第一次池化的输出维度为[(28-2+2)/2]*[(28-2+2)/2]=14*14
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])第二次卷积的输出维度为(14-5+1)*(14-5+1)=10*10
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])非线性激活不影响维度
AvgPool2d output shape:  torch.Size([1, 16, 5, 5])第二次池化的输出维度为[(10-2+2)/2]*[(10-2+2)/2]=5*5
Flatten output shape: 	 torch.Size([1, 400])拉成向量后为[1,16*5*5]
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])
'''
#LeNet在Fashion-MNIST数据集上的表现
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)#加载数据集
#对 evaluate_accuracy函数进行轻微的修改
def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度。"""
    if isinstance(net, torch.nn.Module):#如果使用torch。nn的版本
        net.eval()#会变成评估模式
        if not device:#如果没指定设备的话
            device = next(iter(net.parameters())).device#就把第一个network(网络层)的device拿出来用
    metric = d2l.Accumulator(2)#累加器
    for X, y in data_iter:#对每一个xy
        if isinstance(X, list):#如果是torch.nn模式的话就让每一个xy都应用torch.nn使用的device
            X = [x.to(device) for x in X]
        else:#反之
            X = X.to(device)#就让每一个xy使用第一个网络层的device,意思就是都统一使用第一个网络层的device
        y = y.to(device)
        metric.add(d2l.accuracy(net(X), y), y.numel())#所有Y的个数
    return metric[0] / metric[1]#分类正确的个数除总数得到正确率
#为了使用 GPU,我们还需要一点小改动
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)。"""
    def init_weights(m):#初始化权重
        if type(m) == nn.Linear or type(m) == nn.Conv2d:#如果是线性层或卷积层
            nn.init.xavier_uniform_(m.weight)#就用定义好的初始化函数,根据输入输出大小,当进行随机初始化的时候,方差之类的数据差不多,
                                             #防止模型一开始梯度爆炸或梯度清零。
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()#计算梯度
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:#打印信息
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')
#训练和评估LeNet-5模型
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

Reshape output shape:      torch.Size([1, 1, 28, 28])
Conv2d output shape:      torch.Size([1, 6, 28, 28])
Sigmoid output shape:      torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:      torch.Size([1, 16, 10, 10])
Sigmoid output shape:      torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:      torch.Size([1, 400])
Linear output shape:      torch.Size([1, 120])
Sigmoid output shape:      torch.Size([1, 120])
Linear output shape:      torch.Size([1, 84])
Sigmoid output shape:      torch.Size([1, 84])
Linear output shape:      torch.Size([1, 10])
training on cuda:0
<Figure size 350x250 with 1 Axes>*100
loss 0.466, train acc 0.825, test acc 0.817
43232.2 examples/sec on cuda:0

进程已结束,退出代码0
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/18150.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

KKFileView在线预览禁用复制右键图片保存等操作

KKFileView在线预览禁用复制右键图片保存等操作一、需求背景二、修改kkFileview1.docx、doc文档不可复制、F12、右键、打印限制问题2.图片限制拖拽处理3.限制Excel转换后复制等操作4.PDF模式禁用右上角菜单栏一、需求背景 公司的运营平台&#xff0c;管理了一些如合同等内容&a…

【记录】软件自动修复工具Jaid配置、调试、运行及相关问题的解决方案

文章目录1. 前言2. Jaid原论文3. 环境4. 配置5. 调试6. 碰到的问题7. 一些发现8. 小结9. 参考文献1. 前言 创作开始时间&#xff1a;2022年11月18日20:50:38 如题&#xff0c;本文记录一下软件自动修复工具Jaid配置、调试、运行及相关问题的解决方案。 2. Jaid原论文 【ASE…

Jenkins+Docker+SVN实现SpringBoot项目半自动化部署

起因&#xff1a;入职后公司需要的技能&#xff0c;全部项目都使用的JenkinsDocker部署 Jenkins详细教程&#xff1a;知乎大佬写的文章 Docker详细教程&#xff1a;一个大佬的博客 SVN使用教程&#xff1a;一个大佬的博客 深入了解Jenkins、Docker、SVN&#xff0c;去上面三个大…

图书管理系统【java】

目录 &#x1f947;1.设计背景 &#x1f50e;2.设计思路 &#x1f511;3.book包 &#x1f4d7;3.1 Book类的实现 &#x1f4d5;3.2 BookList类的实现(书架) &#x1f511;4.user包 &#x1f4d9;4.1 User类的实现 &#x1f4d2;4.2 AdminUser&#xff08;管理员&#x…

MySQL导出csv数据文件

之前使用MySQL导出过一次线上数据&#xff0c;当时解决了乱码和数据没有正常分隔的问题。 参见这篇文章: 记一次“曲折“的MySQL数据导出 前几个月换了工作电脑&#xff0c;这几天又需要导出几十万的线上数据&#xff0c;在导出过程中还是出现了一些问题&#xff0c;再记录一…

OpenWrt 固件编译教程

一、编译环境准备 编译平台 阿里云 Ubuntu 20.04.5 LTS 安装编译环境依赖 sudo apt-get -y install build-essential asciidoc binutils bzip2 gawk gettext git libncurses5-dev libz-dev patch python3 python2.7 unzip zlib1g-dev lib32gcc1 libc6-dev-i386 subversion f…

ICME 会议介绍

官网翻译来的&#xff0c;具体内容还是看官网&#xff1a;IEEE ICME23 Author Information and Submission Instructions 目录 常会和特别会议 研讨会 行业/应用文件 演示 一般信息 示例文件、格式化指南和模板 电子论文提交 提交论文的分步说明 感兴趣的主题包括但不…

右键发送到菜单+批处理实现批量自动化为文件名添加统一的后缀

WinR打开运行&#xff0c;并输入shell:sendto&#xff0c;打开系统右键发送到菜单的所在文件夹。 新建记事本文档&#xff0c;修改文件名为公开.bat&#xff0c;编辑并保存如下内容&#xff1a; echo off :loop if not "%~1" "" (ren "%~1" &…

[附源码]SSM计算机毕业设计在线二手车交易信息管理系统JAVA

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Linux系统配置 Samba客户端

配置 Samba客户端 1.Windows 客户端访问 samba 共享 无论Samba共享服务是部署在Windows系统上&#xff0c;还是部署在Linux系统上&#xff0c;通过Windows系统进行访问时&#xff0c;其步骤和方法都是一样的。下面假设Samba共享服务部署在Linux系统上&#xff0c;并通过Wind…

11返场钜惠,格式转换、图片/视频压缩免费小技巧

&#x1f4e3; 话不多说&#xff0c;直接上干货&#xff01; &#x1f525; 11返场钜惠&#xff0c;牛学长转码大师免费送&#xff01;&#xff01;&#xff01;&#x1f525; 那么牛学长转码大师能帮助您些什么呢&#xff1f;一起看看吧~ 一、格式转换 作为一款专业的格式…

Redis集群部署的三种模式

一、Redis简介 Redis 是一款完全开源免费、遵守BSD协议的高性能(NOSQL)的key-value数据库。它使用ANSI C语言编写&#xff0c;支持网络、可基于内存亦可持久化的日志型、Key-Value数据库&#xff0c;并提供多种语言的API。 Redis的使用场景有如下一些&#xff1a; 读写效率要…

计算机网络复习——第四章网络层

9月开始学习的一个月&#xff0c;I hope everthing be fine. 相关知识见&#xff0c;感觉比较容易入手 《计算机网络》&#xff08;谢希仁&#xff09;内容总结 | JavaGuide 重点知识&#xff1a; TCP/IP 协议中的网络层向上只提供简单灵活的&#xff0c;无连接的&#xff…

[附源码]SSM计算机毕业设计在线购物商城JAVA

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

C/C++多进程高并发框架分享【内附可执行源码注释完整】

文章目录&#x1f680;前言&#x1f34e;源码分享&#x1f382;总结&#x1f680;前言 多进程高并发的设计的思想体现在&#xff1a;电脑物理CPU有多少个核&#xff08;core&#xff09;就创建多少个子进程&#xff0c;并且把各个子进程平均分配到各个核&#xff08;core&…

【JavaSE】多态、抽象类

文章目录1. 向上转型2. 重写3. 多态4. 向下转型5. 抽象类1. 向上转型 我们来看看以下程序 class Animal {public String name;public int age;public void eat() {System.out.println("父类的方法");} } class Cat extends Animal {public String hire;public void…

Birdboot第六天 jar包 数据库

实际应用birdboot框架 1.BirdBoot导包 1.新建maven BirdBoot------pom替换 2.删掉static 和 Springboot里面写的&#xff08;controller entity&#xff09; rebuild之后把无用的导包都删掉 主启动类里面把main方法改为run方法 传参&#xff1a;类名和参数&#xff08;复制sp…

动力节点索引优化解决方案学习笔记——索引介绍

1.索引介绍 1.1什么是MySQL的索引 MySQL官方对于索引的定义&#xff1a;索引是帮助MySQL高效获取数据的数据结构。 MySQL在存储数据之外&#xff0c;数据库系统中还维护着满足特定查找算法的数据结构&#xff0c;这些数据结构以某种引用(指向)表中的数据&#xff0c;这样我们…

决策树算法

目录 ​分类算法 决策树算法 外卖订餐决策树 分支处理 分类算法 分类算法是利用训练样本集获得分类函数即分类模型(分类器)&#xff0c;从而实现将数据集中的样本划分到各个类中。分类模型通过学习训练样本中属性集与类别之间的潜在关系&#xff0c;并以此为依据对新样本属…

测试基础——数据库及数据库表的SQL操作(了解即可)

目录 1.数据库基础概念 2.SQL介绍 3.MySQL介绍 4.数据库连接工具Navicat 5.数据类型 6.约束 7.对数据库操作的SQL语句 7.1创建数据库 7.2使用/打开/切换数据库 7.3修改数据库 7.4删除数据库 7.5查看所有数据库 7.6数据库备份 8.数据库表操作的SQL语句 8.1创建数据…