保姆级教程:用PyTorch从零复现BIT变化检测模型(基于ResNet18+Transformer)
从零构建BIT变化检测模型基于PyTorch的遥感影像差异识别实战遥感影像变化检测一直是计算机视觉领域极具挑战性的任务。想象一下当你手头有两张同一区域不同时间拍摄的卫星图像如何快速准确地识别出新建的建筑物、消失的森林或是扩大的水域这正是BIT(Bitemporal Image Transformer)模型要解决的核心问题。不同于传统方法BIT创新性地结合了孪生网络与Transformer架构在保持高效计算的同时显著提升了变化检测的精度。1. 环境配置与基础准备在开始构建模型之前我们需要搭建合适的开发环境。推荐使用Python 3.8和PyTorch 1.10版本这些组合经过验证能够提供最佳的兼容性和性能表现。核心依赖安装pip install torch torchvision torchaudio pip install opencv-python numpy tqdm matplotlib对于GPU加速确保已安装对应版本的CUDA工具包。可以通过以下命令验证PyTorch是否正确识别了GPUimport torch print(torch.cuda.is_available()) # 应返回True print(torch.__version__) # 确认版本符合要求建议使用Anaconda创建独立环境避免依赖冲突。如果遇到维度不匹配等常见错误通常是由于PyTorch版本与CUDA版本不兼容导致可以通过重新安装指定版本解决。2. 数据准备与预处理变化检测模型的质量很大程度上取决于输入数据的质量。我们需要准备成对的遥感图像并确保它们已经过精确配准即两幅图像在空间上完全对齐。典型数据目录结构dataset/ ├── train/ │ ├── time1/ # 第一期影像 │ ├── time2/ # 第二期影像 │ └── label/ # 变化区域标注 └── test/ ├── time1/ ├── time2/ └── label/数据增强策略对提升模型泛化能力至关重要。以下是一个结合OpenCV的增强示例import cv2 import random def augment_pair(img1, img2, label): # 随机水平翻转 if random.random() 0.5: img1 cv2.flip(img1, 1) img2 cv2.flip(img2, 1) label cv2.flip(label, 1) # 随机旋转 angle random.uniform(-15, 15) h,w img1.shape[:2] M cv2.getRotationMatrix2D((w/2,h/2), angle, 1) img1 cv2.warpAffine(img1, M, (w,h)) img2 cv2.warpAffine(img2, M, (w,h)) label cv2.warpAffine(label, M, (w,h)) return img1, img2, label注意两期影像必须同步应用相同的增强变换否则会人为制造变化噪声。3. 改进型ResNet18骨干网络实现BIT采用改进的ResNet18作为特征提取骨干。与传统ResNet18相比主要做了两点关键修改取消最后两个stage的下采样保留更多空间信息添加双线性上采样和3x3卷积调整特征尺寸修改后的ResNet18实现关键代码import torch.nn as nn from torchvision.models import resnet18 class ModifiedResNet18(nn.Module): def __init__(self): super().__init__() original resnet18(pretrainedTrue) # 取前四个stage去除原第五个stage self.layer0 nn.Sequential( original.conv1, original.bn1, original.relu, original.maxpool) self.layer1 original.layer1 self.layer2 original.layer2 self.layer3 original.layer3 # 修改后的第四stage取消下采样 self.layer4 nn.Sequential( *[bottleneck for bottleneck in original.layer4[:1]], # 只取第一个bottleneck nn.Conv2d(256, 256, kernel_size3, padding1, stride1), # 取消下采样 nn.BatchNorm2d(256), nn.ReLU(inplaceTrue) ) # 上采样和调整卷积 self.upsample nn.Upsample(scale_factor2, modebilinear, align_cornersTrue) self.final_conv nn.Conv2d(256, 32, kernel_size3, padding1) def forward(self, x): x self.layer0(x) x self.layer1(x) x self.layer2(x) x self.layer3(x) x self.layer4(x) x self.upsample(x) x self.final_conv(x) return x实际应用中建议将这部分封装为孪生网络即两个图像共享同一组权重class SiameseBackbone(nn.Module): def __init__(self): super().__init__() self.backbone ModifiedResNet18() def forward(self, x1, x2): return self.backbone(x1), self.backbone(x2)4. Semantic Tokenizer设计与实现Semantic Tokenizer是BIT的核心创新之一它将特征图转换为紧凑的语义token。这个过程借鉴了NLP中tokenizer的思想但针对视觉任务进行了特殊设计。关键组件实现class SpatialAttention(nn.Module): 空间注意力机制用于聚焦重要区域 def __init__(self, in_channels): super().__init__() self.conv nn.Conv2d(in_channels, 1, kernel_size1) def forward(self, x): attn torch.sigmoid(self.conv(x)) # [B,1,H,W] return x * attn class SemanticTokenizer(nn.Module): def __init__(self, token_len4, in_channels32): super().__init__() self.token_len token_len self.spatial_attn SpatialAttention(in_channels) self.projection nn.Conv2d(in_channels, token_len, kernel_size1) def forward(self, x): # x形状: [B,C,H,W] x self.spatial_attn(x) # 空间注意力 tokens self.projection(x) # [B,token_len,H,W] tokens tokens.flatten(2).transpose(1,2) # [B,H*W,token_len] return tokens提示token_len是一个可调参数论文中默认使用4但在实际应用中可以根据任务复杂度适当增加。5. Transformer编码器-解码器架构BIT的Transformer部分采用标准的编码器-解码器结构但针对变化检测任务进行了特殊设计。多头注意力实现class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim embed_dim self.num_heads num_heads self.head_dim embed_dim // num_heads self.qkv nn.Linear(embed_dim, embed_dim*3) self.proj nn.Linear(embed_dim, embed_dim) def forward(self, x): B, N, C x.shape qkv self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) q, k, v qkv.unbind(2) # 各[B,N,num_heads,head_dim] attn (q k.transpose(-2,-1)) * (self.head_dim**-0.5) attn attn.softmax(dim-1) out (attn v).transpose(1,2).reshape(B, N, C) return self.proj(out)完整的Transformer编码器层class TransformerEncoderLayer(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.attn MultiHeadAttention(embed_dim, num_heads) self.norm1 nn.LayerNorm(embed_dim) self.mlp nn.Sequential( nn.Linear(embed_dim, embed_dim*4), nn.GELU(), nn.Linear(embed_dim*4, embed_dim) ) self.norm2 nn.LayerNorm(embed_dim) def forward(self, x): x x self.attn(self.norm1(x)) x x self.mlp(self.norm2(x)) return x6. 模型训练与调优技巧将上述组件组合成完整的BIT模型后我们需要设计合适的训练策略。变化检测任务通常面临严重的类别不平衡问题变化区域往往只占图像的很小部分。加权交叉熵损失实现class WeightedBCELoss(nn.Module): def __init__(self, pos_weight5.0): super().__init__() self.pos_weight pos_weight def forward(self, pred, target): loss - (self.pos_weight * target * torch.log(pred 1e-8) (1-target) * torch.log(1-pred 1e-8)) return loss.mean()训练循环关键代码def train_epoch(model, loader, optimizer, device): model.train() total_loss 0 criterion WeightedBCELoss() for x1, x2, y in loader: x1, x2, y x1.to(device), x2.to(device), y.to(device) optimizer.zero_grad() pred model(x1, x2) loss criterion(pred, y) loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(loader)注意学习率预热(learning rate warmup)对Transformer训练非常重要可以使用线性或余弦预热策略。7. 结果可视化与性能评估训练完成后我们需要对模型性能进行定量和定性评估。常用的评估指标包括指标名称计算公式意义精确率(Precision)TP/(TPFP)预测为变化的区域中真实变化的比例召回率(Recall)TP/(TPFN)真实变化区域中被正确预测的比例F1分数2PrecisionRecall/(PrecisionRecall)精确率和召回率的调和平均IoUTP/(TPFPFN)预测与真实变化区域的重叠度可视化对比函数import matplotlib.pyplot as plt def visualize_comparison(x1, x2, pred, gt): plt.figure(figsize(15,5)) plt.subplot(141) plt.imshow(x1) plt.title(Time 1) plt.subplot(142) plt.imshow(x2) plt.title(Time 2) plt.subplot(143) plt.imshow(pred 0.5, cmapgray) plt.title(Prediction) plt.subplot(144) plt.imshow(gt, cmapgray) plt.title(Ground Truth) plt.show()在实际项目中我们发现BIT模型在建筑物变化检测上表现尤为出色能够准确识别新建、拆除的建筑但对植被变化的敏感度相对较低。通过调整token长度和增加解码器层数可以进一步提升对细微变化的检测能力。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2520449.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!