用PyTorch复现FCN语义分割:从VGG16预训练到FCN-8s实战,附完整代码与避坑指南
用PyTorch实现FCN-8s语义分割从VGG16迁移学习到工业级部署全流程当我们需要让计算机理解图像中每个像素的语义时传统的分类网络就显得力不从心了。想象一下自动驾驶汽车需要识别道路上的行人、车辆和交通标志或者医疗影像分析需要精确勾勒出肿瘤边界——这些场景都需要像素级的理解能力。全卷积网络FCN正是为解决这类问题而生的革命性架构。1. 环境准备与数据预处理在开始构建模型之前我们需要确保开发环境配置正确。推荐使用Python 3.8和PyTorch 1.10版本这些版本在稳定性和功能支持上都有良好表现。对于GPU加速CUDA 11.3是目前最兼容的版本。conda create -n fcn python3.8 conda activate fcn pip install torch torchvision opencv-python matplotlib1.1 数据集处理实战语义分割数据集通常包含原始图像和对应的标注掩码。以Cityscapes数据集为例我们需要特别处理标注图像class CityscapesDataset(Dataset): def __init__(self, root, splittrain, transformNone): self.images_dir os.path.join(root, leftImg8bit, split) self.targets_dir os.path.join(root, gtFine, split) self.transform transform self.images [] self.targets [] for city in os.listdir(self.images_dir): img_dir os.path.join(self.images_dir, city) target_dir os.path.join(self.targets_dir, city) for file_name in os.listdir(img_dir): self.images.append(os.path.join(img_dir, file_name)) target_name file_name.replace(leftImg8bit, gtFine_labelIds) self.targets.append(os.path.join(target_dir, target_name)) def __getitem__(self, index): image cv2.cvtColor(cv2.imread(self.images[index]), cv2.COLOR_BGR2RGB) target cv2.imread(self.targets[index], cv2.IMREAD_GRAYSCALE) if self.transform: augmented self.transform(imageimage, masktarget) image, target augmented[image], augmented[mask] return image, target注意处理标注时需要考虑类别不平衡问题。例如在城市街景中天空和道路像素可能远多于行人像素这会影响模型训练效果。2. VGG16骨干网络改造FCN-8s使用VGG16作为特征提取器但需要对其进行关键改造移除最后的全连接层保留卷积层和池化层的特征提取能力添加1x1卷积层替代原始分类头class VGG16FeatureExtractor(nn.Module): def __init__(self, pretrainedTrue): super().__init__() vgg models.vgg16(pretrainedpretrained).features self.slice1 nn.Sequential() self.slice2 nn.Sequential() self.slice3 nn.Sequential() self.slice4 nn.Sequential() self.slice5 nn.Sequential() for x in range(5): # conv1_2 self.slice1.add_module(str(x), vgg[x]) for x in range(5, 10): # conv2_2 self.slice2.add_module(str(x), vgg[x]) for x in range(10, 17): # conv3_3 self.slice3.add_module(str(x), vgg[x]) for x in range(17, 24): # conv4_3 self.slice4.add_module(str(x), vgg[x]) for x in range(24, 31): # conv5_3 self.slice5.add_module(str(x), vgg[x]) if pretrained: for param in self.parameters(): param.requires_grad False def forward(self, x): h self.slice1(x) h_relu1_2 h h self.slice2(h) h_relu2_2 h h self.slice3(h) h_relu3_3 h h self.slice4(h) h_relu4_3 h h self.slice5(h) h_relu5_3 h return h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_33. FCN-8s网络架构实现FCN-8s的核心创新在于多尺度特征融合32倍上采样路径直接从conv7输出上采样16倍上采样路径融合pool4特征8倍上采样路径进一步融合pool3特征class FCN8s(nn.Module): def __init__(self, n_class21): super().__init__() self.features VGG16FeatureExtractor(pretrainedTrue) # 1x1卷积替代全连接 self.conv6 nn.Conv2d(512, 4096, kernel_size1) self.drop6 nn.Dropout2d() self.conv7 nn.Conv2d(4096, 4096, kernel_size1) self.drop7 nn.Dropout2d() self.score_fr nn.Conv2d(4096, n_class, kernel_size1) # 跳级连接 self.score_pool4 nn.Conv2d(512, n_class, kernel_size1) self.score_pool3 nn.Conv2d(256, n_class, kernel_size1) # 上采样 self.upscore2 nn.ConvTranspose2d( n_class, n_class, kernel_size4, stride2, padding1) self.upscore8 nn.ConvTranspose2d( n_class, n_class, kernel_size16, stride8, padding4) self.upscore_pool4 nn.ConvTranspose2d( n_class, n_class, kernel_size4, stride2, padding1) def forward(self, x): _, _, h_pool3, h_pool4, h_pool5 self.features(x) # 主路径处理 h self.drop6(F.relu(self.conv6(h_pool5))) h self.drop7(F.relu(self.conv7(h))) h self.score_fr(h) h self.upscore2(h) # 16x上采样 # 融合pool4特征 upscore_pool4 self.score_pool4(h_pool4) h h[:, :, 1:1upscore_pool4.size(2), 1:1upscore_pool4.size(3)] h h upscore_pool4 h self.upscore_pool4(h) # 8x上采样 # 融合pool3特征 upscore_pool3 self.score_pool3(h_pool3) h h[:, :, 1:1upscore_pool3.size(2), 1:1upscore_pool3.size(3)] h h upscore_pool3 return self.upscore8(h) # 最终8x上采样提示特征融合时要注意张量尺寸对齐。FCN论文中采用crop方式处理边界不匹配问题这是实现时容易出错的关键点。4. 训练策略与性能优化4.1 损失函数选择语义分割常用的损失函数对比损失函数优点缺点适用场景CrossEntropy分类标准选择忽略类别不平衡均衡数据集DiceLoss处理类别不平衡训练不稳定医学图像FocalLoss关注难样本超参敏感目标检测Lovász-Softmax直接优化mIoU计算复杂竞赛场景对于城市街景分割推荐使用组合损失class MixedLoss(nn.Module): def __init__(self, alpha0.5): super().__init__() self.alpha alpha self.ce nn.CrossEntropyLoss() self.dice DiceLoss() def forward(self, pred, target): return self.alpha * self.ce(pred, target) (1-self.alpha) * self.dice(pred, target)4.2 学习率调度策略采用warmup余弦退火组合策略def get_lr_scheduler(optimizer, n_iter_per_epoch, args): def lr_lambda(current_iter): # Warmup阶段 if current_iter args.warmup_epochs * n_iter_per_epoch: return float(current_iter) / float(max(1, args.warmup_epochs * n_iter_per_epoch)) # 余弦退火阶段 return 0.5 * (1. math.cos(math.pi * (current_iter - args.warmup_epochs * n_iter_per_epoch) / ((args.epochs - args.warmup_epochs) * n_iter_per_epoch))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)4.3 训练过程可视化使用TensorBoard记录关键指标writer SummaryWriter(log_dirruns/fcn8s_experiment) for epoch in range(epochs): model.train() for i, (images, masks) in enumerate(train_loader): outputs model(images) loss criterion(outputs, masks) optimizer.zero_grad() loss.backward() optimizer.step() # 记录训练指标 writer.add_scalar(Loss/train, loss.item(), epoch*len(train_loader)i) # 验证集评估 if i % 100 0: model.eval() val_loss 0 mIoU 0 with torch.no_grad(): for val_images, val_masks in val_loader: val_outputs model(val_images) val_loss criterion(val_outputs, val_masks).item() mIoU mean_iou(val_outputs, val_masks) writer.add_scalar(Loss/val, val_loss/len(val_loader), epoch*len(train_loader)i) writer.add_scalar(mIoU/val, mIoU/len(val_loader), epoch*len(train_loader)i) model.train()5. 模型部署与性能优化5.1 模型量化加速使用PyTorch的量化工具减小模型体积model FCN8s(n_class21).eval() # 量化配置 model.qconfig torch.quantization.get_default_qconfig(fbgemm) quantized_model torch.quantization.prepare(model, inplaceFalse) quantized_model torch.quantization.convert(quantized_model, inplaceFalse) # 测试量化效果 with torch.no_grad(): quantized_output quantized_model(torch.rand(1,3,512,512)) print(f量化模型输出尺寸: {quantized_output.shape})5.2 ONNX格式导出dummy_input torch.randn(1, 3, 512, 512) torch.onnx.export(model, dummy_input, fcn8s.onnx, input_names[input], output_names[output], dynamic_axes{input: {0: batch}, output: {0: batch}})5.3 TensorRT优化# 使用trtexec转换ONNX到TensorRT引擎 trtexec --onnxfcn8s.onnx --saveEnginefcn8s.engine --fp16 --workspace20486. 实际应用中的挑战与解决方案6.1 小目标分割难题当处理小物体时FCN-8s可能表现不佳。解决方案包括多尺度训练在训练时随机缩放输入图像注意力机制在跳级连接处添加注意力模块高分辨率分支保留更多底层特征class AttentionBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.query nn.Conv2d(in_channels, in_channels//8, kernel_size1) self.key nn.Conv2d(in_channels, in_channels//8, kernel_size1) self.value nn.Conv2d(in_channels, in_channels, kernel_size1) self.gamma nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, H, W x.size() query self.query(x).view(batch_size, -1, H*W).permute(0,2,1) key self.key(x).view(batch_size, -1, H*W) energy torch.bmm(query, key) attention F.softmax(energy, dim-1) value self.value(x).view(batch_size, -1, H*W) out torch.bmm(value, attention.permute(0,2,1)) out out.view(batch_size, C, H, W) return self.gamma * out x6.2 实时性优化对于实时应用可以通过以下方式优化通道剪枝移除不重要的卷积通道知识蒸馏用大模型指导小模型训练架构搜索自动寻找高效网络结构def channel_prune(model, prune_percent0.3): parameters_to_prune [] for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): parameters_to_prune.append((module, weight)) prune.global_unstructured( parameters_to_prune, pruning_methodprune.L1Unstructured, amountprune_percent, ) for module, _ in parameters_to_prune: prune.remove(module, weight) return model7. 进阶技巧与最新改进7.1 深度监督训练在中间层添加辅助损失函数class FCN8sWithDS(nn.Module): def __init__(self, n_class21): super().__init__() # ...原有初始化代码... # 深度监督分支 self.ds_conv1 nn.Conv2d(256, n_class, kernel_size1) self.ds_conv2 nn.Conv2d(512, n_class, kernel_size1) def forward(self, x): _, _, h_pool3, h_pool4, h_pool5 self.features(x) # 深度监督输出 ds1 self.ds_conv1(h_pool3) ds2 self.ds_conv2(h_pool4) # 主路径处理... return main_output, ds1, ds27.2 自注意力增强class SelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.query nn.Conv2d(in_channels, in_channels//8, 1) self.key nn.Conv2d(in_channels, in_channels//8, 1) self.value nn.Conv2d(in_channels, in_channels, 1) self.gamma nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, H, W x.size() proj_query self.query(x).view(batch_size, -1, H*W).permute(0,2,1) proj_key self.key(x).view(batch_size, -1, H*W) energy torch.bmm(proj_query, proj_key) attention F.softmax(energy, dim-1) proj_value self.value(x).view(batch_size, -1, H*W) out torch.bmm(proj_value, attention.permute(0,2,1)) out out.view(batch_size, C, H, W) return self.gamma * out x
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2543172.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!