文章目录
- KL散度
- 前向 vs 反向 KL
- 前向KL
- 反向KL
- 可视化
 
 
- 问题描述
- 变分推理
- ELBO: Evidence Lower Bound
- 参考
 
此篇博文主要介绍什么是变分推理(Variational Inference , VI),以及它的数学推导公式。变分推理,是机器学习中一种流行的方式,使用优化的技术估计复杂概率密度。变分推理的工作原理: 首先选择一系列概率密度函数,然后采用KL散度作为优化度量找到最接近于概率密度的函数。引入evidence lower bound的方法更容易计算近似概率。
KL散度
KL散度是两个分布之间的相对熵,量化概率分布 P ( X ) P \left( X \right) P(X)与候选分布 Q ( X ) Q\left( X \right) Q(X)的相似程度。对于一个离散的随机变量 X X X,概率分布 P P P和分布 Q Q Q之间的KL散度的计算公式如下定义:

其中 H ( P ) = − Σ x ∈ X P ( x ) l o g P ( x ) \mathbb{H}\left( P \right) = -\Sigma_{x \in X} P \left( x \right)log P \left( x \right) H(P)=−Σx∈XP(x)logP(x)是分布 P P P的熵, H ( P ) = − Σ x ∈ X P ( x ) l o g Q ( x ) \mathbb{H}\left( P \right) = -\Sigma_{x \in X} P\left( x \right)logQ\left( x \right) H(P)=−Σx∈XP(x)logQ(x)是分布 P P P和分布 Q Q Q的交叉熵。
KL散度具有如下性质:1. 非负性;2. 非对称性;3. 当KL散度的取值位于 ( 0 , ∞ ) (0,\infty) (0,∞),越接近于0,说明分布 P P P和分布 Q Q Q越匹配。
此外,概率分布 
     
      
       
       
         P 
        
       
      
        P 
       
      
    P和分布 
     
      
       
       
         Q 
        
       
      
        Q 
       
      
    Q之间的KL散度还可以表示为两个概率密度函数 
     
      
       
       
         p 
        
       
      
        p 
       
      
    p和 
     
      
       
       
         q 
        
       
      
        q 
       
      
    q之间对数差的期望。假设随机变量 
     
      
       
       
         x 
        
       
      
        x 
       
      
    x为概率分布函数 
     
      
       
       
         P 
        
       
      
        P 
       
      
    P的一个概率值, 
     
      
       
       
         E 
        
       
      
        \mathbb{E} 
       
      
    E为期望,那么KL公式还可如下定义:
 
前向 vs 反向 KL
KL散度是非对称的,那也就是说 D K L ( P ∥ Q ) ≠ D K L ( Q ∥ P ) D_{KL} \left( P \| Q \right) \neq D_{KL} \left( Q \| P \right) DKL(P∥Q)=DKL(Q∥P),因此根据分布 P P P和分布 Q Q Q的位置,可分为前向KL和后向KL。
前向KL
前向KL的公式定义如下。只要近似值不能够覆盖实际概率分布,KL散度将会变得很大,用公式表示就是 
     
      
       
        
         
         
           lim 
          
         
            
          
         
         
         
           q 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
           → 
          
         
           0 
          
         
        
        
         
         
           p 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
         
         
           q 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
        
       
         → 
        
       
         ∞ 
        
       
         , 
        
       
         p 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         > 
        
       
         0 
        
       
      
        \lim_{q\left(x\right) \to 0} \frac{p\left(x\right)}{q\left(x\right)} \rightarrow \infty , p\left(x\right) > 0 
       
      
    limq(x)→0q(x)p(x)→∞,p(x)>0,当 
     
      
       
       
         p 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         > 
        
       
         0 
        
       
         , 
        
       
         q 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         → 
        
       
         0 
        
       
      
        p\left(x\right) > 0, q\left(x\right) \to 0 
       
      
    p(x)>0,q(x)→0时, 
     
      
       
        
         
         
           p 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
         
         
           q 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
        
       
      
        \frac{p\left(x\right)}{q\left(x\right)} 
       
      
    q(x)p(x)的极限值将为 
     
      
       
       
         ∞ 
        
       
      
        \infty 
       
      
    ∞。因此,当 
     
      
       
       
         p 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         > 
        
       
         0 
        
       
      
        p\left(x\right) > 0 
       
      
    p(x)>0时,必须选择一个概率密度确保 
     
      
       
       
         q 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         > 
        
       
         0 
        
       
      
        q \left(x\right) > 0 
       
      
    q(x)>0。这种特殊的情况被称为"zero avoiding",直观理解就是 
     
      
       
       
         q 
        
       
      
        q 
       
      
    q高估 
     
      
       
       
         p 
        
       
      
        p 
       
      
    p。
 
反向KL
反向KL的公式定义如下,其中 
     
      
       
        
         
         
           lim 
          
         
            
          
         
         
         
           p 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
           → 
          
         
           0 
          
         
        
        
         
         
           q 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
         
         
           p 
          
          
          
            ( 
           
          
            x 
           
          
            ) 
           
          
         
        
       
         → 
        
       
         ∞ 
        
       
         , 
        
       
         q 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         > 
        
       
         0 
        
       
      
        \lim_{p \left(x\right) \to 0} \frac{q\left(x\right)}{p\left(x\right)} \rightarrow \infty , q\left(x\right) > 0 
       
      
    limp(x)→0p(x)q(x)→∞,q(x)>0,当 
     
      
       
       
         p 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         = 
        
       
         0 
        
       
      
        p \left(x\right) = 0 
       
      
    p(x)=0时,迫使 
     
      
       
       
         q 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         = 
        
       
         0 
        
       
      
        q \left(x\right) = 0 
       
      
    q(x)=0,不然KL散度值将会很大。这种被称为“zero forcing”,直观理解就是 
     
      
       
       
         q 
        
       
      
        q 
       
      
    q低估 
     
      
       
       
         p 
        
       
      
        p 
       
      
    p。
 
可视化
下图展示了双峰分布上的正向和反向KL散度。蓝色轮廓表示实际概率密度 p p p,红色轮廓表示单峰近似 q q q。左一显示正向KL散度最小化, q q q倾向于覆盖 p p p。中间和右一显示了反向KL散度最小化, q q q倾向于锁定到两种模式中的其中一个。

问题描述

假设有两个随机变量 X X X和 Z Z Z,其中 X X X为观测变量, Z Z Z为潜在变量。 X X X和 Z Z Z的关系如上图所示,观测变量 X X X依赖于潜在变量 Z Z Z,从 Z Z Z到 X X X的箭头表示条件概率密度 p ( X ∣ Z ) p\left( X | Z \right) p(X∣Z)。依据贝叶斯公式,可计算后验概率密度 p ( Z ∣ X ) p\left( Z| X \right) p(Z∣X)。
p ( Z ∣ X ) = p ( X ∣ Z ) p ( Z ) p ( X ) p\left( Z| X \right) = \frac{p\left(X|Z\right)p\left( Z \right)}{p\left(X\right)} p(Z∣X)=p(X)p(X∣Z)p(Z)
其中,分母 p ( X ) p\left( X \right) p(X)的计算公式为 p ( X ) = ∫ z ∈ Z p ( Z ∣ z ) p ( z ) d z p\left( X \right) = \int_{z \in Z} p \left( Z | z \right) p\left( z \right)dz p(X)=∫z∈Zp(Z∣z)p(z)dz, z z z为样本空间 Z Z Z中的一个实例。 p ( Z ) p\left( Z \right) p(Z)为先验,它捕获了 Z Z Z的先验信息。
观察的边缘概率密度(marginal probability density) p ( X ) p\left( X \right) p(X)被成为evidence,对于很多模型,evidence的积分依赖于所选模型,要么在闭合形式下不可用,要么需要指数时间计算。
变分推理的目的是为潜在变量的统计推断提供后验概率密度 p ( Z ∣ X ) p\left( Z| X \right) p(Z∣X)的近似解析,它从可处理的概率密度族中选择潜在变量 Z Z Z的概率密度函数 q q q解决近似问题。变分推理能够有效地计算边缘概率密度(或者evidence)的下界,其基本思想是:一个更高的边缘相似性指示所选统计模型更好地拟合观察到的数据。
变分推理
变分推理VI的目的是从可处理的概率密度族 Q \mathcal{Q} Q中选择一个近似的概率密度 q q q。潜在变量 Z Z Z的每一个在 Q \mathcal{Q} Q中的概率密度 q ( Z ) ∈ Q q\left( Z \right) \in \mathcal{Q} q(Z)∈Q都是后验的一个近似候选,VI的目的就是从这些候选中选择最优的那一个。依据KL散度的性质,两个分布的KL值越小,两个分布越匹配。假设近似概率密度于观测变量于观测变量条件不相关,那么推理问题就可以看作一个优化问题,公式如下所示。

优化上述公式,就可从所选的概率家族中得到后验的最佳近似值 q ∗ ( ⋅ ) q^{*}\left( \cdot \right) q∗(⋅),优化的复杂性取决于概率密度族的选择。计算上述公式中的KL散度,需要知道后验 P P P,但是后验的计算是棘手的。
一个替代的方案是用反向KL散度,这样后验和近似的平均交叉熵可以通过期望计算。因此上述公式可以重新被定义为如下公式。

 然而,由于仍然需要知道后验 
     
      
       
       
         P 
        
       
      
        P 
       
      
    P,优化反向KL仍然是不可行的。但是可以最小化一个等于它的函数直到一个常数,这就是evidence lower bound,ELBO。
ELBO: Evidence Lower Bound
设上述公式中的KL散度为 
     
      
       
       
         D 
        
       
      
        D 
       
      
    D,依据下述推导可得到ELBO的公式。
  
      
       
        
         
          
           
            
            
              D 
             
            
           
           
            
             
             
               = 
              
              
              
                D 
               
               
               
                 K 
                
               
                 L 
                
               
              
              
              
                ( 
               
              
                Q 
               
               
               
                 ( 
                
               
                 Z 
                
               
                 ) 
                
               
              
                ∥ 
               
              
                P 
               
               
               
                 ( 
                
               
                 Z 
                
               
                 ∣ 
                
               
                 X 
                
               
                 ) 
                
               
              
                ) 
               
              
             
               = 
              
              
              
                E 
               
               
               
                 z 
                
               
                 ∈ 
                
               
                 Q 
                
                
                
                  ( 
                 
                
                  Z 
                 
                
                  ) 
                 
                
               
              
             
               l 
              
             
               o 
              
             
               g 
              
              
               
               
                 q 
                
                
                
                  ( 
                 
                
                  z 
                 
                
                  ) 
                 
                
               
               
               
                 p 
                
                
                
                  ( 
                 
                
                  z 
                 
                
                  ∣ 
                 
                
                  x 
                 
                
                  ) 
                 
                
               
              
             
            
           
          
          
           
            
             
            
           
           
            
             
             
               = 
              
             
               E 
              
             
               [ 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               q 
              
              
              
                ( 
               
              
                z 
               
              
                ) 
               
              
             
               ] 
              
             
               − 
              
             
               E 
              
             
               [ 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               p 
              
              
              
                ( 
               
              
                z 
               
              
                ∣ 
               
              
                x 
               
              
                ) 
               
              
             
               ] 
              
              
              
                  
             
            
           
          
          
           
            
             
            
           
           
            
             
             
               = 
              
             
               E 
              
             
               [ 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               q 
              
              
              
                ( 
               
              
                z 
               
              
                ) 
               
              
             
               ] 
              
             
               − 
              
             
               E 
              
             
               [ 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               p 
              
              
              
                ( 
               
              
                z 
               
              
                , 
               
              
                x 
               
              
                ) 
               
              
             
               ] 
              
             
               + 
              
             
               E 
              
             
               [ 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               p 
              
              
              
                ( 
               
              
                x 
               
              
                ) 
               
              
             
               ] 
                    
             
            
           
          
          
           
            
             
            
           
           
            
             
             
               = 
              
             
               E 
              
             
               [ 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               q 
              
              
              
                ( 
               
              
                z 
               
              
                ) 
               
              
             
               ] 
              
             
               − 
              
             
               E 
              
             
               [ 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               q 
              
              
              
                ( 
               
              
                z 
               
              
                , 
               
              
                x 
               
              
                ) 
               
              
             
               ] 
              
             
               + 
              
             
               l 
              
             
               o 
              
             
               g 
              
             
               p 
              
              
              
                ( 
               
              
                x 
               
              
                ) 
               
              
              
             
            
           
          
         
         
        
          ⇓ 
         
         
        
          − 
         
        
          D 
         
        
          + 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          p 
         
         
         
           ( 
          
         
           x 
          
         
           ) 
          
         
        
          = 
         
        
          E 
         
        
          [ 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          q 
         
         
         
           ( 
          
         
           z 
          
         
           , 
          
         
           x 
          
         
           ) 
          
         
        
          ] 
         
        
          − 
         
        
          E 
         
        
          [ 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          q 
         
         
         
           ( 
          
         
           z 
          
         
           ) 
          
         
        
          ] 
         
        
          = 
         
        
          E 
         
        
          L 
         
        
          B 
         
        
          Q 
         
         
         
           ( 
          
         
           Q 
          
         
           ) 
          
         
        
       
         \begin{matrix} D &= D_{KL} \left( Q\left( Z \right) \| P \left( Z | X \right)\right) = \mathbb{E}_{z \in Q\left( Z \right) } log \frac{q\left( z \right)}{p \left( z | x \right)}\\ &= \mathbb{E} [ log q \left( z \right)] - \mathbb{E} [ log p \left( z | x \right)] \qquad \qquad \qquad \; \; \\ &= \mathbb{E} [ log q \left( z \right)] - \mathbb{E} [ log p \left( z , x \right)] + \mathbb{E} [ log p \left( x \right)] \; \; \; \\ &= \mathbb{E} [ log q \left( z \right)] - \mathbb{E} [ log q \left( z , x \right)] + log p \left( x \right) \qquad \end{matrix} \\ \Downarrow \\ -D + log p \left( x \right) = \mathbb{E} [ log q \left( z , x \right)] - \mathbb{E} [ log q \left( z \right)] = ELBQ\left( Q \right) 
        
       
     D=DKL(Q(Z)∥P(Z∣X))=Ez∈Q(Z)logp(z∣x)q(z)=E[logq(z)]−E[logp(z∣x)]=E[logq(z)]−E[logp(z,x)]+E[logp(x)]=E[logq(z)]−E[logq(z,x)]+logp(x)⇓−D+logp(x)=E[logq(z,x)]−E[logq(z)]=ELBQ(Q)
ELBO等于KL散度的负值于常量 
     
      
       
       
         l 
        
       
         o 
        
       
         g 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
      
        log\left(x \right) 
       
      
    log(x)的和。从上述公式可以看出,最大化ELBO等价于最小化KL散度。依据贝叶斯概率 
     
      
       
       
         p 
        
        
        
          ( 
         
        
          z 
         
        
          , 
         
        
          x 
         
        
          ) 
         
        
       
         = 
        
       
         p 
        
        
        
          ( 
         
        
          z 
         
        
          ) 
         
        
       
         ⋅ 
        
       
         p 
        
        
        
          ( 
         
        
          z 
         
        
          ∣ 
         
        
          x 
         
        
          ) 
         
        
       
         = 
        
       
         p 
        
        
        
          ( 
         
        
          x 
         
        
          ) 
         
        
       
         ⋅ 
        
       
         p 
        
        
        
          ( 
         
        
          x 
         
        
          ∣ 
         
        
          z 
         
        
          ) 
         
        
       
      
        p\left(z, x \right) = p\left(z \right) \cdot p\left(z | x \right) = p\left(x \right) \cdot p\left(x | z \right) 
       
      
    p(z,x)=p(z)⋅p(z∣x)=p(x)⋅p(x∣z),ELBO公式又可做如下推导。
 
 从上述公式可以看出,ELBO是数据的对数似然期望与先验和近似后验概率密度的KL散度之和。对数似然期望描述了所选统计模型与数据的拟合程度。KL散度促使变分概率密度接近于先验,因此,ELBO可看作对数据的正则拟合。
使用Jensen不等式( f ( E [ x ] ) ≥ E [ f ( X ) ] f\left( E[x] \right) \ge E[f\left( X \right)] f(E[x])≥E[f(X)])可推到出ELBO和 p ( x ) 的关系, p\left( x \right)的关系, p(x)的关系,ELBO值是要低于 l o g p ( x ) log p\left( x\right) logp(x)。问题描述中,我们也提到evidence的积分依赖于所选模型,要么在闭合形式下不可用,要么需要指数时间计算。ELBO和 l o g p ( x ) log p\left( x\right) logp(x)的这种关系,促使研究人员使用变分下界作为模型选择的标准。

参考
- An Introduction to Variational Inference
- Variational Inference



















