ResNet残差神经网络的模型结构定义(pytorch实现)

news2025/5/13 11:08:20

ResNet残差神经网络的模型结构定义(pytorch实现)

ResNet‑34

在这里插入图片描述

ResNet‑34的实现思路。核心在于:

  1. 定义残差块(BasicBlock)
  2. _make_layer 方法堆叠多个残差块
  3. 按照 ResNet‑34 的通道和层数配置来搭建网络

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1  # 对于 BasicBlock,输出通道 = base_channels * expansion

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # 第一个 3×3 卷积
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_channels)
        # 第二个 3×3 卷积
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)

        # 如果输入输出通道或下采样不一致,则用 1×1 卷积做一下“shortcut”
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * BasicBlock.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        # 残差连接
        out += self.shortcut(x)
        return F.relu(out)

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        """
        block:       残差块类型(BasicBlock 或 Bottleneck)
        layers:      每个 stage 包含多少个 block,例如 [3, 4, 6, 3] 对应 ResNet‑34
        num_classes: 最后分类数
        """
        super().__init__()
        self.in_channels = 64

        # Stem:7×7 conv + maxpool
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
                               stride=2, padding=3, bias=False)
        self.bn1   = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 四个 stage,通道分别是 [64,128,256,512]
        self.layer1 = self._make_layer(block, 64,  layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 全局平均池化 + 全连接
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc      = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """
        构造一个 stage,由 num_blocks 个 block 组成。
        第一个 block 可能带 stride 下采样,其余 block stride=1。
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, stride=s))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)

        x = self.layer1(x)   # output size  /4
        x = self.layer2(x)   # output size  /8
        x = self.layer3(x)   # output size /16
        x = self.layer4(x)   # output size /32

        x = self.avgpool(x)  # [B, C, 1, 1]
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def resnet34(num_classes=1000):
    """返回一个 ResNet-34 实例"""
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)

关键点解析

  • BasicBlock

    • 两个连续的 3×3 卷积,均附带 BatchNorm 和 ReLU
    • 当通道数或步幅不匹配时,用 1×1 卷积对输入做一下线性变换,才能做元素相加
  • _make_layer

    • 每个 stage 第一个残差块如果要做下采样,则 stride=2;其余都保持 stride=1
    • layers 参数 [3,4,6,3] 精确对应了图中红、粉、灰、蓝四部分每层 block 的数量
  • 整体流程

    1. 7×7、stride=2 下采样 → 最大池化 →
    2. 四个 stage(通道 64→128→256→512,每段下采样一次)→
    3. 全局平均池化 → 全连接分类

这样就完整复现了图中右侧那张“34-layer residual”结构。你可以直接调用 resnet34(),并像下面这样测试一下输出形状:

if __name__ == "__main__":
    model = resnet34(num_classes=1000)
    x = torch.randn(8, 3, 224, 224)
    y = model(x)
    print(y.shape)   # torch.Size([8, 1000])

ResNet‑50

PyTorch 实现 ResNet‑50 。它与 ResNet‑34 唯一不同之处在于使用了 Bottleneck 模块,并且每个 stage 的 block 数量依次为 [3, 4, 6, 3](同 ResNet‑34),但每个 block 内部由三个卷积层组成,expansion 值为 4。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Bottleneck(nn.Module):
    # 输出通道 = base_channels * expansion
    expansion = 4

    def __init__(self, in_channels, base_channels, stride=1):
        super().__init__()
        # 1×1 降维
        self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1,
                               bias=False)
        self.bn1   = nn.BatchNorm2d(base_channels)
        # 3×3 卷积(可能下采样)
        self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(base_channels)
        # 1×1 升维
        self.conv3 = nn.Conv2d(base_channels, base_channels * Bottleneck.expansion,
                               kernel_size=1, bias=False)
        self.bn3   = nn.BatchNorm2d(base_channels * Bottleneck.expansion)

        # shortcut 分支
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != base_channels * Bottleneck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, base_channels * Bottleneck.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(base_channels * Bottleneck.expansion)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return F.relu(out)


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.in_channels = 64

        # Stem:7×7 conv + maxpool
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
                               stride=2, padding=3, bias=False)
        self.bn1   = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 四个 stage
        self.layer1 = self._make_layer(block,  64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 池化 + 全连接
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc      = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, base_channels, num_blocks, stride):
        """
        构造一个 stage,由 num_blocks 个 block 组成。
        第一个 block 可能下采样(stride>1),其余保持 stride=1。
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, base_channels, stride=s))
            self.in_channels = base_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)

        x = self.layer1(x)   # /4
        x = self.layer2(x)   # /8
        x = self.layer3(x)   # /16
        x = self.layer4(x)   # /32

        x = self.avgpool(x)  # [B, C, 1, 1]
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def resnet50(num_classes=1000):
    """返回一个 ResNet-50 实例"""
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)


# 简单测试
if __name__ == "__main__":
    model = resnet50(num_classes=1000)
    x = torch.randn(4, 3, 224, 224)
    y = model(x)
    print(y.shape)  # -> torch.Size([4, 1000])

说明

  • Bottleneck 模块:三个卷积层依次为 1×1 → 3×3 → 1×1,最后一个 1×1 用来恢复维度(乘以 expansion=4)。
  • shortcut 分支:当需下采样(stride=2)或输入输出维度不一致时,使用 1×1 卷积对齐后相加。
  • layers 参数 [3,4,6,3]:分别对应四个 stage 中 Bottleneck block 的个数。

这样就完成了 ResNet‑50 的全结构定义。你可以直接调用 resnet50() 并将其与预训练权重或自己的数据集一起使用。


参考:Kaiming He 等人,Deep Residual Learning for Image Recognition (CVPR 2016).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2374606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp|商品列表加入购物车实现抛物线动画效果、上下左右抛入、多端兼容(H5、APP、微信小程序)

以uniapp框架为基础,详细解析商品列表加入购物车抛物线动画的实现方案。通过动态获取商品点击位置与购物车坐标,结合CSS过渡动画模拟抛物线轨迹,实现从商品图到购物车图标的动态效果。 目录 核心实现原理坐标动态计算抛物线轨迹模拟​动画元素控制代码实现详解模板层设计脚本…

谈AI/OT 的融合

过去的十几年间,工业界讨论最多的话题之一就是IT/OT 融合,现在,我们不仅要实现IT/OT 的融合,更要面向AI/OT 的融合。看起来不太靠谱,却留给我们无限的想象空间。OT 领域的专家们不要再当“九斤老太”,指责这…

USB传输模式

USB有四种传输模式: 控制传输, 中断传输, 同步传输, 批量传输 1. 中断传输 中断传输一般用于小批量, 非连续的传输. 对实时性要求较高. 常见的使用此传输模式的设备有: 鼠标, 键盘等. 要注意的是, 这里的 “中断” 和我们常见的中断概念有差异. Linux中的中断是设备主动发起的…

.NET10 - 尝试一下Open Api的一些新特性

1.简单介绍 .NET9中Open Api有了很大的变化,在默认的Asp.NET Core Web Api项目中,已经移除了Swashbuckle.AspNetCore package,同时progrom中也变更为 builder.Servers.AddOpenApi() builder.Services.MapOpenApi() 2025年微软将发布…

RabbitMQ 工作模式

RabbitMQ 一共有 7 中工作模式,可以先去官网上了解一下(一下截图均来自官网):RabbitMQ 官网 Simple P:生产者,要发送消息的程序;C:消费者,消息的接受者;hell…

基于C++的多线程网络爬虫设计与实现(CURL + 线程池)

在当今大数据时代,网络爬虫作为数据采集的重要工具,其性能直接决定了数据获取的效率。传统的单线程爬虫在面对海量网页时往往力不从心,而多线程技术可以充分利用现代多核CPU的计算能力,显著提升爬取效率。本文将详细介绍如何使用C…

【日撸 Java 三百行】Day 11(顺序表(一))

目录 Day 11:顺序表(一) 一、关于顺序表 二、关于面向对象 三、代码模块分析 1. 顺序表的属性 2. 顺序表的方法 四、代码及测试 拓展: 小结 Day 11:顺序表(一) Task: 在《数…

软考 系统架构设计师系列知识点之杂项集萃(55)

接前一篇文章:软考 系统架构设计师系列知识点之杂项集萃(54) 第89题 某软件公司欲开发一个Windows平台上的公告板系统。在明确用户需求后,该公司的架构师决定采用Command模式实现该系统的界面显示部分,并设计UML类图如…

保持Word中插入图片的清晰度

大家有没有遇到这个问题,原本绘制的高清晰度图片,插入word后就变模糊了。先说原因,word默认启动了自动压缩图片功能,分享一下如何关闭这项功能,保持Word中插入图片的清晰度。 ①在Word文档中,点击左上角的…

Linux复习笔记(三) 网络服务配置(web)

遇到的问题,都有解决方案,希望我的博客能为你提供一点帮助。 二、网络服务配置 2.3 web服务配置 2.3.1通信基础:HTTP协议与C/S架构(了解) ​​HTTP协议的核心作用​​ Web服务基于HTTP/HTTPS协议实现客户端&#xff…

springboot旅游小程序-计算机毕业设计源码76696

目 录 摘要 1 绪论 1.1研究背景与意义 1.2研究现状 1.3论文结构与章节安排 2 基于微信小程序旅游网站系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统…

uniapp自定义导航栏搭配插槽

<uni-nav-bar dark :fixed"true" shadow background-color"#007AFF" left-icon"left" left-text"返回" clickLeft"back"><view class"nav-bar-title">{{ navBarTitle }}</view><block v-slo…

MFC listctrl修改背景颜色

在 MFC 中修改 ListCtrl 控件的行背景颜色&#xff0c;需要通过自绘&#xff08;Owner-Draw&#xff09;机制实现。以下是详细的实现方法&#xff1a; 方法一&#xff1a;通过自绘&#xff08;Owner-Draw&#xff09;实现 步骤 1&#xff1a;启用自绘属性 在对话框设计器中选…

SpringBoot+Dubbo+Zookeeper实现分布式系统步骤

SpringBootDubboZookeeper实现分布式系统 一、分布式系统通俗解释二、环境准备&#xff08;详细版&#xff09;1. 软件版本2. 安装Zookeeper&#xff08;单机模式&#xff09; 三、完整项目结构&#xff08;带详细注释&#xff09;四、手把手代码实现步骤1&#xff1a;创建父工…

Linux进程10-有名管道概述、创建、读写操作、两个管道进程间通信、读写规律(只读、只写、读写区别)、设置阻塞/非阻塞

目录 1.有名管道 1.1概述 1.2与无名管道的差异 2.有名管道的创建 2.1 直接用shell命令创建有名管道 2.2使用mkfifo函数创建有名管道 3.有名管道读写操作 3.1单次读写 3.2多次读写 4.有名管道进程间通信 4.1回合制通信 4.2父子进程通信 5.有名管道读写规律&#xff…

精品可编辑PPT | 全面风险管理信息系统项目建设风控一体化标准方案

这份文档是一份全面风险管理信息系统项目建设风控一体化标准方案&#xff0c;涵盖了业务架构、功能方案、系统技术架构设计、项目实施及服务等多个方面的详细内容。方案旨在通过信息化手段提升企业全面风险管理工作水平&#xff0c;促进风险管理落地和内部控制规范化&#xff0…

YOLOv8网络结构

YOLOv8的网络结构由输入端(Input)、骨干网络(Backbone)、颈部网络(Neck)和检测头(Head)四部分组成。 YOLOv8的网络结构如下图所示&#xff1a; 在整个系统架构中&#xff0c;图像首先进入输入处理模块&#xff0c;该模块承担着图像预处理与数据增强的双重任务。接着&#xff0c…

笔记本电脑升级实战手册【扩展篇1】:flash id查询硬盘颗粒

文章目录 前言&#xff1a;一、硬盘颗粒介绍1、MLC&#xff08;Multi-Level Cell&#xff09;2、TLC&#xff08;Triple-Level Cell&#xff09;3、QLC&#xff08;Quad-Level Cell&#xff09; 二、硬盘与主控1、主控介绍2、主流主控厂家 三 、硬盘颗粒查询使用flash id工具查…

AutoDL租用服务器教程

在跑ai模型的时候&#xff0c;容易遇到算力不够的情况。此时便需要租用服务器。autodl是个较为便宜的服务器租用平台&#xff0c;h20仅需七点几元每小时。下面是简单的介绍。 打开网站AutoDL算力云 | 弹性、好用、省钱。租GPU就上AutoDL&#xff0c;并登录账号 登录后&#xff…

goner/otel 在Gone框架接入OpenTelemetry

文章目录 背景与意义快速上手&#xff1a;五步集成 OpenTelemetry运行效果展示代码详解与实践目录结构说明组件加载&#xff08;module.load.go&#xff09;业务组件示例&#xff08;your_component.go&#xff09;程序入口&#xff08;main.go&#xff09; 进阶用法与最佳实践…