用PyTorch手把手实现BoTNet:把ResNet50的3x3卷积换成MHSA到底有多简单?
用PyTorch手把手实现BoTNet把ResNet50的3x3卷积换成MHSA到底有多简单如果你正在寻找一种既能保留CNN局部特征提取能力又能引入全局注意力机制的方法BoTNet可能是最优雅的解决方案之一。这个将ResNet中3x3卷积替换为多头自注意力(MHSA)的改动看似简单却效果显著。本文将用可运行的代码展示这一转换过程让你在10分钟内掌握核心实现技巧。1. 环境准备与基础理解在开始编码前我们需要明确几个关键概念。BoTNet的核心思想是在ResNet的Bottleneck块中用MHSA替代传统的3x3卷积。这种设计保留了CNN的层次结构同时在深层网络引入全局注意力机制。准备环境只需标准的PyTorch环境import torch import torch.nn as nn import torch.nn.functional as F为什么选择Bottleneck进行改造ResNet的Bottleneck结构天然适合插入注意力机制先通过1x1卷积降维减少计算量中间层处理核心特征这里是替换点最后1x1卷积恢复维度这种结构恰好与Transformer中的扩展-注意力-压缩流程相似。2. 实现核心MHSA模块让我们先构建最关键的MHSA层。与标准Transformer不同这里的实现需要处理2D特征图class MHSA(nn.Module): def __init__(self, n_dims, width14, height14, heads4): super().__init__() self.heads heads # 使用1x1卷积实现QKV投影 self.query nn.Conv2d(n_dims, n_dims, kernel_size1) self.key nn.Conv2d(n_dims, n_dims, kernel_size1) self.value nn.Conv2d(n_dims, n_dims, kernel_size1) # 相对位置编码参数 self.rel_h nn.Parameter(torch.randn([1, heads, n_dims//heads, 1, height])) self.rel_w nn.Parameter(torch.randn([1, heads, n_dims//heads, width, 1])) self.softmax nn.Softmax(dim-1) def forward(self, x): n_batch, C, width, height x.size() # 投影到QKV空间 q self.query(x).view(n_batch, self.heads, C//self.heads, -1) k self.key(x).view(n_batch, self.heads, C//self.heads, -1) v self.value(x).view(n_batch, self.heads, C//self.heads, -1) # 内容注意力 content_content torch.matmul(q.permute(0,1,3,2), k) # 位置注意力 content_position (self.rel_h self.rel_w).view(1, self.heads, C//self.heads, -1) content_position torch.matmul(content_position, q) # 合并注意力 energy content_content content_position attention self.softmax(energy) # 输出重构 out torch.matmul(v, attention.permute(0,1,3,2)) out out.view(n_batch, C, width, height) return out这段代码有几个关键设计点使用1x1卷积而非线性层实现QKV投影保留空间结构采用分解的相对位置编码将H×W的编码简化为(HW)形式注意力计算同时考虑内容相关性和位置相关性3. 改造Bottleneck模块现在我们可以改造标准的ResNet Bottleneck将中间的3x3卷积替换为MHSAclass Bottleneck(nn.Module): expansion 4 def __init__(self, in_planes, planes, stride1, heads4, mhsaFalse, resolutionNone): super().__init__() # 第一个1x1卷积降维 self.conv1 nn.Conv2d(in_planes, planes, kernel_size1, biasFalse) self.bn1 nn.BatchNorm2d(planes) # 核心修改点3x3卷积或MHSA if not mhsa: self.conv2 nn.Conv2d(planes, planes, kernel_size3, padding1, stridestride, biasFalse) else: self.conv2 nn.ModuleList([ MHSA(planes, widthint(resolution[0]), heightint(resolution[1]), headsheads) ]) if stride 2: # 处理下采样 self.conv2.append(nn.AvgPool2d(2, 2)) self.conv2 nn.Sequential(*self.conv2) self.bn2 nn.BatchNorm2d(planes) # 第三个1x1卷积升维 self.conv3 nn.Conv2d(planes, self.expansion*planes, kernel_size1, biasFalse) self.bn3 nn.BatchNorm2d(self.expansion*planes) # 捷径连接 self.shortcut nn.Sequential() if stride ! 1 or in_planes ! self.expansion*planes: self.shortcut nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size1, stridestride), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out F.relu(self.bn2(self.conv2(out))) out self.bn3(self.conv3(out)) out self.shortcut(x) out F.relu(out) return out改造时需特别注意MHSA不支持下采样需额外添加平均池化层保持原有的残差连接结构不变维持BatchNorm和ReLU的配置位置4. 构建完整BoTNet模型现在我们可以组装完整的BoTNet架构。通常只在最后几个阶段使用MHSA以平衡计算量和性能class BoTNet(nn.Module): def __init__(self, block, num_blocks, num_classes1000, resolution(224,224), heads4): super().__init__() self.in_planes 64 self.resolution list(resolution) # 初始卷积层 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.relu nn.ReLU(inplaceTrue) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 更新分辨率信息 for op in [self.conv1, self.maxpool]: if op.stride[0] 2: self.resolution[0] / 2 if len(op.stride) 1 and op.stride[1] 2: self.resolution[1] / 2 # 构建四个阶段 self.layer1 self._make_layer(block, 64, num_blocks[0], stride1) self.layer2 self._make_layer(block, 128, num_blocks[1], stride2) self.layer3 self._make_layer(block, 256, num_blocks[2], stride2) self.layer4 self._make_layer(block, 512, num_blocks[3], stride2, headsheads, mhsaTrue) # 分类头 self.avgpool nn.AdaptiveAvgPool2d((1,1)) self.fc nn.Sequential( nn.Dropout(0.3), nn.Linear(512*block.expansion, num_classes) ) def _make_layer(self, block, planes, num_blocks, stride1, heads4, mhsaFalse): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(block(self.in_planes, planes, stride, heads, mhsa, self.resolution)) if stride 2: self.resolution [r//2 for r in self.resolution] self.in_planes planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out self.relu(self.bn1(self.conv1(x))) out self.maxpool(out) out self.layer1(out) out self.layer2(out) out self.layer3(out) out self.layer4(out) out self.avgpool(out) out torch.flatten(out, 1) out self.fc(out) return out def BoTNet50(num_classes1000, resolution(224,224), heads4): return BoTNet(Bottleneck, [3,4,6,3], num_classesnum_classes, resolutionresolution, headsheads)关键设计选择仅在layer4最后阶段使用MHSA模块动态跟踪特征图分辨率变化保持与标准ResNet相同的宏观结构5. 训练技巧与性能对比将ResNet改造为BoTNet后训练策略需要相应调整学习率调整初始学习率可以比标准ResNet稍小约小2-5倍使用带warmup的学习率调度optimizer torch.optim.SGD(model.parameters(), lr0.01, momentum0.9) scheduler torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr0.01, steps_per_epochlen(train_loader), epochs100 )性能对比表格指标ResNet50BoTNet50变化幅度参数量(M)25.524.7↓3.1%FLOPs(G)4.15.8↑41.5%ImageNet Top-176.1%77.3%↑1.2%COCO mAP38.040.2↑2.2注意实际性能提升取决于具体任务和数据。在小规模数据集上可能需要减少MHSA的使用比例以避免过拟合。实际部署建议从最后阶段开始逐步替换先替换1个block观察效果对于小分辨率输入224x224可能不需要MHSA可以使用混合精度训练加速scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs model(inputs) loss criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()通过以上步骤我们完成了从ResNet到BoTNet的改造。实际项目中这种改造通常能在目标检测和语义分割任务中获得更显著的提升因为这类任务更需要全局上下文信息。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2540776.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!