MindSpore框架学习项目-ResNet药物分类-构建模型

news2025/7/14 6:36:53

目录

2.构建模型

2.1定义模型类

2.1.1 基础块ResidualBlockBase

ResidualBlockBase代码解析

2.1.2 瓶颈块ResidualBlock

ResidualBlock代码解释

2.1.3 构建层

构建层代码说明

2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现

ResNet组建类代码解析

2.1.5 实例化resnet_xx网络

实例化resnet_xx网络代码分析

2.2模型初始化

模型初始化代码解析


参考内容: 昇思MindSpore | 全场景AI框架 | 昇思MindSpore社区官网 华为自研的国产AI框架,训推一体,支持动态图、静态图,全场景适用,有着不错的生态

本项目可以在华为云modelart上租一个实例进行,也可以在配置至少为单卡3060的设备上进行

https://console.huaweicloud.com/modelarts/

Ascend环境也适用,但是注意修改device_target参数

需要本地编译器的一些代码传输、修改等可以勾上ssh远程开发

说明:项目使用的数据集来自华为云的数据资源。项目以深度学习任务构建的一般流程展开(数据导入、处理 > 模型选择、构建 > 模型训练 > 模型评估 > 模型优化)。

主线为‘一般流程’,同时代码中会标注出一些要点(# 要点1-1-1:设置使用的设备

)作为支线,帮助学习mindspore框架在进行深度学习任务时一些与pytorch的差异。

可以只看目录中带数字标签的部分来快速查阅代码。

2.构建模型

2.1定义模型类

要求:

补充如下代码的空白处

主要完成:

1. 实现1个卷积层和1个ReLU激活函数的定义

2. 实现ResidualBlockBase和ResidualBlock模块的残差连接,并补全self.layer4的参数

导入mindspore训练环节(包括模型构建、激活函数、反向传播、损失函数等需要的库)

from mindspore import Model
from mindspore import context
import mindspore.ops as ops
from mindspore import Tensor, nn, set_context, GRAPH_MODE, train
from mindspore import load_checkpoint, load_param_into_net
from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal

初始化:weight_init = Normal(mean=0, sigma=0.02) 用于初始化卷积层;

gamma_init = Normal(mean=1, sigma=0.02) 用于初始化批归一化层

weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

2.1.1 基础块ResidualBlockBase

conv1 和 conv2 的参数设置:

conv1 负责处理输入数据的空间下采样(通过 stride 参数)或通道数变换(通过 out_channels),同时进行第一次特征提取。

conv2 固定为 3×3 卷积,不改变空间尺寸(默认 stride=1),仅对 conv1 的输出进一步提取特征。

(卷积层当池化层用)

class ResidualBlockBase(nn.Cell):
    expansion: int = 1

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlockBase, self).__init__()
        if not norm:
            self.norm = nn.BatchNorm2d(out_channel)
        else:
            self.norm = norm
        # 要点2-1-1:实现1个卷积层和一个ReLU激活函数的定义
        # 1. Conv2d:
        #    in_channels (int) - Conv2d层输入Tensor的空间维度。
        #    out_channels (int) - Conv2d层输出Tensor的空间维度。
        #    kernel_size (Union[int, tuple[int]]) - 指定二维卷积核的高度和宽度, 卷积核大小为3X3;
        #    stride (Union[int, tuple[int]],可选) - 二维卷积核的移动步长。
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.conv2 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, weight_init=weight_init)
        # 2. ReLU:逐元素计算ReLU(Rectified Linear Unit activation function)修正线性单元激活函数。需要调用MindSpore的相关API.
        self.relu = nn.ReLU()
        self.down_sample = down_sample
        
    def construct(self, x):
        """ResidualBlockBase construct."""
        identity = x

        out = self.conv1(x)
        out = self.norm(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)
        # 要点2-1-2: 
        # 1. 实现ResidualBlockBase模块的残差连接
        out = out+identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out

ResidualBlockBase代码解析

核心类定义:ResidualBlockBase

作用:实现残差网络的基础块(Basic Block),包含主分支(卷积路径)和短路连接(Shortcut),解决深层网络梯度消失问题。

输入:

in_channel :输入特征图通道数

out_channel :输出特征图通道数

stride :卷积步长(控制特征图尺寸变化,用于下采样)

norm :归一化层(默认使用 BatchNorm2d)

down_sample :下采样模块(用于调整短路连接的维度,确保与主分支输出维度一致)

要点 2-1-1:定义卷积层和 ReLU 激活函数

Conv2d 层实现

self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                       kernel_size=3, stride=stride,  # 3x3卷积核,步长由参数控制
                       weight_init=weight_init)     # 权重初始化(正态分布,sigma=0.02)
self.conv2 = nn.Conv2d(in_channel, out_channel,  # 第二个卷积层输入通道仍为in_channel(残差块基础版)
                       kernel_size=3, weight_init=weight_init)

关键参数:

kernel_size=3:固定使用 3x3 卷积核,符合残差块基础设计(如 ResNet-18/34 的 Basic Block)。

stride=stride:第一个卷积层的步长由外部传入(用于下采样),第二个卷积层步长固定为 1(保持尺寸)。

weight_init=weight_init:使用正态分布初始化权重(Normal(mean=0, sigma=0.02)),避免梯度爆炸 / 消失。

ReLU 激活函数

self.relu = nn.ReLU()  # 直接调用MindSpore的ReLU模块,逐元素计算max(0, x)

作用:引入非线性,避免网络退化为线性层,同时缓解梯度消失。

要点 2-1-2:实现残差连接

if self.down_sample is not None:
    identity = self.down_sample(x)  # 下采样:调整短路连接的维度(通道数/尺寸)
out = out + identity  # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out)   # 最后一次ReLU激活,输出非线性特征

残差连接逻辑:

短路连接(Identity Mapping):当输入x的维度(通道数 / 尺寸)与主分支输出out一致时,直接相加(identity = x)。

若维度不一致(如通道数增加或尺寸缩小),通过down_sample模块对x进行下采样(通常是 1x1 卷积 + 步长调整),确保形状匹配。

相加操作:

核心公式:输出 = 主分支输出 + 短路连接,强制保留原始输入信息,使梯度能直接回传至浅层。

激活函数位置:相加后再进行一次 ReLU 激活,确保输出为非线性特征,符合 ResNet 设计规范。

关键模块解析

1. 归一化层(Norm)处理

if not norm:
    self.norm = nn.BatchNorm2d(out_channel)  # 默认使用BatchNorm2d
else:
    self.norm = norm  # 支持自定义归一化层(如GroupNorm)

作用:对卷积输出进行归一化,加速训练并提升模型鲁棒性。

位置:每个卷积层后立即接归一化层,再接 ReLU 激活(Conv→Norm→ReLU 顺序)。

2. 下采样模块(down_sample)

self.down_sample = down_sample  # 由外部传入,通常是1x1卷积+步长

触发场景:当in_channel ≠ out_channel或stride > 1时,需通过下采样调整短路连接的维度。

典型实现:

down_sample = nn.SequentialCell([
    nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, weight_init=weight_init),
    nn.BatchNorm2d(out_channel)
])

通过 1x1 卷积调整通道数,步长调整尺寸,确保与主分支输出形状一致。

3. 权重初始化策略

weight_init = Normal(mean=0, sigma=0.02)  # 卷积层权重初始化
gamma_init = Normal(mean=1, sigma=0.02)    # BatchNorm的γ参数初始化(未在当前代码中使用)

正态分布初始化:较小的标准差(σ=0.02)避免初始权重过大导致激活值饱和,符合深度学习框架的常见实践(如 PyTorch 的默认初始化)。

2.1.2 瓶颈块ResidualBlock
class ResidualBlock(nn.Cell):
    expansion = 4

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=1, weight_init=weight_init)
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
                               kernel_size=1, weight_init=weight_init)
        self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):

        identity = x

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.norm3(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)
        # 2. 实现ResidualBlock模块的残差连接
        out = out+identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out

ResidualBlock代码解释

核心类定义:ResidualBlock(瓶颈块)

作用:实现深层残差网络的瓶颈结构,通过 “降维 - 特征提取 - 升维” 减少计算量,支持构建更深的网络(如 50 层以上)。

核心参数:

expansion=4 :升维因子(固定为 4,符合 ResNet 设计规范),即最后一个 1x1 卷积将通道数扩展为out_channel×4。

in_channel :输入特征图通道数

out_channe :中间层特征图通道数(经 1x1 卷积降维后的通道数)

stride :3x3 卷积的步长(控制特征图尺寸变化,用于下采样)

down_sample :下采样模块(调整短路连接的维度,确保与主分支输出维度一致)

要点:瓶颈块的结构设计

瓶颈块通过三层卷积实现 “降维→特征提取→升维”,显著减少计算量(对比基础块的两层 3x3 卷积):

第一层:1x1 卷积(降维)

self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, weight_init=weight_init)

作用:将输入通道数从in_channel降为out_channel(如输入 256→输出 64),减少后续 3x3 卷积的计算量。

卷积核大小:1x1,仅改变通道数,不改变特征图尺寸(stride=1,无填充)。

第二层:3x3 卷积(特征提取)

self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=stride, weight_init=weight_init)

作用:在降维后的低维空间提取空间特征(如边缘、纹理)。

关键参数:stride=stride:支持下采样(如 stride=2 时特征图尺寸减半),由外部传入(用于构建不同 stage 的残差块)。

kernel_size=3:保持 3x3 卷积核,确保感受野与基础块一致。

第三层:1x1 卷积(升维)

self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion, kernel_size=1, weight_init=weight_init)

作用:将通道数从out_channel升至out_channel×expansion(如 64→256),与短路连接维度匹配(因 ResNet 的 stage 设计中,输出通道数通常是输入的 4 倍)。

核心公式:输出通道数 = out_channel × expansion(此处expansion=4是 ResNet 瓶颈块的固定设计)。

要点:残差连接与维度匹配

if self.down_sample is not None:
    identity = self.down_sample(x)  # 调整短路连接的维度
out = out + identity  # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out)   # 最后一次ReLU激活

触发下采样的场景:

当以下任意条件成立时,需通过down_sample调整短路连接:输入通道数in_channel ≠ 输出通道数out_channel×expansion(升维导致通道数不匹配)。

stride > 1(特征图尺寸缩小,短路连接需同步下采样)。

下采样模块实现(通常由外部传入):

down_sample = nn.SequentialCell([
    nn.Conv2d(in_channel, out_channel*self.expansion, kernel_size=1, stride=stride, weight_init=weight_init),
    nn.BatchNorm2d(out_channel*self.expansion)
])

通过 1x1 卷积调整通道数,步长调整尺寸,确保identity与主分支输出out形状一致(通道数、高度、宽度均相同)。

归一化层与激活函数的顺序

# 每一层的处理流程:Conv → BatchNorm → ReLU
out = self.conv1(x)       # 1x1卷积(降维)
out = self.norm1(out)     # BatchNorm2d
out = self.relu(out)      # ReLU激活
out = self.conv2(out)     # 3x3卷积(特征提取)
out = self.norm2(out)     # BatchNorm2d
out = self.relu(out)      # ReLU激活
out = self.conv3(out)     # 1x1卷积(升维)
out = self.norm3(out)     # BatchNorm2d(升维后归一化)

设计原则:符合 ResNet 的 “Post-Normalization” 架构,即在卷积后立即归一化,再激活,确保每一层输入处于稳定分布,加速训练收敛。

权重初始化策略

weight_init = Normal(mean=0, sigma=0.02)  # 与基础块一致,小方差初始化避免梯度爆炸

作用:对 1x1 和 3x3 卷积的权重进行正态分布初始化,确保初始权重较小,激活值不会因过大输入导致饱和(如 ReLU 的负数区域失活)。

与基础残差块(ResidualBlockBase)的区别

2.1.3 构建层

根据给定参数构建由指定数量残差块组成的网络层,包括处理下采样及层间连接等

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
               channel: int, block_nums: int, stride: int = 1):
    down_sample = None


    if stride != 1 or last_out_channel != channel * block.expansion:

        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channel, channel * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
        ])

    layers = []
    layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))

    in_channel = channel * block.expansion

    for _ in range(1, block_nums):

        layers.append(block(in_channel, channel))

    return nn.SequentialCell(layers)

构建层代码说明

功能定位

ResNet 的整体架构就是通过make_layer函数不断堆叠残差块,形成多个 stage,每个 stage 内部保持相同的通道数,相邻 stage 之间通过下采样调整尺寸和通道数,最终构建出深度神经网络。

输入参数:

last_out_channel :上一层输出的特征图通道数(用于判断是否需要下采样)。

block :残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块,通过Type[Union]支持两种类型)。

channel :当前 stage 的基础通道数(瓶颈块中为降维后的通道数,基础块中为输出通道数)。

block_nums :当前 stage 包含的残差块数量(如 ResNet-50 的每个 stage 包含 3/4/6/3 个瓶颈块)。

stride :当前 stage 第一个残差块的卷积步长(控制下采样,默认 1 表示不采样)。

输出:

由多个残差块组成的nn.SequentialCell序列(可直接作为网络的一个 stage,如 ResNet 的layer1、layer2等)。

核心代码逻辑解析

1. 下采样模块(down_sample)的条件判断与创建(核心考点)

if stride != 1 or last_out_channel != channel * block.expansion:
    down_sample = nn.SequentialCell([
        nn.Conv2d(last_out_channel, channel * block.expansion,
                  kernel_size=1, stride=stride, weight_init=weight_init),
        nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
    ])

触发条件(满足任意一条即需下采样):

stride != 1:需要对特征图尺寸进行下采样(如 stride=2 时尺寸减半)。

last_out_channel != channel * block.expansion:输入通道数与当前 stage 输出通道数不一致(瓶颈块中输出通道是channel×4,基础块中是channel×1)。

下采样实现:

通过1x1 卷积调整通道数(从last_out_channel到channel×block.expansion)。

卷积步长设为stride,同步调整特征图尺寸(与主分支的 3x3 卷积步长一致)。

接 BatchNorm 层归一化,确保短路连接的输出分布稳定。

核心作用:保证短路连接(identity)的维度与主分支输出一致,使out + identity操作可行。

2. 残差块序列的构建

layers = []
# 添加第一个残差块(可能包含下采样)
layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))
# 更新输入通道为当前stage的输出通道(block.expansion倍)
in_channel = channel * block.expansion
# 添加后续残差块(无下采样,stride=1,通道数已对齐)
for _ in range(1, block_nums):
    layers.append(block(in_channel, channel))  # 输入通道为上一个块的输出通道

第一个块的特殊性:

传入stride和down_sample,处理当前 stage 的下采样和通道对齐(如 ResNet 中layer2的第一个块 stride=2,实现尺寸减半)。

若无需下采样(stride=1 且通道数匹配),down_sample=None,短路连接直接使用输入x。

后续块的一致性:

输入通道in_channel固定为channel×block.expansion(即上一个块的输出通道)。

不再传入stride(默认 1)和down_sample(无需下采样,通道数已对齐),所有后续块仅做恒等残差连接。

2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现
from mindspore import load_checkpoint, load_param_into_net
from mindspore import ops

class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) # 要点2-1-3:layer4的输出通道参数‘512’的含义
        self.avg_pool = ops.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):
        
        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x,(2,3))
        
        x = self.flatten(x)
        x = self.fc(x)

        return x

ResNet组建类代码解析

1. 类定义与核心参数

class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

block:残差块类型(基础块ResidualBlockBase或瓶颈块ResidualBlock),决定网络层数和计算复杂度。

layer_nums:各 stage 的残差块数量(如[3, 4, 6, 3]对应 ResNet-50)。

num_classes:分类任务的类别数(如中药材分类的 12 类)。

input_channel:全连接层输入通道数(由最后一个 stage 的输出通道决定,如瓶颈块下为512×4=2048)。

2. 主干网络结构

输入层与初始特征提取

self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)  # 7x7卷积
self.norm = nn.BatchNorm2d(64)  # 归一化
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')  # 最大池化

7x7 卷积:输入 3 通道(RGB 图像),输出 64 通道,步长 2,初步提取特征并降采样(尺寸减半)。

最大池化:核大小 3x3,步长 2,pad_mode='same'保持空间尺寸对称减半(如 224→112→56)。

四个 stage(layer1-layer4)

self.layer1 = make_layer(64, block, 64, layer_nums[0])  # stride=1(默认)
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)  # 下采样
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)  # 关键考点:补全layer4参数

make_layer功能:动态构建残差块序列,每个 stage:第一个块通过stride=2实现下采样(layer2-layer4),通道数翻倍(如 64→128→256→512)。

block.expansion控制通道升维(基础块 = 1,瓶颈块 = 4),例如瓶颈块下64×4=256作为下 stage 输入。

输出层

self.avg_pool = ops.ReduceMean(keep_dims=True)  # 全局平均池化
self.flatten = nn.Flatten()  # 展平特征
self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)  # 全连接分类

全局平均池化:替代全连接层前的全连接操作,减少参数数量,输出特征图尺寸为(batch, 512×expansion, 1, 1)。

全连接层:将特征映射到num_classes维空间,输出分类概率。

3. 前向传播逻辑

def construct(self, x):
    x = self.conv1(x) → self.norm(x) → self.relu(x) → self.max_pool(x)  # 初始特征提取
    x = self.layer1(x) → self.layer2(x) → self.layer3(x) → self.layer4(x)  # 四级残差块特征提取
    x = self.avg_pool(x, (2, 3))  # 对空间维度(H=2, W=3,假设输入为7x7)做平均池化
    x = self.flatten(x)  # 展平为一维向量(shape: [batch, input_channel])
    x = self.fc(x)  # 分类输出
    return x

空间尺寸变化:假设输入 224x224,经过conv1(stride=2)和max_pool(stride=2)后尺寸为 56x56,每层 stage 若stride=2则尺寸减半(56→28→14→7),最终layer4输出 7x7。

通道数变化:随 stage 递增(64→128→256→512),经block.expansion后瓶颈块通道数为 256→512→1024→2048。

4. 核心要点与设计原则

layer4参数补全(题目要求):输入通道为256×block.expansion(上一 stage 输出),当前 stage 基础通道512,stride=2实现最后一次下采样。

残差块类型兼容性:通过block参数支持基础块(浅层)和瓶颈块(深层),expansion自动适配通道逻辑(无需为不同块编写独立代码)。

下采样策略:每个 stage 的第一个块通过stride=2和1x1卷积调整通道 / 尺寸,保证残差连接维度匹配。

计算效率:瓶颈块通过1x1卷积降维减少 3x3 卷积计算量,使深层网络(如 ResNet-152)训练可行。

5. 代码关键作用

模块化构建:通过make_layer和残差块组合,快速搭建不同深度的 ResNet(如 50 层、101 层)。

特征提取流程:从浅层边缘检测到深层语义特征,逐层抽象,适应图像分类任务。

维度匹配:自动处理残差连接的通道和尺寸对齐,避免手动计算错误。

2.1.5 实例化resnet_xx网络

实例化resnet50

def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)
    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

实例化resnet_xx网络代码分析

ps:代码中‘ return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,

pretrained, resnet50_ckpt, 2048)’

为什么能跨函数识别到 pretrained参数?

在 Python 中,这是因为作用域的规则 。在resnet50函数中,pretrained是该函数的参数,属于局部作用域。当调用_resnet函数时,pretrained作为参数传递给_resnet函数,所以_resnet函数能够识别并使用这个参数

1. 通用模型构建函数 _resnet

def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)
    return model

功能:通用 ResNet 模型构建接口,通过参数化残差块类型、层数、分类数等,灵活生成不同配置的 ResNet 模型。

参数解析:block:残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块)。

layers:各 stage 的残差块数量列表(如[3,4,6,3]对应 ResNet-50 的四个 stage)。

num_classes:分类任务的类别数(如中药材的 12 类)。

pretrained:是否加载预训练权重(布尔值,True表示加载)。

pretrained_ckpt:预训练权重文件路径(如"./LoadPretrainedModel/resnet50_224_new.ckpt")。

input_channel:全连接层输入维度(由最后一个 stage 的输出通道决定,如 ResNet-50 为 2048)。

2. 特定模型:ResNet-50 的封装 resnet50

def resnet50(num_classes: int = 1000, pretrained: bool = False):
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

功能:直接生成 ResNet-50 模型,固定了 ResNet-50 的核心配置(瓶颈块、各 stage 块数、输入通道)。

关键参数固定值:block=ResidualBlock:使用瓶颈块(Bottleneck Block),适用于深层网络(ResNet-50/101/152)。

layers=[3,4,6,3]:ResNet-50 的标准配置(四个 stage 分别包含 3、4、6、3 个瓶颈块)。

input_channel=2048:最后一个 stage 输出通道数(512 基础通道 × 瓶颈块 expansion=4)。

预训练支持:resnet50_ckpt指定了预训练权重路径(如用户需要加载 ImageNet 预训练权重,可通过pretrained=True启用)。

3. 代码设计核心价值

模块化与复用性:

_resnet作为通用构建函数,通过参数化残差块类型和层数,可扩展生成 ResNet-18(基础块 +[2,2,2,2])、ResNet-101([3,4,23,3])等变体,避免重复代码。用户友好性:

resnet50函数封装了 ResNet-50 的具体配置,用户只需指定num_classes(分类数)和pretrained(是否加载预训练),即可快速获取模型,降低使用门槛。

2.2模型初始化

要求:

对定义的ResNet50模型进行实例化

实例化一个用于12分类的resnet50模型

# 要点2-2-1: 对定义的ResNet50模型进行实例化
network = resnet50(num_classes=12)
num_class = 12
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)
network.fc = fc

for param in network.get_parameters():
    param.requires_grad = True

模型初始化代码解析

1. 实例化 ResNet50 模型

network = resnet50(num_classes=12)

作用:调用resnet50函数创建 ResNet50 模型实例,指定分类数为 12(如中药材的 12 类)。

内部逻辑:

resnet50函数通过_resnet生成 ResNet50 模型,默认使用瓶颈块(ResidualBlock)和标准层数[3,4,6,3],并将原 1000 类的全连接层(fc 层)初始化为 12 类输出(但需后续调整,见下文)。2. 替换全连接层适配新任务

num_class = 12  # 新任务的分类数(如中药材的12类)
in_channel = network.fc.in_channels  # 获取原fc层的输入通道数(ResNet50为2048)
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)  # 新建12类输出的全连接层
network.fc = fc  # 替换原模型的fc层

背景:预训练 ResNet50 的 fc 层通常输出 1000 类(ImageNet 任务),需替换为新任务的分类数(12 类)。

关键操作:获取原 fc 层输入维度(in_channel=2048,由 ResNet50 的全局平均池化输出决定)。

新建全连接层fc,输入维度保持 2048,输出维度改为 12。

替换原模型的 fc 层,完成模型输出适配。

3. 启用所有参数训练 -- 全量微调

for param in network.get_parameters():
    param.requires_grad = True

作用:将模型所有参数的梯度计算标志(requires_grad)设为True,允许训练时更新所有参数。

场景意义:若使用预训练模型(pretrained=True),此操作表示 “端到端微调”(所有层参数均参与训练),适合新数据集与预训练数据分布差异较大的场景(如中药材分类 vs ImageNet 通用分类)。

若未使用预训练(pretrained=False),则模型从头开始训练,所有参数自然需要梯度更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2372126.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Spring Boot】Spring Boot + Thymeleaf搭建mvc项目

Spring Boot Thymeleaf搭建mvc项目 1. 创建Spring Boot项目2. 配置pom.xml3. 配置Thymeleaf4. 创建Controller5. 创建Thymeleaf页面6. 创建Main启动类7. 运行项目8. 测试结果扩展:添加静态资源 1. 创建Spring Boot项目 打开IntelliJ IDEA → New Project → 选择M…

学习spring boot-拦截器Interceptor,过滤器Filter

目录 拦截器Interceptor 过滤器Filter 关于过滤器的前置知识可以参考: 过滤器在springboot项目的应用 一,使用WebfilterServletComponentScan 注解 1 创建过滤器类实现Filter接口 2 在启动类中添加 ServletComponentScan 注解 二,创建…

雷赛伺服L7-EC

1电子齿轮比: 电机圈脉冲1万 (pa11的值 x 4倍频) 2电机刚性: pa003 或者 0x2003 // 立即生效的 3LED显示: PA5.28 1 电机速度 4精度: PA14 //默认30,超过3圈er18…

阅文集团C++面试题及参考答案

能否不使用锁保证多线程安全? 在多线程编程中,锁(如互斥锁、信号量)是实现线程同步的传统方式,但并非唯一方式。不使用锁保证多线程安全的核心思路是避免共享状态、使用原子操作或采用线程本地存储。以下从几个方面详…

AVL树:保持平衡的高效二叉搜索树

目录 一、AVL树的概念 1. 二叉搜索树的局限性 2. AVL树的定义 二、AVL树节点结构 三、AVL树的插入操作 1. 插入流程 2. 代码实现片段 四、AVL树的旋转调整 1. 左单旋(RR型) 2. 右单旋(LL型) 3. 左右双旋(LR型…

Webpack基本用法学习总结

Webpack 基本使用核心概念处理样式资源步骤: 处理图片资源修改图片输出文件目录 自动清空上次打包的内容EslintBabel处理HTML资源搭建开发服务器生产模式提取css文件为单独文件问题: Css压缩HTML压缩 小结1高级SourceMap开发模式生产模式 HMROneOfInclud…

阿里云服务器数据库故障排查指南?

阿里云服务器数据库故障排查指南? 以下是针对阿里云服务器(如ECS自建数据库或阿里云RDS等托管数据库)的故障排查指南,涵盖常见问题的定位与解决方案: 一、数据库连接失败 检查网络连通性 ECS自建数据库 确认安全组规则放行数据库…

数图闪耀2025深圳CCFA中国零售博览会:AI+零售数字化解决方案引发现场热潮

展会时间:2025年5月8日—10日 地点:深圳国际会展中心(宝安新馆) 【深圳讯】5月8日,亚洲规模最大的零售行业盛会——2025 CCFA中国零售博览会在深圳盛大开幕。本届展会汇聚全球25个国家和地区的900余家参展商&#xff…

LeetCode 1722. 执行交换操作后的最小汉明距离 题解

示例: 输入:source [1,2,3,4], target [2,1,4,5], allowedSwaps [[0,1],[2,3]] 输出:1 解释:source 可以按下述方式转换: - 交换下标 0 和 1 指向的元素:source [2,1,3,4] - 交换下标 2 和 3 指向的元…

linux ptrace 图文详解(八) gdb跟踪被调试程序的子线程、子进程

目录 一、gdb跟踪被调试程序的fork、pthread_create操作 二、实现原理 三、代码实现 四、总结 (代码:linux 6.3.1,架构:arm64) One look is worth a thousand words. —— Tess Flanders 相关链接: …

游戏:用python写梦幻西游脚本(谢苏)

《梦幻西游》是一款受欢迎的网络游戏,许多玩家希望通过脚本来增强游戏体验,比如自动打怪、自动治疗等。本文将为您展示一个用Python编写简单《梦幻西游》自动打怪脚本的方案。 需求分析 1.1 具体问题 在《梦幻西游》中,玩家需要频繁与怪物进行…

Spring Boot 3.x集成SaToken使用swagger3+knife4j 4.X生成接口文档

说一说Spring Boot 3.X集成SaToken使用swagger3并使用第三方的knife4j踩过的坑&#xff0c;废话不多说直接上正题&#xff0c;SaToken的我就不贴了 第一步当然是要先导入相关的依赖&#xff0c;包括swagger和knife4j&#xff0c;如下 <dependency><groupId>com.gi…

用Python监控金价并实现自动提醒!附完整源码

&#x1f482; 个人网站:【 摸鱼游戏】【神级代码资源网站】【星海网址导航】&#x1f4bb;香港大宽带-4H4G 20M只要36/月&#x1f449; 点此查看详情 在日常投资中&#xff0c;很多朋友喜欢在一些平台买点黄金&#xff0c;低买高卖赚点小差价。但黄金价格实时波动频繁&#xf…

ChatTempMail - AI驱动的免费临时邮箱服务

在当今数字世界中&#xff0c;保护在线隐私的需求日益增长。ChatTempMail应运而生&#xff0c;作为一款融合人工智能技术的新一代临时邮箱服务&#xff0c;它不仅提供传统临时邮箱的基本功能&#xff0c;还通过AI技术大幅提升了用户体验。 核心功能与特性 1. AI驱动的智能邮件…

掌握单元测试:提升软件质量的关键步骤

介绍 测试&#xff1a;是一种用来促进鉴定软件的正确性、完整性、安全性和质量的过程。 阶段划分&#xff1a;单元测试、集成测试、系统测试、验收测试。 测试方法&#xff1a;白盒测试、黑盒测试及灰盒测试。 单元测试&#xff1a;就是针对最小的功能单元&#xff08;方法&…

YOLOv1模型架构、损失值、NMS极大值抑制

文章目录 前言一、YOLO系列v11、核心思想2、流程解析 二、损失函数1、位置误差2、置信度误差3、类别概率损失 三、NMS&#xff08;非极大值抑制&#xff09;总结YOLOv1的优缺点 前言 YOLOv1&#xff08;You Only Look Once: Unified, Real-Time Object Detection&#xff09;由…

【论文阅读】——Articulate AnyMesh: Open-Vocabulary 3D Articulated Objects Modeling

文章目录 摘要一、介绍二、相关工作2.1. 铰接对象建模2.2. 部件感知3D生成 三、方法3.1. 概述3.2. 通过VLM助手进行可移动部件分割3.3. 通过几何感知视觉提示的发音估计3.4. 通过随机关节状态进行细化 四、实验4.1. 定量实验发音估计设置: 4.2. 应用程序 五、结论六、思考 摘要…

HarmonyOS基本的应用的配置

鸿蒙HarmonyOS组建页面 1、创建ets文件并配置2、修改main_pages.json文件3、修改EntryAbility.ets文件&#xff08;启动时加载的页面&#xff09; 1、创建ets文件并配置 Index.ets是创建项目自动构建生成的&#xff0c;我们可以将其删除掉&#xff0c;并重新在page文件夹下创建…

【redis】集群模式

Redis Cluster是Redis官方推出的分布式解决方案&#xff0c;旨在通过数据分片、高可用和动态扩展能力满足大规模数据存储与高并发访问的需求。其核心机制基于虚拟槽分区&#xff0c;将16384个哈希槽均匀分配给集群中的主节点&#xff0c;每个键通过CRC16哈希算法映射到特定槽位…

DeepSeek实战--微调

1.为什么是微调 &#xff1f; 微调LLM&#xff08;Fine-tuning Large Language Models&#xff09; 是指基于预训练好的大型语言模型&#xff08;如GPT、LLaMA、PaLM等&#xff09;&#xff0c;通过特定领域或任务的数据进一步训练&#xff0c;使其适应具体需求的过程。它是将…