现代卷积神经网络(ResNet)

news2025/7/16 3:38:34

专栏:神经网络复现目录


本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet)。
文章部分文字和代码来自《动手学深度学习》

文章目录

  • 残差网络(ResNet)
  • 恒等变换
  • 跳跃连接
  • 残差块
  • ResNet模型
    • 结构
    • 实现
      • 残差块
      • ResNet
  • 利用ResNet50进行CIFAR10分类
    • 数据集
    • 损失函数优化器
    • 训练
    • 可视化


残差网络(ResNet)

残差网络(Residual Network,简称 ResNet)是由微软研究院于 2015 年提出的一种深度卷积神经网络。它的主要特点是在网络中添加了“残差块”(Residual Block),有效地解决了深层网络的梯度消失和梯度爆炸问题,从而使得更深的网络结构可以训练得更好。

ResNet 的核心思想是学习残差,即在训练过程中让神经网络学习一个残差映射,该映射将输入直接映射到输出,即 F ( x ) = H ( x ) − x F(x) = H(x) - x F(x)=H(x)x。这里 x x x 表示残差块的输入, H ( x ) H(x) H(x) 表示残差块的输出。在传统的网络结构中,网络的每一层都需要学习一个映射函数,即 H ( x ) H(x) H(x),但这种方法会出现梯度消失和梯度爆炸问题。而使用残差块可以使网络只学习残差部分,即 F ( x ) F(x) F(x),这使得训练更加容易,并且可以构建更深的网络结构。

ResNet 的基本结构是残差块,其中包含了跨层连接(skip connection)的机制。跨层连接可以将输入直接传递到输出端,从而避免了梯度消失问题,同时也减轻了梯度爆炸问题。除了跨层连接之外,ResNet 还采用了批量归一化(Batch Normalization)和池化操作等技巧来提高网络的训练效率和泛化能力。

总体上来说,ResNet 是一种十分有效的深度神经网络结构,其在许多计算机视觉任务上都取得了优异的表现,例如图像分类、物体检测和语义分割等

恒等变换

恒等变换(Identity Transformation)指的是一种变换,使得输入和输出完全相同。在数学上,恒等变换可以用一个函数f(x) = x来表示。

在深度学习中,恒等变换通常用于残差块(Residual Block)中。在残差块中,恒等变换被用作跳跃连接(Shortcut Connection),将输入直接传递给输出,这样可以加速梯度的传播和网络的训练。

举个例子,假设有一个残差块的输入x和输出y,它们的维度都为d。那么该残差块可以表示为:

y = f(x) + x

其中,f(x)是残差块的变换,它会对输入进行处理。而x则是恒等变换,它使得输入可以直接传递给输出。通过这样的设计,残差块可以保留输入中的有用信息,同时仍然能够对输入进行一定程度的处理,从而提高网络的性能。

跳跃连接

跳跃连接(Skip Connection),也称为残差连接(Residual Connection),是深度神经网络中的一种连接方式,用于解决网络训练过程中梯度消失问题。

在跳跃连接中,网络的某一层的输出不仅会传递给下一层进行计算,还会直接传递到距离当前层较远的层。这样可以使得网络中的信息能够更快地传递和共享,同时也可以减轻梯度消失的问题,使得训练过程更加稳定。

在ResNet网络中,跳跃连接被用于将网络的输入直接连接到卷积层的输出上,形成一个残差块。这样,网络的前向传播就变成了从输入到输出的“路径”加上一个残差块的“跳跃”,即“shortcut”。跳跃连接的使用大大改善了ResNet的性能,使得它在ImageNet图像分类等任务中取得了非常优秀的表现。

残差块

残差块(Residual Block)是深度学习中常用的一种模块,可以用来构建深度神经网络。残差块的主要作用是使得神经网络的训练更加容易,并且能够加速神经网络的收敛。

让我们聚焦于神经网络局部:如图所示,假设我们的原始输入为x,而希望学出的理想映射为f(x)作为上方激活函数的输入,左图虚线框中的部分需要直接拟合出该映射f(x),而右图虚线框中的部分则需要拟合出残差映射f(x)-x。残差映射在现实中往往更容易优化。 以本节开头提到的恒等映射作为我们希望学出的理想映射f(x),我们只需将图中右图虚线框内上方的加权运算(如仿射)的权重和偏置参数设成0,那么f(x)即为恒等映射,实际中,当理想映射f(x)极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。
在这里插入图片描述
在深度神经网络中,很容易出现梯度消失或梯度爆炸的问题,这会导致深度神经网络的训练非常困难。残差块的设计可以缓解这个问题。残差块的主要思想是在网络中增加一条跳跃连接(Shortcut Connection),这条连接可以让输入直接跳过一些层,从而更加容易地传递梯度。具体来说,残差块可以分为如下几个步骤:

  1. 输入x经过一个卷积层,得到特征图y1。
  2. 将y1经过一个Batch Normalization层和ReLU激活函数。
  3. 将y1再次经过一个卷积层,得到特征图y2。
  4. 将y2经过一个Batch Normalization层。
  5. 将输入x和y2相加,得到残差特征图y3。
  6. 将y3再经过一个ReLU激活函数。

在这个过程中,输入x可以看做是一种残差,因为它会直接和特征图y2相加。这个残差块的设计可以让网络更容易地学习残差,从而更好地拟合训练数据。此外,残差块也可以增加网络的深度,从而提升网络的效果。

ResNet模型

结构

ResNet原文中给出了几种基本的网络结构配置,本文以ResNet50为例。

在这里插入图片描述
ResNet50结构详解:

  1. 输入层(Input Layer):输入图像的大小为224 x 224 x 3。

  2. 卷积层(Convolution Layer):7x7的卷积核,步长为2,输出通道为64,padding为3。

  3. 标准化层(Batch Normalization Layer):对每个通道的输出做标准化处理,包括均值和方差。

  4. 激活函数(Activation Layer):使用ReLU激活函数。

  5. 最大池化层(Max Pooling Layer):3x3的池化核,步长为2,padding为1,对每个通道的输出取最大值。

  6. 残差块(Residual Block)1:包含3个卷积层和标准化层。第一个卷积层的卷积核为1x1,输出通道为64;第二个卷积层的卷积核为3x3,输出通道为64;第三个卷积层的卷积核为1x1,输出通道为256(因为残差块的输入和输出通道数不同,需要用1x1的卷积核进行通道变换)。

  7. 残差块(Residual Block)2:包含4个卷积层和标准化层。第一个卷积层的卷积核为1x1,输出通道为128;第二个卷积层的卷积核为3x3,输出通道为128;第三个卷积层的卷积核为1x1,输出通道为512。

  8. 残差块(Residual Block)3:包含6个卷积层和标准化层。第一个卷积层的卷积核为1x1,输出通道为256;第二个卷积层的卷积核为3x3,输出通道为256;第三个卷积层的卷积核为1x1,输出通道为1024。

  9. 残差块(Residual Block)4:包含3个卷积层和标准化层。第一个卷积层的卷积核为1x1,输出通道为512;第二个卷积层的卷积核为3x3,输出通道为512;第三个卷积层的卷积核为1x1,输出通道为2048。

  10. 平均池化层(Average Pooling Layer):使用全局平均池化,将输出的特征图转化为向量。

  11. 全连接层(Fully Connected Layer):将向量连接到最终的输出层,该层包含1000个神经元,每个神经元对应于一个类别,表示图像属于该类别的概率。

  12. Softmax层(Softmax Layer):使用softmax函数将全连接层的输出转化为概率分布,每个类别的概率为0到1之间的实数,概率之和为1。

总结:ResNet50网络结构由多个残差块组成,每个残差块内部包含多个卷积层和标准化层。通过使用残差学习的方法,ResNet50网络能够在训练深度神经网络时解决梯度消失和梯度爆炸的问题,同时在图像分类等任务中表现出色。

实现

残差块

import torch.nn as nn
import torch.onnx


class Residual(nn.Module):
    def __init__(self, in_channels, channels, stride, downsample=None):
        super(Residual, self).__init__()
        # 1x1的卷积降维操作
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=channels, kernel_size=(1, 1),
                               bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        # 3x3的卷积提取特征操作
        self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3),
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        # 1x1的卷积升维操作
        self.conv3 = nn.Conv2d(in_channels=channels, out_channels=channels * 4, kernel_size=(1, 1),
                               bias=False)
        self.bn3 = nn.BatchNorm2d(channels * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        # 4组卷积层的头一层网络会做一次降采样,目的是使out和identity维度一致可以做加法
        if self.downsample is not None:
            self.dconv = nn.Conv2d(in_channels, channels * 4, stride=stride, kernel_size=(1, 1), bias=False)
            self.dbn = nn.BatchNorm2d(channels * 4)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.dconv(identity)
            identity = self.dbn(identity)

        out += identity
        out = self.relu(out)

        return out

ResNet

class Resnet50(nn.Module):
    def __init__(self, num_classes):
        super(Resnet50, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
         # 对应第1组网络层,3*Resnet的基本结构
        self.conv64_1 = Residual(64, 64, stride=1, downsample=True)
        self.conv64_2 = Residual(256, 64, stride=1)
        self.conv64_3 = Residual(256, 64, stride=1)
        # 对应第2组网络层,4*Resnet的基本结构
        self.conv128_1 = Residual(256, 128, stride=2, downsample=True)
        self.conv128_2 = Residual(128 * 4, 128, stride=1)
        self.conv128_3 = Residual(128 * 4, 128, stride=1)
        self.conv128_4 = Residual(128 * 4, 128, stride=1)
        # 对应第3组网络层,6*Resnet的基本结构
        self.conv256_1 = Residual(512, 256, stride=2, downsample=True)
        self.conv256_2 = Residual(256 * 4, 256, stride=1)
        self.conv256_3 = Residual(256 * 4, 256, stride=1)
        self.conv256_4 = Residual(256 * 4, 256, stride=1)
        self.conv256_5 = Residual(256 * 4, 256, stride=1)
        self.conv256_6 = Residual(256 * 4, 256, stride=1)
        # 对应第4组网络层,3*Resnet的基本结构
        self.conv512_1 = Residual(1024, 512, stride=2, downsample=True)
        self.conv512_2 = Residual(512 * 4, 512, stride=1)
        self.conv512_3 = Residual(512 * 4, 512, stride=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv64_1(x)
        x = self.conv64_2(x)
        x = self.conv64_3(x)
        x = self.conv128_1(x)
        x = self.conv128_2(x)
        x = self.conv128_3(x)
        x = self.conv128_4(x)
        x = self.conv256_1(x)
        x = self.conv256_2(x)
        x = self.conv256_3(x)
        x = self.conv256_4(x)
        x = self.conv256_5(x)
        x = self.conv256_6(x)
        x = self.conv512_1(x)
        x = self.conv512_2(x)
        x = self.conv512_3(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

利用ResNet50进行CIFAR10分类

数据集

# 导入数据集
from torchvision import datasets
import torch
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'trunk')
cifar_train = datasets.CIFAR10(root="/data",train=True, download=True, transform=transform)
cifar_test = datasets.CIFAR10(root="/data",train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(cifar_train, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar_test, batch_size=16, shuffle=False)

损失函数优化器

# 定义损失函数和优化器
net=Resnet50(10);
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
epoch = 5
net = net.to(device)
total_step = len(train_loader)
train_all_loss = []
val_all_loss = []

训练

import numpy as np

for i in range(epoch):
    net.train()
    train_total_loss = 0
    train_total_num = 0
    train_total_correct = 0

    for iter, (images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = net(images)
        loss = criterion(outputs,labels)
        train_total_correct += (outputs.argmax(1) == labels).sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))
    net.eval()
    test_total_loss = 0
    test_total_correct = 0
    test_total_num = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = net(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
    print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
        i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100

    ))
    train_all_loss.append(np.round(train_total_loss / train_total_num,4))
    val_all_loss.append(np.round(test_total_loss / test_total_num,4))

可视化

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

plt.figure()
plt.title("Train Loss and Test Loss Curve")
plt.xlabel('plot_epoch')
plt.ylabel('loss')
plt.plot(train_all_loss)
plt.plot(val_all_loss)
plt.legend(['train loss', 'test loss'])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/395990.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【微信小程序】-- 生命周期(二十八)

💌 所属专栏:【微信小程序开发教程】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! &…

VRRP多网关负载分担实验

1、VRRP专业术语 VRRP备份组框架图如图14-1所示: 图14-1:VRRP备份组框架图 VRRP路由器(VRRP Router):运行VRRP协议的设备,它可能属于一个或多个虚拟路由器,如SwitchA和SwitchB。虚拟路由器(Virtual Router):又称VRR…

Windows安装Qt与VS2019添加QT插件

一、通过Qt安装包方式http://download.qt.io/archive/qt/5.12/5.12.3/.安装可以就选中这个MSVC 2017 64-bit,其他就暂时不用了二、通过vs2019安装Qt插件方式方法1下面这种方式本人安装不起来,一直卡住下不下来。拓展->管理拓展->联机->搜索Qt&a…

【计算机视觉 自然语言处理】什么是多模态?

文章目录一、多模态的定义二、多模态的任务2.1 VQA(Visual Question Answering)视觉问答2.2 Image Caption 图像字幕2.3 Referring Expression Comprehension 指代表达2.4 Visual Dialogue 视觉对话2.5 VCR (Visual Commonsense Reasoning) 视觉常识推理…

让你眼前一亮的不是流行软件,而是这五款小众软件

让你眼前一亮的软件,不一定是市面上最流行的。今天,我将推荐给你五款非常小众,但是十分好用的软件。它们功能强大,使用起来也非常方便,而且经过我个人的测试,确保质量有保障。如果你用完后觉得不好用&#…

Java VisualVM 安装 Visual GC 插件图文教程

文章目录1. 通过运行打开 Java VisualVM 监控工具2. 菜单栏初始视图说明3. 工具插件菜单说明4. 手工安装插件5. 重启监控工具查看 Visual GC1. 通过运行打开 Java VisualVM 监控工具 首先确保已安装 Java 环境,如此处安装版本 JDK 1.8.0_161 C:\Users\niaonao>j…

从零开始学GeoServer源码十一(如何解决No Multipart-config for Servlet错误)

目录前言1.现象2.排查问题3.找到问题4.解决问题5.总结前言 本文起源于我们遇到的一个问题,本来 GeoServer 使用的好好的,但是有天突然发现,无法在 GeoServer 中上传样式的 sld 文件了,报错 “No Multipart-config for Servlet” …

【Python安装配置教程】

Python由荷兰数学和计算机科学研究学会的吉多范罗苏姆于1990年代初设计,作为一门叫做ABC语言的替代品。Python提供了高效的高级数据结构,还能简单有效地面向对象编程。Python语法和动态类型,以及解释型语言的本质,使它成为多数平台…

一篇普通的bug日志——bug的尽头是next吗?

文章目录[bug 1] TypeError: method object is not subscriptable[bug 2] TypeError: unsupported format string passed to numpy.ndarray.__format__[bug 3] ValueError:Hint: Expected dtype() paddle::experimental::CppTypeToDataType<T>::Type()[bug 4] CondaSSLE…

javaweb网上宠物商城管理系统分前后台(源码+数据库+开题报告+ppt+文档)

一、 系统运行环境 硬件配置&#xff1a;2.4G以上处理器&#xff0c;4G以上内存&#xff0c;250G以上硬盘空间&#xff1b; 操作系统&#xff1a;Windows XP、Windows 7或更高版本&#xff1b; 数据库&#xff1a;MySQL&#xff1b; Web服务器&#xff1a;Tomcat 7.0&#xff…

标准信号转高电压高电流输出放大转换器0-5v/0-24v转4-20mA/0-500mA

概述导轨安装DIN11HVI 系列模拟信号隔离放大器是一种将输入信号隔离放大、转换成按比例输出的直流信号混合集成厚模电路。产品广泛应用在电力、远程监控、仪器仪表、医疗设备、工业自控等需要直流信号隔离测控的行业。此系列产品内部采用了线性光电隔离技术相比电磁隔离具有更好…

Java中异常(异常的处理方式(JVM默认的处理方式、自己处理(灵魂四问)、抛出异常(throws、throw))、异常中的常见方法、小练习、自定义异常)

编译时异常&#xff1a;在编译阶段&#xff0c;必须要手动处理&#xff0c;否则代码报错&#xff08;提醒程序员检查本地信息&#xff09; 运行时异常&#xff1a;在编译阶段是不需要处理的&#xff0c;是代码运行时出现的异常&#xff08;代码出错而导致程序出现的问题&#…

3D软件开发工具HOOPS 2023 更新亮点合集——增强了对建筑环境和自然环境中3D图形的真实感

HOOPS SDK是全球领先开发商TechSoft 3D旗下的原生产品&#xff0c;专注于Web端、桌面端、移动端3D工程应用程序的开发。长期以来&#xff0c;HOOPS通过卓越的3D技术&#xff0c;帮助全球600多家知名客户推动3D软件创新&#xff0c;这些客户包括SolidWorks、SIEMENS、Oracle、Ar…

Java高级-----多线程

多线程JAVA高级--多线程1、基本概念&#xff1a;程序、进程、线程1.1进程与线程1.2使用多线程的优点1.3何时需要多线程2、线程的创建和使用2.1线程的创建和启动2.2Thread 类2.3API 中创建线程的四种方式2.3.1方式一继承 Thread 类2.3.1.1 步骤2.3.1.2创建过程中的两个问题说明2…

JMU软件20 计算机网络复习

文章目录题型单位换算第一章协议与划分层次、网络协议的三个组成要素&#xff0c;分层的思想等协议网络协议的三个组成要素分层的思想⭐计算机网络体系结构OSI 的七层协议TCP/IP 的四层协议五层协议发送时延、传播时延、总时延、往返时间RTT计算第二章 物理层传输媒体导向性传输…

如何用SaleSmartly完善您的实时聊天页面

众所周知&#xff0c;第一印象在业务中非常重要&#xff0c;需要确保您的网站是可以促进您与客户之间的顺畅联系。想想您访问商家联系页面时&#xff0c;你通常看到什么&#xff1f;可能是用于发送电子邮件的对话框&#xff0c;也可能是要呼叫的电话号码&#xff0c;虽然这是一…

【LeetCode】剑指 Offer(18)

目录 题目&#xff1a;剑指 Offer 35. 复杂链表的复制 - 力扣&#xff08;Leetcode&#xff09; 题目的接口&#xff1a; 解题思路&#xff1a; 代码&#xff1a; 过啦&#xff01;&#xff01;&#xff01; 写在最后&#xff1a; 题目&#xff1a;剑指 Offer 35. 复杂链…

Mysql8.0的特性

Mysql8.0的特性 建议使用8.0.17及之后的版本&#xff0c;更新的内容比较多。 新增降序索引 -- 如下所示&#xff0c;我们可以在创建索引时 在字段名后面指定desc进行降序排序 create table t1(c1 int,c2 int,index idx_c1_c2(c1,c2 desc));group by 不再隐式排序 mysql5.7的版…

使用Chakra-UI封装简书的登录页面组件(React)

要求&#xff1a;使用chakra ui和react 框架将简书的登录页面的表单封装成独立的可重用的组件使用到的API&#xff1a;注册API请求方式&#xff1a;POST 请求地址&#xff1a;https://conduit.productionready.io/api/users请求数据: {"user":{ "username&quo…

typora-beta-0.11.18版本又提示过期的解决方案

很实用&#xff0c;所以照搬一下下面的作者的回答&#xff0c;省得以后再找~~~ 知乎的作者来源如下&#xff1a; 作者&#xff1a;吴小皓 链接&#xff1a;typora打开报错&#xff1a;This beta version of Typora is expired, please download and install a newer version …