Pytorch图像去噪实战(十三):DDIM加速扩散模型采样,让去噪从1000步降到50步
Pytorch图像去噪实战十三DDIM加速扩散模型采样让去噪从1000步降到50步一、问题场景DDPM效果能看但采样实在太慢上一篇我们把 DDPM 图像去噪工程搭起来了。训练流程跑通后很快会遇到一个非常现实的问题采样太慢。DDPM一般需要从 T1000 一步步反向去噪x1000 - x999 - ... - x0如果只是做实验还可以接受。但在真实项目中比如用户上传图片实时去噪批量修复图片OCR预处理在线图片增强1000步采样基本不可接受。这时就需要 DDIM。二、DDIM解决什么问题DDIM的核心价值是用更少的采样步数完成近似去噪。比如把1000步减少到50步甚至20步虽然可能牺牲一点质量但速度提升非常明显。三、DDPM和DDIM的工程区别DDPM采样每一步都加入随机噪声随机反向过程DDIM可以使用确定性采样确定性反向过程这意味着采样更快结果更稳定可以跳步采样更适合工程部署四、项目结构ddim_denoise/ ├── diffusion/ │ ├── ddpm.py │ └── ddim.py ├── models/ │ └── unet.py ├── dataset.py ├── train.py ├── sample_ddpm.py └── sample_ddim.pyDDIM不需要重新训练模型可以复用DDPM训练好的噪声预测网络。五、DDIM采样器实现diffusion/ddim.pyimporttorchclassDDIMSampler:def__init__(self,ddpm,ddim_steps50):self.ddpmddpm self.ddim_stepsddim_steps self.time_stepstorch.linspace(ddpm.timesteps-1,0,ddim_steps).long().to(ddpm.device)torch.no_grad()defsample(self,model,shape):deviceself.ddpm.device xtorch.randn(shape).to(device)foriinrange(len(self.time_steps)-1):tself.time_steps[i]t_nextself.time_steps[i1]batch_ttorch.full((shape[0],),t,devicedevice,dtypetorch.long)pred_noisemodel(x,batch_t)alpha_bar_tself.ddpm.alpha_bars[t]alpha_bar_nextself.ddpm.alpha_bars[t_next]pred_x0(x-torch.sqrt(1-alpha_bar_t)*pred_noise)/torch.sqrt(alpha_bar_t)pred_x0torch.clamp(pred_x0,0.0,1.0)xtorch.sqrt(alpha_bar_next)*pred_x0torch.sqrt(1-alpha_bar_next)*pred_noisereturnx六、DDIM采样脚本sample_ddim.pyimporttorchimporttorchvision.utilsasvutilsfromconfigs.train_configimportTrainConfigfromdiffusion.ddpmimportDDPMfromdiffusion.ddimimportDDIMSamplerfrommodels.unetimportDDPMUNettorch.no_grad()defsample_ddim():cfgTrainConfig()devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelDDPMUNet(channelscfg.channels).to(device)model.load_state_dict(torch.load(checkpoints/ddpm_epoch_100.pth,map_locationdevice))model.eval()ddpmDDPM(timestepscfg.timesteps,beta_startcfg.beta_start,beta_endcfg.beta_end,devicedevice)samplerDDIMSampler(ddpm,ddim_steps50)samplessampler.sample(model,shape(16,cfg.channels,cfg.image_size,cfg.image_size))samplestorch.clamp(samples,0.0,1.0)vutils.save_image(samples.cpu(),ddim_samples.png,nrow4)if__name____main__:sample_ddim()七、为什么DDIM可以跳步DDPM严格按照马尔可夫链逐步反推。DDIM则使用一种非马尔可夫形式的采样路径。工程上可以这样理解DDIM不是每一步都重新随机采样而是根据当前预测的x0和噪声方向直接跳到更早的时间步。所以它可以从1000 - 999 - 998变成1000 - 980 - 960这就是速度提升的核心。八、采样步数怎么选实际建议快速预览ddim_steps20适合训练中间快速看效果。平衡质量和速度ddim_steps50这是比较常用的设置。更高质量ddim_steps100速度慢一些但质量更稳。九、加入eta控制随机性DDIM可以设置 eta 控制是否加入随机性。简化理解eta 0确定性采样eta 0加入随机性入门建议先用eta0因为结果更稳定方便对比实验。十、推理速度对比实际工程中采样速度差距非常明显。方法采样步数速度质量DDPM1000慢稳DDIM100快很多较稳DDIM50推荐平衡DDIM20很快略差十一、踩坑记录坑1time_steps顺序写反DDIM采样必须从大时间步到小时间步T - 0如果写成 0 到 T结果会完全错。坑2pred_x0不做clamp预测出的 x0 可能超出 0~1。建议pred_x0torch.clamp(pred_x0,0.0,1.0)否则容易出现过曝或发黑。坑3步数太少导致结构崩20步速度快但质量不一定稳定。建议先用50步作为默认值。十二、适合收藏总结DDIM加速流程训练DDPM噪声预测模型构建DDIMSampler从1000步中均匀选择少量时间步根据预测noise估计x0跳步完成采样避坑清单时间步顺序必须反向pred_x0建议clamp20步适合预览50步更稳DDIM不需要重新训练模型采样器要和DDPM参数一致十三、优化建议可以继续优化加eta参数使用非均匀时间步加EMA权重改进UNet结构用条件输入做真实图像去噪结尾总结DDIM解决的是扩散模型工程落地中最实际的问题DDPM质量可以但太慢。通过DDIM我们可以在不重新训练模型的情况下把采样速度提升一个数量级。如果你准备把Diffusion用于图像去噪项目DDIM几乎是必学内容。下一篇预告Pytorch图像去噪实战十四条件扩散模型图像去噪让Diffusion根据带噪图恢复干净图
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2569888.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!