pytorch深度学习实战lesson28

news2025/7/14 12:59:24

第二十八课 resnet的梯度计算(如何缓解梯度问题)

沐神说:“假设你在卷积神经网络里面,只要了解一个神经网络的话,你就了解 rest net 就行了。 rest net 是一个很简单的也是很好用的一个网络。这也是大家会经常在实际中使用的一个网络。”

目录

理论部分

实践部分


理论部分

如果不断地去加深神经网络,其实不一定会带来好处吗的。举个例子,假设上图的星星是最优值的位置,然后F 1是一个函数,可以认为函数区域的大小代表了这个函数的复杂程度。按照左图来看,就是模型越复杂,学得越偏,就是说虽然f6这个模型更加复杂了,但是实际上你学偏了,还不如一个小模型和最优离得近呢。

那么应该怎么做呢?就是如果每一次增加模型的复杂度,每一次那个更复杂的模型包含前面的小模型的话,模型就不会变差。

具体来讲就是不光要新加网络层,还要引出一条残差连接与新加层的结果相加。

    上图是resnet的具体设计细节。

残差连接的加入位置可以进行调换。

整个restnet的架构:restnet最核心的就是通过残差连接加入了一个加法。就是说就它有两种 resnet block 第一种是高宽减半的 resnet block ,所谓的高宽减半,就是说在第一个卷积层里面幅度等于2,就等于是把高宽减半了。然后通常来说也会把通道数增加一倍。然后通过调整 rest net 块以及它的输出通道数,可以得到不同的 resnet 的架构。

再回顾一下 resnet 是怎么来处理梯度消失,使得能够训练1000层的样子。其实resnet的基本思想就是将乘法变加法。

蓝字和紫字是没有残差连接的网络,就是说在之后的更新时,会出现梯度消失的现象;但是如绿色字所示,加上残差连接后,就它的梯度的计算方法从原来的乘法改成了加法,这就使得梯度不会急剧减小,可以保持层数多的情况下的梯度。

实践部分

#残差网络(ResNet)
#残差块
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import matplotlib.pyplot as plt
class Residual(nn.Module):
    '''第一个卷积层的话你可以直接那个 strat 第二个第二个就是 strata 就是不变了。然后你的 padding 因为你的 kernel 数都是等于,
    我就把这个回撤回下来。 kernel 数等于3,你 padding 等于1,就是高宽不变。就第一个是可以指定 strat 你可以等于2剩下的就是说这一个是不会让你指定 strat 它就等默认等于1了。
    如果你要使用1乘1的卷积层的话,我会再构造一个空3出来,它就是一个,就是它会把你的 input channel 变成你的 output channel 就1如果你这两个东西不一样的话,
    所以这个东西是必须的,所以它把你 input 那个 channel 数给变换到 output channel 数克动差异等于1, thread 也会是等于你要的那个 thread 这样子你能够 match 到你的高宽。'''

    '''ResNet沿用了VGG完整的卷积层设计。 残差块里首先有2个有相同输出通道数的卷积层。 
    每个卷积层后接一个批量规范化层和ReLU激活函数。 然后我们通过跨层数据通路,跳过这2个卷积运算,
    将输入直接加在最后的ReLU激活函数前。 这样的设计要求2个卷积层的输出与输入形状一样,
    从而使它们可以相加。 
    如果想改变通道数,就需要引入一个额外的卷积层来将输入变换成需要的形状后再做相加运算。 '''
    def __init__(self, input_channels, num_channels, use_1x1conv=False,strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3,padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3,padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)
#输入和输出形状一致
blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
print(Y.shape)
#增加输出通道数的同时,减半输出的高和宽
blk = Residual(3, 6, use_1x1conv=True, strides=2)
print(blk(X).shape)
#ResNet模型
'''ResNet的前两层跟之前介绍的GoogLeNet中的一样: 在输出通道数为64、步幅为2的7*7卷积层后,
接步幅为2的3*3的最大汇聚层。 不同之处在于ResNet每个卷积层后增加了批量规范化层。'''
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
'''GoogLeNet在后面接了4个由Inception块组成的模块。 ResNet则使用4个由残差块组成的模块,
每个模块使用若干个同样输出通道数的残差块。 第一个模块的通道数同输入通道数一致。 
由于之前已经使用了步幅为2的最大汇聚层,所以无须减小高和宽。 
之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。
下面我们来实现这个模块。注意,我们对第一个模块做了特别处理。'''
def resnet_block(input_channels, num_channels, num_residuals,first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True,strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk
#接着在ResNet加入所有残差块,这里每个模块使用2个残差块。
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5, nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.Linear(512, 10))
#最后,与GoogLeNet一样,在ResNet中加入全局平均汇聚层,以及全连接层输出。
#观察一下ResNet中不同模块的输入形状是如何变化的
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)
#训练模型
lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()
#每个模块有4个卷积层(不包括恒等映射的卷积层)。 加上第一个卷积层和最后一个全连接层,共有18层。
# 因此,这种模型通常被称为ResNet-18。 通过配置不同的通道数和模块里的残差块数可以得到不同的ResNet模型,
# 例如更深的含152层的ResNet-152。 虽然ResNet的主体架构跟GoogLeNet类似,
# 但ResNet架构更简单,修改也更方便。这些因素都导致了ResNet迅速被广泛使用。

torch.Size([4, 3, 6, 6])
torch.Size([4, 6, 3, 3])
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 128, 28, 28])
Sequential output shape:     torch.Size([1, 256, 14, 14])
Sequential output shape:     torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:     torch.Size([1, 512, 1, 1])
Flatten output shape:     torch.Size([1, 512])
Linear output shape:     torch.Size([1, 10])
training on cuda:0
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
。。。。
loss 0.010, train acc 0.998, test acc 0.820
2106.3 examples/sec on cuda:0

进程已结束,退出代码0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/38252.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenCV-Python小应用(六):车道线检测

OpenCV-Python小应用&#xff08;六&#xff09;&#xff1a;车道线检测前言前提条件实验环境基于霍夫变换的车道线检测参考文献前言 本文是个人使用OpenCV-Python的应用案例&#xff0c;由于水平有限&#xff0c;难免出现错漏&#xff0c;敬请批评改正。更多精彩内容&#xff…

【成为红帽工程师】第五天 NFS服务器

目录 一、NFS服务器简介 二、NFS的使用 三、客户端使用autofs自动挂载 四、相关实验 一、NFS服务器简介 NFS&#xff08;网络文件系统&#xff09;&#xff0c;是FreeBSD支持的文件系统中的一种&#xff0c;它允许网络中的计算机&#xff08;不同的计算机、不同的操作系统&…

Go学习之路:流程控制语句:for、if、else、switch 和 defer(DAY 1)

文章目录前引流程控制语句&#xff1a;for、if、else、switch 和 defer1.1、for循环语句/语法格式&#xff08;一&#xff09;1.2、for循环语句/省略前置后置语句&#xff08;二&#xff09;1.3、for循环语句/while&#xff08;三&#xff09;1.4、for循环语句/无限循环&#x…

美新科技过会:收入依赖美国、产能利用率低,林东亮等均为香港籍

11月25日&#xff0c;深圳证券交易所创业板披露的信息显示&#xff0c;美新科技股份有限公司&#xff08;下称“美新科技”&#xff09;获得上市委会议通过。据贝多财经了解&#xff0c;美新科技于2022年3月31日在创业板递交上市申请。 本次冲刺创业板上市&#xff0c;美新科技…

SpringCloudGateway--谓词(断言)

目录 一、定义 二、谓词使用 1、After 2、Before 3、Between 4、Cookie 5、Header 6、Host 7、Method 8、Path 9、Query 10、RemoteAddr 11、Weight 一、定义 SpringCloudGateway中三个重要词汇&#xff1a; 路由&#xff08;Route&#xff09;&#xff1a;配置网…

傻白入门芯片设计,芯片键合(Die Bonding)(四)

目录 一、键合( Bonding) 1. 什么是键合(Bonding)&#xff1f; 2. 芯片键合步骤 3&#xff0e;芯片拾取与放置(Pick & Place) 4. 芯片顶出(Ejection)工艺 5. 使用环氧树脂(Epoxy)实现粘合的芯片键合工艺 6. 使用晶片黏结薄膜&#xff08;DAF&#xff09;的芯片键合工…

Redis实战篇(三)秒杀

一、全局唯一ID &#xff08;1&#xff09;定义 全局ID生成器&#xff0c;是一种在分布式系统下用来生成全局唯一ID的工具&#xff0c;一半满足下列特性&#xff1a; 唯一性高可用高性能递增性安全性 为了增加ID的安全性&#xff0c;我们不直接使用Redis自增的数值&#xf…

OpenCV-Python快速入门(十五):霍夫变换

OpenCV-Python快速入门&#xff08;十五&#xff09;&#xff1a;霍夫变换前言前提条件实验环境霍夫变换基本原理霍夫直线变换&#xff08;cv2.HoughLines()&#xff09;概率霍夫变换&#xff08;cv2.HoughLinesP()&#xff09;霍夫圆变换&#xff08;cv2.HoughCircles()&#…

移动端测试理论

App测试基础 App功能测试及专项测试 前言: 对于APP项目的测试&#xff0c;一般是进行系统测试。 测试主要从业务功能和非业务功能两个方面考虑。业务功能测试 根据软件说明&#xff0c;设计文档或用户需求验证App的各个功能的实现。 专项测试 兼容性测试 兼容性测试的关注点…

阿里Java研发面经(已拿offer)

一、自我总结&#xff1a; 1&#xff09;首先最重要的一点。对自己的要求高点。不要以简单的实习生来要求自己。你要想 你会的别人都会 你的核心竞争力是什么呢。所以楼主建议以Java高级工程师来要求自己。不会的就学嘛。人面对未知的事物 本能反应是恐惧与退缩。可当你尝试去…

Xxl-Job 初次体验

Xxl-Job 初次体验一、定时任务-前置知识二、演变机制三、xxl-Job 设计思想四、xxl-job 实战1. 调度中心部署2. 编写执行器简单使用一下2.1. 让执行器run起来&#xff01;2.2. 在调度中心配置任务&#xff0c;调度一下&#xff01;3. XxlJob 任务的生命周期4. 路由策略5. 父子任…

汇川PLC编程软件AutoShop的使用

文章目录一、数据类型二、系统参数.1、内存容量设置2、“掉电保持范围”设置3、系统设置三、符号表1、编辑符号表2、符号表的打印四、元件监控表1、新建元件监控表2、编辑元件监控表3、复制元件监控表4、快速监控表五、元件使用信息表六、交叉引用表七、软元件内存表1、新建和复…

windows的小米11真机appium微信爬虫

1、下载appium 仓库地址 2、下载python的包 pip install Appium-Python-Client -i https://pypi.tuna.tsinghua.edu.cn/simple 3、下载android-sdk 先下SDK Tools 国内一个镜像网站 参考这个教程 安装好后&#xff0c;运行这个SDK Manager.exe 然后install&#xff0c;同意协…

Alphalens使用方法细节判断

首先alphalens的数据格式&#xff1a; factor: MultiIndex&#xff08;用stack()方法来转换&#xff09; prices: DataFrame #转换成MultiIndex factor alpha_mom.stack() print (factor.tail()) datetime 2017-11-20 15:00:00 601857.XSHG 1…

小学生python游戏编程arcade----烟花粒子

小学生python游戏编程arcade----烟花粒子前言烟花粒子1、Vector向量类1.1 arcade中的向量类1.2 应用2、绘制粒子所有纹理图片2.1 给定直径和颜色的圆的纹理2.2 arcade.make_circle_texture函数原码2.3 make_soft_circle_texture 函数原码2.4 公共纹理代码3 效果图4 代码源码获取…

【读点论文】Densely Connected Convolutional Networks用残差连接大力出奇迹,进一步叠加特征图,以牺牲显存为代价

Densely Connected Convolutional Networks Abstract 如果卷积网络在靠近输入的层和靠近输出的层之间包含较短的连接&#xff0c;则卷积网络可以训练得更深入、更准确和有效。在本文中&#xff0c;接受了这种观察&#xff0c;并介绍了密集卷积网络(DenseNet)&#xff0c;它以…

Linux - Linux下Java安装路径查找;配置Java环境变量

一、查看Java的安装路径 1、已经安装好了JDK&#xff0c;也配置了环境变量 1、执行 java -version java -version 出现了版本号&#xff0c;表示安装过了JDK&#xff0c;配置了环境变量 2、在配置过jdk的情况下&#xff0c;执行java -verbose指令&#xff0c;在打印出的文本…

java stream中的peek()用法

文章目录前言最终操作&#xff08;terminal operation&#xff09;peek() vs forEach()peek() 的典型用法&#xff1a;协助调试总结前言 最近看到一段代码&#xff1a; aeFormList.stream().peek(object -> saveInfomation(object, params)).collect(Collectors.toList())…

std::shared_ptr(基础、仿写、安全性)

目录 一、c参考手册 1、解释说明 2、代码示例 3、运行结果 二、对std::shared_ptr分析 1、shared_ptr基础 2、创建shared_ptr实例 3、访问所指对象 4、拷贝和赋值操作 5、检查引用计数 三、仿写std::shared_ptr代码 1、单一对象 2、数组对象 四、shared_ptr遇到问…

MyBatis 环境搭建

MyBatis 环境搭建步骤 1.创建一张表和表对应的实体类 2.创建一个 maven 项目&#xff0c;把项目添加到 git 仓库 创建maven项目 教程见&#xff1a;Maven[项目构建工具]_chen☆的博客-CSDN博客 添加到git仓库&#xff1a; 3.在文件 pom.xml 添加 mybiatis 相关依赖(导入 MyBa…