别再只调参了!手把手教你用EfficientNet-B0的MBConv和SENet模块,在PyTorch里复现一个轻量级分类网络
从零构建EfficientNet-B0核心模块MBConv与SENet的PyTorch实战指南当你第一次看到EfficientNet论文中那些复杂的结构图时是否感到无从下手作为计算机视觉领域的重要里程碑EfficientNet系列模型以其出色的性能与效率平衡著称。但大多数教程止步于理论介绍或简单调用预训练模型很少深入探讨如何从零实现其核心架构。本文将带你用PyTorch亲手构建EfficientNet-B0的两个关键模块——MBConv和SENet最终组装成一个完整的轻量级分类网络。1. 环境准备与基础架构在开始编码前我们需要搭建好开发环境并理解EfficientNet-B0的整体架构。不同于直接调用torchvision.models.efficientnet我们将从最基础的卷积层开始构建。首先确保已安装最新版PyTorch和torchvisionpip install torch torchvision matplotlibEfficientNet-B0由以下几个主要部分组成初始卷积层Stem Convolution16个MBConv模块核心部分顶部卷积层全局平均池化和全连接分类层让我们先定义网络的基本骨架import torch import torch.nn as nn class EfficientNetB0(nn.Module): def __init__(self, num_classes10): super(EfficientNetB0, self).__init__() # Stem卷积层 self.stem nn.Sequential( nn.Conv2d(3, 32, kernel_size3, stride2, padding1, biasFalse), nn.BatchNorm2d(32), nn.SiLU() # Swish激活函数 ) # MBConv模块将在这里添加 self.blocks nn.Sequential() # 顶部卷积层 self.top nn.Sequential( nn.Conv2d(320, 1280, kernel_size1, biasFalse), nn.BatchNorm2d(1280), nn.SiLU() ) # 分类器 self.classifier nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(1280, num_classes) ) def forward(self, x): x self.stem(x) x self.blocks(x) x self.top(x) x self.classifier(x) return x提示Swish激活函数SiLU是EfficientNet中的重要组件定义为x * sigmoid(x)在PyTorch中可直接使用nn.SiLU()2. MBConv模块的深度解析与实现MBConvMobile Inverted Bottleneck Convolution是EfficientNet的核心构建块它结合了深度可分离卷积和残差连接。与MobileNetV2的MBConv不同EfficientNet的版本还加入了SENet注意力机制。2.1 MBConv的结构分解一个标准的MBConv模块包含以下层次结构1×1扩展卷积当expand_ratio1时深度可分离卷积Depthwise ConvolutionSENet注意力模块1×1投影卷积残差连接当满足条件时让我们先实现深度可分离卷积这是MBConv的关键部分class DepthwiseSeparableConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride1): super().__init__() padding (kernel_size - 1) // 2 self.depthwise nn.Conv2d( in_channels, in_channels, kernel_size, stridestride, paddingpadding, groupsin_channels, biasFalse ) self.bn1 nn.BatchNorm2d(in_channels) self.pointwise nn.Conv2d( in_channels, out_channels, kernel_size1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) self.act nn.SiLU() def forward(self, x): x self.depthwise(x) x self.bn1(x) x self.act(x) x self.pointwise(x) x self.bn2(x) return x2.2 完整MBConv的实现现在我们可以构建完整的MBConv模块注意处理expand_ratio和残差连接的条件class MBConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size3, stride1, expand_ratio1, se_ratio0.25): super().__init__() self.use_residual (in_channels out_channels) and (stride 1) hidden_dim in_channels * expand_ratio layers [] # 扩展阶段 if expand_ratio ! 1: layers.extend([ nn.Conv2d(in_channels, hidden_dim, 1, biasFalse), nn.BatchNorm2d(hidden_dim), nn.SiLU() ]) # 深度可分离卷积 layers.extend([ nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stridestride, padding(kernel_size-1)//2, groupshidden_dim, biasFalse), nn.BatchNorm2d(hidden_dim), nn.SiLU() ]) # 添加SENet模块 layers.append(SEModule(hidden_dim, se_ratio)) # 投影阶段 layers.extend([ nn.Conv2d(hidden_dim, out_channels, 1, biasFalse), nn.BatchNorm2d(out_channels) ]) self.conv nn.Sequential(*layers) def forward(self, x): if self.use_residual: return x self.conv(x) return self.conv(x)注意expand_ratio控制着通道扩展的程度当expand_ratio1时表示不进行通道扩展。se_ratio控制SENet模块中压缩的比例。3. SENet注意力机制的实现与集成SENetSqueeze-and-Excitation Network是MBConv中的重要组成部分它通过学习通道间的关系来自适应地调整各通道的权重。3.1 SENet的工作原理SENet包含两个主要操作Squeeze全局平均池化将空间维度压缩为1×1Excitation两个全连接层形成瓶颈结构学习通道间的相关性实现代码如下class SEModule(nn.Module): def __init__(self, channels, se_ratio0.25): super().__init__() self.avg_pool nn.AdaptiveAvgPool2d(1) reduced_channels max(1, int(channels * se_ratio)) self.fc nn.Sequential( nn.Conv2d(channels, reduced_channels, 1, biasTrue), nn.SiLU(), nn.Conv2d(reduced_channels, channels, 1, biasTrue), nn.Sigmoid() ) def forward(self, x): y self.avg_pool(x) y self.fc(y) return x * y3.2 SENet在MBConv中的位置在MBConv中SENet模块位于深度可分离卷积之后、投影层之前。这种位置安排使得网络可以先提取空间特征然后通过注意力机制重新校准通道重要性最后再进行降维。为了验证我们的实现是否正确可以对比有无SENet模块的性能差异# 测试MBConv模块 mbconv_with_se MBConv(32, 16, expand_ratio6, se_ratio0.25) mbconv_without_se MBConv(32, 16, expand_ratio6, se_ratioNone) x torch.randn(1, 32, 224, 224) print(Output with SENet:, mbconv_with_se(x).shape) print(Output without SENet:, mbconv_without_se(x).shape)4. 完整网络组装与训练技巧现在我们已经实现了所有关键组件接下来需要按照EfficientNet-B0的架构将它们组装起来。4.1 构建完整的网络块EfficientNet-B0包含7个阶段每个阶段有特定的配置阶段操作重复次数输入通道输出通道扩展比例核大小步长SE比例1MBConv1321613x310.252MBConv2162463x320.253MBConv2244065x520.254MBConv3408063x320.255MBConv38011265x510.256MBConv411219265x520.257MBConv119232063x310.25根据上表配置网络def make_blocks(): block_configs [ # (num_repeat, in_channels, out_channels, expand_ratio, kernel_size, stride, se_ratio) (1, 32, 16, 1, 3, 1, 0.25), (2, 16, 24, 6, 3, 2, 0.25), (2, 24, 40, 6, 5, 2, 0.25), (3, 40, 80, 6, 3, 2, 0.25), (3, 80, 112, 6, 5, 1, 0.25), (4, 112, 192, 6, 5, 2, 0.25), (1, 192, 320, 6, 3, 1, 0.25) ] blocks [] for config in block_configs: num_repeat, in_c, out_c, expand_ratio, kernel_size, stride, se_ratio config # 第一个块可能有不同的stride blocks.append(MBConv(in_c, out_c, kernel_size, stride, expand_ratio, se_ratio)) # 重复的块保持通道数不变stride1 for _ in range(1, num_repeat): blocks.append(MBConv(out_c, out_c, kernel_size, 1, expand_ratio, se_ratio)) return nn.Sequential(*blocks)4.2 训练技巧与超参数设置在CIFAR-10这样的小数据集上训练EfficientNet时需要注意以下几点学习率调度使用余弦退火学习率数据增强RandAugment或AutoAugment效果很好优化器选择使用带有权重衰减的AdamW标签平滑有助于防止过拟合示例训练代码片段from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR model EfficientNetB0(num_classes10) optimizer AdamW(model.parameters(), lr1e-3, weight_decay1e-4) scheduler CosineAnnealingLR(optimizer, T_max100) criterion nn.CrossEntropyLoss(label_smoothing0.1) # 训练循环 for epoch in range(100): for inputs, targets in train_loader: outputs model(inputs) loss criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step()4.3 模型压缩与部署优化虽然EfficientNet已经是轻量级模型但在边缘设备上部署时还可以进一步优化量化使用PyTorch的量化工具减少模型大小剪枝移除不重要的连接TensorRT加速转换模型以获得更好的推理性能量化示例model EfficientNetB0().eval() quantized_model torch.quantization.quantize_dynamic( model, {nn.Linear, nn.Conv2d}, dtypetorch.qint8 ) torch.save(quantized_model.state_dict(), efficientnet_quantized.pth)在实际项目中我发现MBConv模块中的深度可分离卷积对计算效率提升最大而SENet虽然增加了少量计算量但带来的精度提升通常值得这些额外开销。特别是在处理细粒度分类任务时通道注意力机制能显著改善模型对细微特征的识别能力。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2600824.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!