简介
主页:https://github.com/Algolzw/image-restoration-sde
扩散模型终于在去噪、超分辨率等应用了。
这是一种基于随机微分方程的通用图像恢复方法,关键结构包括均值还原SDE,该SDE将高质量图像转换为具有固定高斯噪声的平均状态的降级对应图像,通过模拟相应的逆时SDE,可以在不依赖任何特定任务的先验知识的情况下恢复低质量图像的原点,所提出的均值回归SDE具有封闭形式的解决方案,允许计算真实时间相关分数并使用神经网络学习它
贡献点
- 提出了一种通用的图像恢复方法,使用均值恢复SDE直接模拟图像退化过程。公式有一个封闭形式的解决方案,使能够计算地面真值与时间相关的分数函数,并训练神经网络来估计它。
- 提出了一个简单的替代损失函数来训练神经网络,基于最大化逆时间轨迹的可能性。与普通分数匹配目标相比,损失被证明可以稳定训练并持续提高图像恢复性能。
- 通过将其应用于六种不同的图像恢复任务来证明提出的方法的一般适用性:图像去除、去模糊、去噪、超分辨率、上漆和去雾。
- 在图像去噪、去模糊和去噪的定量比较中实现了极具竞争力的恢复性能,在两个去噪数据集上达到了新的水平。
实现流程
背景知识-SDE
前向过程
 
 𝑓和𝑔分别为drift函数和dispersion函数,𝑤为标准Wiener过程, 
     
      
       
       
         x 
        
       
         ( 
        
       
         0 
        
       
         ) 
        
       
         ∈ 
        
        
        
          R 
         
        
          d 
         
        
       
      
        x(0)\in R^d 
       
      
    x(0)∈Rd为初始条件
 逆向过程
 
  
     
      
       
        
        
          w 
         
        
          ^ 
         
        
       
      
        \hat{w} 
       
      
    w^为逆时Wiener过程, 
     
      
       
        
        
          p 
         
        
          t 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        p_t(x) 
       
      
    pt(x)为x(t)在𝑡时刻的边际概率密度函数。 
     
      
       
        
        
          ∇ 
         
        
          x 
         
        
       
         l 
        
       
         o 
        
       
         g 
        
        
        
          p 
         
        
          t 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \nabla_x log p_t(x) 
       
      
    ∇xlogpt(x)
方法

 使用均值回归SDE 
     
      
       
       
         d 
        
       
         x 
        
       
         = 
        
        
        
          θ 
         
        
          t 
         
        
       
         ( 
        
       
         μ 
        
       
         − 
        
       
         x 
        
       
         ) 
        
       
         d 
        
       
         t 
        
       
         + 
        
        
        
          σ 
         
        
          t 
         
        
       
      
        dx=\theta_t(\mu-x)dt+\sigma_t 
       
      
    dx=θt(μ−x)dt+σt进行图像恢复。SDE模型通过向低质量图像的带噪版本的扩散,来模拟从高质量图像的 
     
      
       
       
         x 
        
       
         ( 
        
       
         0 
        
       
         ) 
        
       
      
        x(0) 
       
      
    x(0)到低质量图像的 
     
      
       
       
         μ 
        
       
         + 
        
       
         ϵ 
        
       
      
        \mu+\epsilon 
       
      
    μ+ϵ的退化过程。通过模拟相应的逆时SDE,可以恢复高质量的图像。
核心思想是将均值恢复SDE与神经网络训练的最大似然目标相结合
分数函数定义为:
 
 其中 
     
      
       
        
        
          θ 
         
        
          t 
         
        
       
      
        \theta_t 
       
      
    θt和 
     
      
       
        
        
          σ 
         
        
          t 
         
        
       
      
        \sigma_t 
       
      
    σt是时间相关的正参数,分别表征平均回归和随机波动的速度
为了进行图像退化,分别让 x ( 0 ) x(0) x(0)和 μ \mu μ为地面真实高质量(HQ)图像和其退化的低质量(LQ)对应图像。
为了使的SDE(3)具有封闭形式的解,设置 σ t 2 / θ t = 2 λ 2 \sigma^2_t / \theta_t = 2\lambda^2 σt2/θt=2λ2,其中 λ 2 \lambda^2 λ2为平稳方差。
假设任意时刻t都满足  
     
      
       
        
        
          σ 
         
        
          t 
         
        
          2 
         
        
       
         / 
        
        
        
          θ 
         
        
          t 
         
        
       
         = 
        
       
         2 
        
        
        
          λ 
         
        
          2 
         
        
       
      
        \sigma_t^2/\theta_t = 2\lambda^2 
       
      
    σt2/θt=2λ2,给定任意时刻 x(s),其中s<t,有
 
 其中  
     
      
       
        
         
         
           θ 
          
         
           ˉ 
          
         
         
         
           s 
          
         
           : 
          
         
           t 
          
         
        
       
         : 
        
       
         = 
        
        
        
          ∫ 
         
        
          s 
         
        
          t 
         
        
        
        
          θ 
         
        
          z 
         
        
       
         d 
        
       
         z 
        
       
      
        \bar{\theta}_{s:t} := \int^t_s\theta_zdz 
       
      
    θˉs:t:=∫stθzdz是已知的,过度内核  
     
      
       
       
         p 
        
       
         ( 
        
       
         x 
        
       
         ( 
        
       
         t 
        
       
         ) 
        
       
         ∣ 
        
       
         x 
        
       
         ( 
        
       
         s 
        
       
         ) 
        
       
         ) 
        
       
         = 
        
       
         N 
        
       
         ( 
        
       
         x 
        
       
         ( 
        
       
         t 
        
       
         ) 
        
       
         ∣ 
        
        
        
          m 
         
         
         
           s 
          
         
           : 
          
         
           t 
          
         
        
       
         ( 
        
       
         x 
        
       
         ( 
        
       
         s 
        
       
         ) 
        
       
         ) 
        
       
         , 
        
        
        
          v 
         
         
         
           s 
          
         
           : 
          
         
           t 
          
         
        
       
         ) 
        
       
      
        p(x(t)|x(s)) = N(x(t)|m_{s:t}(x(s)),v_{s:t}) 
       
      
    p(x(t)∣x(s))=N(x(t)∣ms:t(x(s)),vs:t)是高斯分布,均值  
     
      
       
        
        
          m 
         
         
         
           s 
          
         
           : 
          
         
           t 
          
         
        
       
      
        m_{s:t} 
       
      
    ms:t 方差  
     
      
       
        
        
          v 
         
         
         
           s 
          
         
           : 
          
         
           t 
          
         
        
       
      
        v_{s:t} 
       
      
    vs:t表示为
 
 任意时刻的边缘分布
 
 当t-> 
     
      
       
       
         ∞ 
        
       
      
        \infty 
       
      
    ∞,均值 
     
      
       
        
        
          m 
         
        
          t 
         
        
       
      
        m_t 
       
      
    mt收敛于低质量图像 
     
      
       
       
         μ 
        
       
      
        \mu 
       
      
    μ,方差 
     
      
       
        
        
          v 
         
        
          t 
         
        
       
      
        v_t 
       
      
    vt收敛于平稳方差 
     
      
       
        
        
          λ 
         
        
          2 
         
        
       
      
        \lambda^2 
       
      
    λ2
前向SDE(3)将高质量图像扩散为高斯噪声固定的低质量图像
反向SDE可以推导为:
 
 唯一不知道的是分数函数  
     
      
       
        
        
          ∇ 
         
        
          x 
         
        
       
         l 
        
       
         o 
        
       
         g 
        
        
        
          p 
         
        
          t 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \nabla_xlog p_t(x) 
       
      
    ∇xlogpt(x)
由于在训练期间可以获得地面真实高质量图像 
     
      
       
       
         x 
        
       
         ( 
        
       
         0 
        
       
         ) 
        
       
      
        x(0) 
       
      
    x(0),因此可以训练神经网络来估计分数 
     
      
       
        
        
          ∇ 
         
        
          x 
         
        
       
         l 
        
       
         o 
        
       
         g 
        
        
        
          p 
         
        
          t 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \nabla_xlog p_t(x) 
       
      
    ∇xlogpt(x)。具体来说,可以使用(6)来计算地面真值得分为
 
 如果使  
     
      
       
       
         x 
        
       
         ( 
        
       
         t 
        
       
         ) 
        
       
         = 
        
        
        
          m 
         
        
          t 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
         + 
        
        
         
         
           v 
          
         
           t 
          
         
        
        
        
          ϵ 
         
        
          t 
         
        
       
      
        x(t) = m_t(x)+\sqrt{v_t}\epsilon_t 
       
      
    x(t)=mt(x)+vtϵt,其中  
     
      
       
        
        
          ϵ 
         
        
          t 
         
        
       
      
        \epsilon_t 
       
      
    ϵt 是标准高斯噪声  
     
      
       
        
        
          ϵ 
         
        
          t 
         
        
       
           
        
       
         N 
        
       
         ( 
        
       
         0 
        
       
         , 
        
       
         I 
        
       
         ) 
        
       
      
        \epsilon_t ~ N(0,I) 
       
      
    ϵt N(0,I)
 
 使用一个条件时变神经网络,它将状态和时间作为输入和输出纯噪声
 
  
     
      
       
        
        
          γ 
         
        
          1 
         
        
       
         , 
        
        
        
          γ 
         
        
          2 
         
        
       
         , 
        
       
         ⋯ 
         
       
         , 
        
        
        
          γ 
         
        
          T 
         
        
       
      
        \gamma_1,\gamma_2,\cdots,\gamma_T 
       
      
    γ1,γ2,⋯,γT是positive weights, 
     
      
       
       
         { 
        
        
        
          x 
         
        
          i 
         
        
        
        
          } 
         
         
         
           i 
          
         
           = 
          
         
           0 
          
         
        
          T 
         
        
       
      
        \{x_i\}^T_{i=0} 
       
      
    {xi}i=0T表示扩散过程的离散化
当应用于图像恢复中遇到的复杂退化时,训练往往变得不稳定,这源于试图学习给定时间的瞬时噪声。
在给定高质量图像𝑥0的情况下,试图找到最优轨迹 
     
      
       
        
        
          x 
         
         
         
           1 
          
         
           : 
          
         
           T 
          
         
        
       
      
        x_{1:T} 
       
      
    x1:T,最小化似然函数  
     
      
       
       
         p 
        
       
         ( 
        
        
        
          x 
         
         
         
           1 
          
         
           : 
          
         
           T 
          
         
        
       
         ∣ 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
      
        p(x_{1:T}|x_0) 
       
      
    p(x1:T∣x0)
 
 其中 
     
      
       
       
         p 
        
       
         ( 
        
        
        
          x 
         
        
          T 
         
        
       
         ∣ 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
         = 
        
       
         N 
        
       
         ( 
        
        
        
          x 
         
        
          T 
         
        
       
         : 
        
        
        
          m 
         
        
          T 
         
        
       
         ( 
        
        
        
          x 
         
        
          0 
         
        
       
         ) 
        
       
         , 
        
        
        
          v 
         
        
          T 
         
        
       
         ) 
        
       
      
        p(x_T|x_0) = N(x_T:m_T(x_0),v_T) 
       
      
    p(xT∣x0)=N(xT:mT(x0),vT)是低质量图像
 
 最小化负对数似然的最优反向状态
 
 给定初始状态  
     
      
       
        
        
          x 
         
        
          0 
         
        
       
      
        x_0 
       
      
    x0,任意时刻状态  
     
      
       
        
        
          x 
         
        
          i 
         
        
       
      
        x_i 
       
      
    xi,i>0,从  
     
      
       
        
        
          x 
         
        
          i 
         
        
       
         − 
        
       
         > 
        
        
        
          x 
         
         
         
           i 
          
         
           − 
          
         
           1 
          
         
        
       
      
        x_i->x_{i-1} 
       
      
    xi−>xi−1的IR-SDE表示为

 优化噪声网络,使IR-SDE反向为最优轨迹,即
 
  
     
      
       
       
         ( 
        
       
         d 
        
       
         x 
        
        
        
          ) 
         
         
          
          
            ϵ 
           
          
            ˉ 
           
          
         
           Φ 
          
         
        
       
      
        (dx)_{\bar{\epsilon}\Phi} 
       
      
    (dx)ϵˉΦ 表示 公式7的反向SDE,分数由噪声模型 
     
      
       
        
         
         
           ϵ 
          
         
           ˉ 
          
         
        
          Φ 
         
        
       
      
        \bar{\epsilon}_\Phi 
       
      
    ϵˉΦ得到
其中期望 ∫ 0 t σ s d w ^ ( s ) \int_0^t\sigma_sd\hat{w}(s) ∫0tσsdw^(s)为零,那么只需要考虑 ( d x ) ϵ ^ Φ (dx)_{\hat{\epsilon}_\Phi} (dx)ϵ^Φdirft 部分
实验

 
 
 
 



















