用PyTorch复现SRCNN:三行代码理解深度学习超分的起点(附完整训练脚本)
用PyTorch复现SRCNN三行代码理解深度学习超分的起点附完整训练脚本当你第一次看到低分辨率的老照片时是否想过用技术手段让它重获新生这就是图像超分辨率技术的魅力所在。SRCNN作为深度学习在该领域的开山之作用仅三层的卷积网络架构开启了端到端学习的新范式。本文将带你从零开始用PyTorch完整复现这一经典模型通过代码级解析揭示其精妙设计。1. 环境准备与数据加载1.1 快速搭建PyTorch环境推荐使用conda创建专属Python环境避免依赖冲突conda create -n srcnn python3.8 conda activate srcnn pip install torch torchvision pillow matplotlib对于GPU加速用户建议安装CUDA 11.3对应的PyTorch版本。可以通过nvidia-smi查看显卡驱动版本然后到PyTorch官网获取对应安装命令。1.2 数据集处理技巧SRCNN原始论文使用91-image数据集但我们可以用更易获取的DIV2K数据集from torchvision import transforms train_transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean[0.4488], std[0.1953]) # 基于DIV2K的统计值 ]) class DIV2KDataset(Dataset): def __init__(self, hr_dir, scale2): self.hr_images [os.path.join(hr_dir, f) for f in os.listdir(hr_dir)] self.scale scale self.transform train_transform def __getitem__(self, idx): hr_img Image.open(self.hr_images[idx]).convert(YCbCr) w, h hr_img.size lr_img hr_img.resize((w//self.scale, h//self.scale), Image.BICUBIC) lr_img lr_img.resize((w, h), Image.BICUBIC) if self.transform: hr_img self.transform(hr_img.split()[0]) # 仅使用Y通道 lr_img self.transform(lr_img.split()[0]) return lr_img, hr_img注意Y通道包含主要的亮度信息对视觉质量影响最大因此超分任务通常只处理Y通道。2. 模型架构深度解析2.1 三卷积层的设计哲学SRCNN的精妙之处在于用三个卷积层对应传统方法的三个阶段网络层核尺寸通道数对应传统步骤conv19×964特征提取与表示conv25×532非线性特征映射conv35×51高分辨率重建用PyTorch实现仅需15行代码import torch.nn as nn class SRCNN(nn.Module): def __init__(self, in_channels1): super().__init__() self.feature_extraction nn.Sequential( nn.Conv2d(in_channels, 64, 9, padding4), nn.ReLU(inplaceTrue) ) self.nonlinear_mapping nn.Sequential( nn.Conv2d(64, 32, 5, padding2), nn.ReLU(inplaceTrue) ) self.reconstruction nn.Conv2d(32, in_channels, 5, padding2) def forward(self, x): x self.feature_extraction(x) x self.nonlinear_mapping(x) return self.reconstruction(x)2.2 关键参数的选择依据9×9大卷积核第一层需要捕获足够大的感受野来提取patch特征通道数递减64→32→1的通道设计符合特征提取→映射→重建的信息流无池化层保持空间分辨率不降低这对超分任务至关重要3. 训练策略与调参技巧3.1 损失函数的选择对比在超分任务中常用的损失函数有MSEL2损失criterion nn.MSELoss()优点训练稳定PSNR指标高缺点可能产生过度平滑的结果MAEL1损失criterion nn.L1Loss()优点保留更多高频细节缺点训练收敛较慢感知损失需预训练VGGvgg torchvision.models.vgg16(pretrainedTrue).features[:16] def perceptual_loss(pred, target): return F.mse_loss(vgg(pred), vgg(target))3.2 学习率动态调整实战采用分阶段学习率策略能显著提升模型性能optimizer torch.optim.Adam(model.parameters(), lr1e-4) scheduler torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones[50, 100], gamma0.1 ) for epoch in range(150): for lr, hr in dataloader: pred model(lr) loss criterion(pred, hr) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() print(fEpoch {epoch}: LR{scheduler.get_last_lr()[0]:.2e})提示初始学习率设为1e-4在50和100epoch时分别降低10倍4. 结果可视化与性能评估4.1 定量指标计算方法除了常用的PSNR和SSIM还可以计算def psnr(pred, target, max_val1.0): mse torch.mean((pred - target) ** 2) return 10 * torch.log10(max_val**2 / mse) def ssim(pred, target): # 使用官方实现或手动计算 return torchmetrics.functional.ssim(pred, target, data_range1.0)典型评估结果对比方法Set5 PSNRSet14 PSNR参数量Bicubic28.4226.00-SRCNN30.4827.5057KVDSR31.3528.02665K4.2 可视化对比技巧使用matplotlib制作专业对比图def plot_comparison(lr, hr, pred): plt.figure(figsize(12, 4)) plt.subplot(1, 3, 1) plt.imshow(lr[0].cpu().numpy(), cmapgray) plt.title(Low Resolution) plt.subplot(1, 3, 2) plt.imshow(pred[0].detach().cpu().numpy(), cmapgray) plt.title(SRCNN Output) plt.subplot(1, 3, 3) plt.imshow(hr[0].cpu().numpy(), cmapgray) plt.title(Ground Truth) plt.show()在实际测试中发现SRCNN对文字和边缘的重建效果尤为突出但在复杂纹理区域会出现轻微的模糊现象。这与其简单的网络结构有关也启示我们可以在后续改进中增加网络深度或引入注意力机制。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2527504.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!