Diffusion Models
Links: https://theaisummer.com/diffusion-models/
Markovian Hierachical VAE
rvs:
- data: x 0 x_{0} x0,
- representation: x T x_{T} xT
 
      
       
        
        
          ( 
         
        
          p 
         
        
          ( 
         
         
         
           x 
          
         
           0 
          
         
        
          , 
         
         
         
           x 
          
         
           1 
          
         
        
          , 
         
        
          ⋯ 
          
        
          , 
         
         
         
           x 
          
         
           T 
          
         
        
          ) 
         
        
          , 
         
        
          q 
         
        
          ( 
         
         
         
           x 
          
         
           1 
          
         
        
          , 
         
        
          ⋯ 
          
        
          , 
         
         
         
           x 
          
         
           T 
          
         
        
          ∣ 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          ) 
         
        
       
         (p(x_0,x_1,\cdots,x_T),q(x_1,\cdots,x_{T}|x_0)) 
        
       
     (p(x0,x1,⋯,xT),q(x1,⋯,xT∣x0))
 where  
     
      
       
        
        
          x 
         
        
          1 
         
        
       
         , 
        
       
         ⋯ 
         
       
         , 
        
        
        
          x 
         
        
          T 
         
        
       
      
        x_1,\cdots,x_T 
       
      
    x1,⋯,xT is unobservable, and
- generative model/backward trajectory:
 p ( x 0 , x 1 , ⋯ , x T ) = p ( x T ) ∏ t p ( x t − 1 ∣ x t ) p(x_0,x_1,\cdots,x_T)=p(x_T)\prod_tp(x_{t-1}|x_{t}) p(x0,x1,⋯,xT)=p(xT)t∏p(xt−1∣xt)
- forward trajectory(Markov process):
 q ( x 1 , ⋯ , x T ∣ x 0 ) ) = ∏ t q ( x t ∣ x t − 1 ) q(x_1,\cdots,x_{T}|x_0))=\prod_tq(x_{t}|x_{t-1}) q(x1,⋯,xT∣x0))=t∏q(xt∣xt−1)
E L B O : = ∫ q ( x T ∣ x 0 ) log  p ( x T ) q ( x T ∣ x 0 ) d x T + ∑ t = 2 T ∫ q ( x t − 1 , x t ∣ x 0 ) log  p ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) d x t − 1 x t + ∫ q ( x 1 ∣ x 0 ) log  p ( x 1 ∣ x 0 ) d x 1 ELBO:=\int q(x_{T}|x_{0}) \log \frac{p(x_{T})}{q(x_{T}|x_{0})}\mathrm{d}x_{T}\\ +\sum_{t=2}^T \int q(x_{t-1},x_{t}|x_{0})\log \frac{p(x_{t-1}|x_{t})}{q(x_{t-1}|x_{t}, x_{0})}\mathrm{d}x_{t-1}x_{t}\\+\int q(x_{1}|x_{0})\log p(x_{1}|x_{0})\mathrm{d}x_{1} ELBO:=∫q(xT∣x0)logq(xT∣x0)p(xT)dxT+t=2∑T∫q(xt−1,xt∣x0)logq(xt−1∣xt,x0)p(xt−1∣xt)dxt−1xt+∫q(x1∣x0)logp(x1∣x0)dx1
Loss
L o s s : = − E L B O = D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) + ∑ t = 2 T ∫ q ( x t ∣ x 0 ) d x t D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p ( x t − 1 ∣ x t ) ) − ∫ q ( x 1 ∣ x 0 ) log  p ( x 1 ∣ x 0 ) d x 1 Loss:=-ELBO= D_{KL} (q(x_{T}|x_{0})\| p(x_{T}))\\ +\sum_{t=2}^T \int q(x_{t}|x_{0})\mathrm{d}x_{t}D_{KL}(q(x_{t-1}|x_{t}, x_{0})\|p(x_{t-1}|x_{t}))\\-\int q(x_{1}|x_{0})\log p(x_{1}|x_{0})\mathrm{d}x_{1} Loss:=−ELBO=DKL(q(xT∣x0)∥p(xT))+t=2∑T∫q(xt∣x0)dxtDKL(q(xt−1∣xt,x0)∥p(xt−1∣xt))−∫q(x1∣x0)logp(x1∣x0)dx1
- prior matching term
- denoising matching term
- reconstruction term
Diffusion Models
basic assumption
- tractable distr: p ( x T ) p(x_{T}) p(xT)
- forward trajectory(Markov process): q ( x t ∣ x t − 1 ) q(x_{t}|x_{t-1}) q(xt∣xt−1) is fixed (has no unlearned parameter)
Definition(Diffusion Model)
- tractable distr: p ( x T ) ∼ N ( 0 , 1 ) p(x_{T})\sim N(0,1) p(xT)∼N(0,1)
- generative model/backward trajectory: p ( x t − 1 ∣ x t ) ∼ N ( μ ( t ) , Σ ( t ) ) p(x_{t-1}|x_{t})\sim N(\mu(t),\Sigma(t)) p(xt−1∣xt)∼N(μ(t),Σ(t))
- forward trajectory(Gaussian diffusion): q ( x t ∣ x t − 1 ) ∼ N ( x t − 1 1 − β t , β t ) q(x_{t}|x_{t-1})\sim N(x_{t-1}\sqrt{1-\beta_t},\beta_t) q(xt∣xt−1)∼N(xt−11−βt,βt),
Parameters:
- β t = 1 − α t \beta_t=1-\alpha_t βt=1−αt or α ˉ t : = ∏ t α t \bar{\alpha}_t:=\prod_t\alpha_t αˉt:=∏tαt: noise schedule, where α t \alpha_t αt is small
- α ˉ t \sqrt{\bar{\alpha}_t} αˉt: signal rate

Fact.
- q ( x t ∣ x 0 ) ∼ N ( x 0 α ˉ t , 1 − α ˉ t ) q(x_{t}|x_{0})\sim N(x_{0}\sqrt{\bar{\alpha}_t},1-\bar{\alpha}_t) q(xt∣x0)∼N(x0αˉt,1−αˉt)
-  
      
       
        
        
          q 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          ∼ 
         
        
          N 
         
        
          ( 
         
         
         
           μ 
          
         
           q 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          , 
         
         
         
           σ 
          
         
           2 
          
         
        
          ( 
         
        
          t 
         
        
          ) 
         
        
          ) 
         
        
       
         q(x_{t-1}|x_{t},x_{0})\sim N(\mu_q(x_t ,x_0),\sigma^2(t)) 
        
       
     q(xt−1∣xt,x0)∼N(μq(xt,x0),σ2(t)) where
 μ q ( x t , x 0 ) : = α t ( 1 − α ˉ t − 1 ) x t − α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t = 1 α t x t − β t 1 − α ˉ t α t ϵ 0 \mu_q(x_t,x_0):=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_t-\sqrt{\bar\alpha_{t-1}}(1-\alpha_{t})x_0}{1-\bar\alpha_t}\\ =\frac{1}{\sqrt{\alpha_t}}x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\epsilon_0 μq(xt,x0):=1−αˉtαt(1−αˉt−1)xt−αˉt−1(1−αt)x0=αt1xt−1−αˉtαtβtϵ0
 and σ 2 ( t ) : = 1 − α ˉ t − 1 1 − α ˉ t β t \sigma^2(t):=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_{t} σ2(t):=1−αˉt1−αˉt−1βt.
Design I:  
     
      
       
       
         p 
        
       
         ( 
        
        
        
          x 
         
         
         
           t 
          
         
           − 
          
         
           1 
          
         
        
       
         ∣ 
        
        
        
          x 
         
        
          t 
         
        
       
         ) 
        
       
         ∼ 
        
       
         N 
        
       
         ( 
        
       
         μ 
        
       
         ( 
        
       
         t 
        
       
         ) 
        
       
         , 
        
       
         Σ 
        
       
         ( 
        
       
         t 
        
       
         ) 
        
       
         ) 
        
       
      
        p(x_{t-1}|x_{t})\sim N(\mu(t),\Sigma(t)) 
       
      
    p(xt−1∣xt)∼N(μ(t),Σ(t)):
  
      
       
        
        
          μ 
         
        
          ( 
         
        
          t 
         
        
          ) 
         
        
          = 
         
         
          
           
            
            
              α 
             
            
              t 
             
            
           
          
            ( 
           
          
            1 
           
          
            − 
           
           
            
            
              α 
             
            
              ˉ 
             
            
            
            
              t 
             
            
              − 
             
            
              1 
             
            
           
          
            ) 
           
           
           
             x 
            
           
             t 
            
           
          
            − 
           
           
           
             β 
            
           
             t 
            
           
           
            
             
             
               α 
              
             
               ˉ 
              
             
             
             
               t 
              
             
               − 
              
             
               1 
              
             
            
           
           
           
             x 
            
           
             ^ 
            
           
          
            ( 
           
           
           
             x 
            
           
             t 
            
           
          
            , 
           
          
            t 
           
          
            ) 
           
          
          
          
            1 
           
          
            − 
           
           
            
            
              α 
             
            
              ˉ 
             
            
           
             t 
            
           
          
         
         
        
          Σ 
         
        
          ( 
         
        
          t 
         
        
          ) 
         
        
          = 
         
         
         
           σ 
          
         
           2 
          
         
        
          ( 
         
        
          t 
         
        
          ) 
         
        
       
         \mu(t)=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_t-\beta_{t}\sqrt{\bar\alpha_{t-1}}\hat{x}(x_t,t)}{1-\bar\alpha_t}\\ \Sigma(t)=\sigma^2(t) 
        
       
     μ(t)=1−αˉtαt(1−αˉt−1)xt−βtαˉt−1x^(xt,t)Σ(t)=σ2(t)
Design II:  
     
      
       
       
         p 
        
       
         ( 
        
        
        
          x 
         
         
         
           t 
          
         
           − 
          
         
           1 
          
         
        
       
         ∣ 
        
        
        
          x 
         
        
          t 
         
        
       
         ) 
        
       
         ∼ 
        
       
         N 
        
       
         ( 
        
       
         μ 
        
       
         ( 
        
       
         t 
        
       
         ) 
        
       
         , 
        
       
         Σ 
        
       
         ( 
        
       
         t 
        
       
         ) 
        
       
         ) 
        
       
      
        p(x_{t-1}|x_{t})\sim N(\mu(t),\Sigma(t)) 
       
      
    p(xt−1∣xt)∼N(μ(t),Σ(t)):
  
      
       
        
        
          μ 
         
        
          ( 
         
        
          t 
         
        
          ) 
         
        
          = 
         
         
         
           1 
          
          
           
           
             α 
            
           
             t 
            
           
          
         
         
         
           x 
          
         
           t 
          
         
        
          − 
         
         
          
          
            β 
           
          
            t 
           
          
          
           
            
            
              1 
             
            
              − 
             
             
              
              
                α 
               
              
                ˉ 
               
              
             
               t 
              
             
            
           
           
            
            
              α 
             
            
              t 
             
            
           
          
         
         
         
           ϵ 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
         
        
          Σ 
         
        
          ( 
         
        
          t 
         
        
          ) 
         
        
          = 
         
         
         
           σ 
          
         
           2 
          
         
        
          ( 
         
        
          t 
         
        
          ) 
         
        
       
         \mu(t)=\frac{1}{\sqrt{\alpha_t}}x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\hat{\epsilon}(x_t,t)\\ \Sigma(t)=\sigma^2(t) 
        
       
     μ(t)=αt1xt−1−αˉtαtβtϵ^(xt,t)Σ(t)=σ2(t)
Fact.
 Under the design I:
  
      
       
        
         
         
           D 
          
          
          
            K 
           
          
            L 
           
          
         
        
          ( 
         
        
          q 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          ∥ 
         
         
         
           p 
          
         
           θ 
          
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          ) 
         
        
          ) 
         
        
          = 
         
         
         
           1 
          
          
          
            2 
           
           
           
             σ 
            
           
             t 
            
           
             2 
            
           
          
         
         
          
          
            ( 
           
          
            1 
           
          
            − 
           
           
            
            
              α 
             
            
              ˉ 
             
            
            
            
              t 
             
            
              − 
             
            
              1 
             
            
           
          
            ) 
           
           
           
             β 
            
           
             t 
            
           
             2 
            
           
          
          
          
            ( 
           
          
            1 
           
          
            − 
           
           
            
            
              α 
             
            
              ˉ 
             
            
           
             t 
            
           
           
           
             ) 
            
           
             2 
            
           
          
         
        
          ∥ 
         
         
         
           x 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
        
          − 
         
         
         
           x 
          
         
           0 
          
         
         
         
           ∥ 
          
         
           2 
          
         
         
        
          = 
         
         
         
           1 
          
         
           2 
          
         
        
          ( 
         
         
         
           1 
          
          
          
            1 
           
          
            − 
           
           
            
            
              α 
             
            
              ˉ 
             
            
            
            
              t 
             
            
              − 
             
            
              1 
             
            
           
          
         
        
          − 
         
         
         
           1 
          
          
          
            1 
           
          
            − 
           
           
            
            
              α 
             
            
              ˉ 
             
            
           
             t 
            
           
          
         
        
          ) 
         
        
          ∥ 
         
         
         
           x 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
        
          − 
         
         
         
           x 
          
         
           0 
          
         
         
         
           ∥ 
          
         
           2 
          
         
        
       
         D_{KL} (q(x_{t−1}|x_t , x_0) \| p_θ (x_{t−1} |x_t))=\frac{1}{2\sigma_t^2}\frac{(1-\bar{\alpha}_{t-1})\beta_t^2}{(1-\bar{\alpha}_{t})^2}\|\hat{x}(x_t,t)-x_0\|^2\\ =\frac{1}{2}(\frac{1}{1-\bar{\alpha}_{t-1}}-\frac{1}{1-\bar{\alpha}_{t}})\|\hat{x}(x_t,t)-x_0\|^2 
        
       
     DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=2σt21(1−αˉt)2(1−αˉt−1)βt2∥x^(xt,t)−x0∥2=21(1−αˉt−11−1−αˉt1)∥x^(xt,t)−x0∥2
Under the design II:
  
      
       
        
         
         
           D 
          
          
          
            K 
           
          
            L 
           
          
         
        
          ( 
         
        
          q 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          ∥ 
         
         
         
           p 
          
         
           θ 
          
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          ) 
         
        
          ) 
         
        
          = 
         
         
         
           1 
          
          
          
            2 
           
           
           
             σ 
            
           
             t 
            
           
             2 
            
           
          
         
         
          
          
            β 
           
          
            t 
           
          
            2 
           
          
          
          
            ( 
           
          
            1 
           
          
            − 
           
           
            
            
              α 
             
            
              ˉ 
             
            
           
             t 
            
           
          
            ) 
           
           
           
             α 
            
           
             t 
            
           
             2 
            
           
          
         
        
          ∥ 
         
         
         
           ϵ 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
        
          − 
         
         
         
           ϵ 
          
         
           0 
          
         
         
         
           ∥ 
          
         
           2 
          
         
        
       
         D_{KL} (q(x_{t−1}|x_t , x_0) \| p_θ (x_{t−1} |x_t))=\frac{1}{2\sigma_t^2}\frac{\beta_t^2}{(1-\bar{\alpha}_{t})\alpha_t^2}\|\hat{\epsilon}(x_t,t)-\epsilon_0\|^2 
        
       
     DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=2σt21(1−αˉt)αt2βt2∥ϵ^(xt,t)−ϵ0∥2
Algorithm
Loss:
  
      
       
        
        
          L 
         
        
          = 
         
         
         
           ∑ 
          
         
           t 
          
         
         
         
           L 
          
         
           t 
          
         
         
         
         
           L 
          
         
           t 
          
         
        
          ≈ 
         
         
         
           ∑ 
          
          
          
            ϵ 
           
          
            ∼ 
           
          
            N 
           
          
            ( 
           
          
            0 
           
          
            , 
           
          
            1 
           
          
            ) 
           
          
         
        
          ∥ 
         
        
          ϵ 
         
        
          − 
         
         
         
           ϵ 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
        
          t 
         
        
          ) 
         
         
         
           ∥ 
          
         
           2 
          
         
        
          , 
         
        
          ( 
         
        
          0 
         
        
          ≤ 
         
        
          t 
         
        
          < 
         
        
          T 
         
        
          ) 
         
        
       
         L=\sum_t L_t\\ L_t\approx \sum_{\epsilon\sim N(0,1)}\|\epsilon-\hat{\epsilon}(x_{t},t)\|^2,(0\leq t<T) 
        
       
     L=t∑LtLt≈ϵ∼N(0,1)∑∥ϵ−ϵ^(xt,t)∥2,(0≤t<T)
 where  
     
      
       
        
        
          x 
         
        
          t 
         
        
       
         : 
        
       
         = 
        
        
         
          
          
            α 
           
          
            ˉ 
           
          
         
           t 
          
         
        
        
        
          x 
         
        
          0 
         
        
       
         + 
        
        
         
         
           1 
          
         
           − 
          
          
           
           
             α 
            
           
             ˉ 
            
           
          
            t 
           
          
         
        
       
         ϵ 
        
       
      
        x_{t}:=\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon 
       
      
    xt:=αˉtx0+1−αˉtϵ.
train NN ϵ ^ \hat\epsilon ϵ^ by data { ( ϵ ^ ( x t ( x 0 , i , ϵ i l ) , t ) , ϵ i l ) , ϵ i l ∼ N ( 0 , 1 ) , l = 1 , ⋯ , L } \{(\hat{\epsilon}(x_{t}(x_{0,i},\epsilon_{il}),t),\epsilon_{il}),\epsilon_{il}\sim N(0,1),l=1,\cdots, L\} {(ϵ^(xt(x0,i,ϵil),t),ϵil),ϵil∼N(0,1),l=1,⋯,L} with size of N L NL NL for each t t t。
Exercise
- Given a latent variable model  
      
       
        
        
          p 
         
        
          ( 
         
        
          x 
         
        
          , 
         
        
          z 
         
        
          ) 
         
        
       
         p(x,z) 
        
       
     p(x,z) with variational distr.  
      
       
        
        
          q 
         
        
          ( 
         
        
          z 
         
        
          ∣ 
         
        
          x 
         
        
          ) 
         
        
       
         q(z|x) 
        
       
     q(z∣x).  
      
       
        
        
          q 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         q(x) 
        
       
     q(x) represents data distr. and let  
      
       
        
        
          q 
         
        
          ( 
         
        
          x 
         
        
          , 
         
        
          z 
         
        
          ) 
         
        
          = 
         
        
          q 
         
        
          ( 
         
        
          z 
         
        
          ∣ 
         
        
          x 
         
        
          ) 
         
        
          q 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         q(x,z)=q(z|x)q(x) 
        
       
     q(x,z)=q(z∣x)q(x).
 ∫ q ( x ) L x = ∫ q ( x , z ) log  p ( x , z ) q ( z ∣ x ) ∼ D K L ( q ( x , z ) ∥ p ( x , z ) ) \int q(x)L_x=\int q(x,z)\log\frac{p(x,z)}{q(z|x)}\sim D_{KL}(q(x,z)\|p(x,z)) ∫q(x)Lx=∫q(x,z)logq(z∣x)p(x,z)∼DKL(q(x,z)∥p(x,z))
 where L x L_x Lx is LEBO.
References
- Jonathan Ho, Ajay Jain, Pieter Abbeel. Denoising Diffusion Probabilistic Models, 2020.
- Calvin Luo, Understanding Diffusion Models: A Unified Perspective, 2022



















