文章目录
- 12 VI——变分推断
- 12.1 背景介绍
- 12.2 Classical VI
- 12.2.1 公式导出
- 12.2.2 坐标上升法
 
- 12.3 SGVI——随机梯度变分推断
- 12.3.1 一般化MC方法
- 12.3.2 降方差——Variance Reduction
 
 
12 VI——变分推断
12.1 背景介绍
变分推断的作用就是在概率图模型中进行参数估计,是参数估计的一种确定性近似的方法。下图给出了VI在机器学习中的地位:
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d8OyecRa-1686303181209)(assets/12 VI——变分推断/image-20230606193325497.png)]](https://img-blog.csdnimg.cn/d390c07847584f7595a446ac865db268.png)
12.2 Classical VI
12.2.1 公式导出
首先第一个问题,变分推断中的变分是什么?我们曾在EM算法的公式导出中得到过这样一个公式:
  
      
       
        
        
          log 
         
        
           
         
        
          P 
         
        
          ( 
         
        
          X 
         
        
          ) 
         
        
          = 
         
         
         
           ( 
          
          
          
            ∫ 
           
          
            Z 
           
          
         
           q 
          
         
           ( 
          
         
           Z 
          
         
           ) 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             ( 
            
           
             X 
            
           
             , 
            
           
             Z 
            
           
             ) 
            
           
           
           
             q 
            
           
             ( 
            
           
             Z 
            
           
             ) 
            
           
          
         
           d 
          
         
           Z 
          
         
           ) 
          
         
        
          + 
         
         
         
           ( 
          
         
           − 
          
          
          
            ∫ 
           
          
            Z 
           
          
         
           q 
          
         
           ( 
          
         
           Z 
          
         
           ) 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             ( 
            
           
             Z 
            
           
             ∣ 
            
           
             X 
            
           
             ) 
            
           
           
           
             q 
            
           
             ( 
            
           
             Z 
            
           
             ) 
            
           
          
         
           d 
          
         
           Z 
          
         
           ) 
          
         
        
       
         \log P(X) = \left( \int_Z q(Z) \log \frac{P(X, Z)}{q(Z)} {\rm d}Z \right) + \left( -\int_Z q(Z) \log \frac{P(Z|X)}{q(Z)} {\rm d}Z \right) 
        
       
     logP(X)=(∫Zq(Z)logq(Z)P(X,Z)dZ)+(−∫Zq(Z)logq(Z)P(Z∣X)dZ)
 其中前半部分被叫做ELBO(Evidence Lower Bound),后半部分是KL公式,所以可以简化写成:
  
      
       
        
        
          log 
         
        
           
         
        
          P 
         
        
          ( 
         
        
          X 
         
        
          ) 
         
        
          = 
         
        
          L 
         
        
          ( 
         
        
          q 
         
        
          ) 
         
        
          + 
         
        
          K 
         
        
          L 
         
        
          ( 
         
        
          q 
         
        
          ∣ 
         
        
          ∣ 
         
        
          p 
         
        
          ) 
         
        
          , 
         
         
        
          L 
         
        
          ( 
         
        
          q 
         
        
          ) 
         
        
          = 
         
        
          E 
         
        
          L 
         
        
          B 
         
        
          O 
         
        
       
         \log P(X) = {\mathcal L}(q) + KL(q||p), \quad {\mathcal L}(q) = ELBO 
        
       
     logP(X)=L(q)+KL(q∣∣p),L(q)=ELBO
 其中的ELBO也就是EM算法中的变分。
变分推断的一个具体作用就是在EM算法中,通过近似推断求解出 q ( z ) q(z) q(z)的分布。
若要使变分最大,自然是: q ^ = a r g max  L ( q ) ⟹ q ^ ≈ P ( Z ∣ X ) {\hat q} = arg\max {\mathcal L}(q) \implies {\hat q} \approx P(Z|X) q^=argmaxL(q)⟹q^≈P(Z∣X),但在EM算法章节中我们也说了,由于 q ^ = P ( Z ∣ X ) {\hat q} = P(Z|X) q^=P(Z∣X)实际上大多情况都难以求解,所以需要通过别的办法实现。
而变分推断使用了Mean Theory: q ( Z ) = ∏ i = 1 M q i ( Z i ) q(Z) = \prod_{i=1}^M q_i(Z_i) q(Z)=∏i=1Mqi(Zi)。其中 M M M表示 q ( Z ) q(Z) q(Z)被切分成了 M M M个维度,其中每个维度表示为 q i ( Z i ) q_i(Z_i) qi(Zi)。这样通过固定 i = 1 , … , j − 1 , j + 1 , … , M i = 1, \dots, j-1, j+1, \dots, M i=1,…,j−1,j+1,…,M的项求出 q j ( Z j ) q_j(Z_j) qj(Zj),由Mean Theory定理的公式就可以求出 q ( Z ) q(Z) q(Z)
所以下面我们来分析变分 
     
      
       
       
         L 
        
       
         ( 
        
       
         q 
        
       
         ) 
        
       
      
        {\mathcal L}(q) 
       
      
    L(q):
  
      
       
        
        
          L 
         
        
          ( 
         
        
          q 
         
        
          ) 
         
        
          = 
         
         
         
           ∫ 
          
         
           Z 
          
         
        
          q 
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          log 
         
        
           
         
         
         
           P 
          
         
           ( 
          
         
           X 
          
         
           , 
          
         
           Z 
          
         
           ) 
          
         
        
          d 
         
        
          Z 
         
        
          − 
         
         
         
           ∫ 
          
         
           Z 
          
         
        
          q 
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          log 
         
        
           
         
         
         
           q 
          
         
           ( 
          
         
           Z 
          
         
           ) 
          
         
        
          d 
         
        
          Z 
         
        
       
         {\mathcal L}(q) = \int_Z q(Z) \log {P(X, Z)} {\rm d}Z - \int_Z q(Z) \log {q(Z)} {\rm d}Z 
        
       
     L(q)=∫Zq(Z)logP(X,Z)dZ−∫Zq(Z)logq(Z)dZ
 若我们将Mean Theory代入左式:
-  可得公式: 
 l e f t = ∫ Z ∏ i = 1 M q i ( Z i ) ⋅ log  P ( X , Z ) d Z = ∫ Z j q j ( Z j ) ⋅ [ ∫ Z − Z j ∏ i ≠ j M q i ( Z i ) log  P ( X , Z ) d Z − Z j ] d Z j = ∫ Z j q j ( Z j ) ⋅ E ∏ i ≠ j M q i ( Z i ) [ log  P ( X , Z ) ] d Z j \begin{align} left &= \int_Z \prod_{i=1}^M q_i(Z_i) \cdot \log P(X, Z) {\rm d}_Z \\ &= \int_{Z_j}q_j(Z_j) \cdot \left[ \int_{Z - Z_j} \prod_{i \neq j}^M q_i(Z_i) \log P(X, Z) {\rm d}_{Z-Z_j} \right] {\rm d}_{Z_j} \\ &= \int_{Z_j}q_j(Z_j) \cdot E_{\prod_{i \neq j}^M q_i(Z_i)} \left[ \log P(X, Z) \right] {\rm d}_{Z_j} \\ \end{align} left=∫Zi=1∏Mqi(Zi)⋅logP(X,Z)dZ=∫Zjqj(Zj)⋅ ∫Z−Zji=j∏Mqi(Zi)logP(X,Z)dZ−Zj dZj=∫Zjqj(Zj)⋅E∏i=jMqi(Zi)[logP(X,Z)]dZj
-  此时我们强行将 E ∏ i ≠ j M q i ( Z i ) [ log  P ( X , Z ) ] E_{\prod_{i \neq j}^M q_i(Z_i)} \left[ \log P(X, Z) \right] E∏i=jMqi(Zi)[logP(X,Z)]定义为 log  P ^ ( X , Z j ) \log {\hat P}(X, Z_j) logP^(X,Zj),就能得到: 
 l e f t = ∫ Z j q j ( Z j ) ⋅ log  P ^ ( X , Z j ) d Z j \begin{align} left = \int_{Z_j}q_j(Z_j) \cdot \log {\hat P}(X, Z_j) {\rm d}_{Z_j} \\ \end{align} left=∫Zjqj(Zj)⋅logP^(X,Zj)dZj
若将Mean Theory代入右式:
-  可得公式: 
 r i g h t = ∫ Z ∏ i = 1 M q i ( Z i ) ⋅ ∑ k = 1 M log  q k ( Z k ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) ⋅ [ log  q 1 ( Z 1 ) + log  q 2 ( Z 2 ) + ⋯ + log  q M ( Z M ) ] d Z \begin{align} right &= \int_Z \prod_{i=1}^M q_i(Z_i) \cdot \sum_{k=1}^M \log q_k(Z_k) {\rm d}_Z \\ &= \int_Z \prod_{i=1}^M q_i(Z_i) \cdot [ \log q_1(Z_1) + \log q_2(Z_2) + \dots + \log q_M(Z_M) ] {\rm d}_Z \\ \end{align} right=∫Zi=1∏Mqi(Zi)⋅k=1∑Mlogqk(Zk)dZ=∫Zi=1∏Mqi(Zi)⋅[logq1(Z1)+logq2(Z2)+⋯+logqM(ZM)]dZ
-  其中出了第 j j j项,我们都固定了,可以视为常数。将第 j j j项提出来可以得到: 
 j − t h = ∫ Z ∏ i = 1 M q i ( Z i ) ⋅ log  q j ( Z j ) d Z = ∫ Z 1 q 1 ( Z 1 ) d Z 1  ⋯ ∫ Z j q j ( Z j ) ⋅ log  q j ( Z j ) d Z j  ⋯ ∫ Z M q M ( Z M ) d Z M = ∫ Z j q j ( Z j ) ⋅ log  q j ( Z j ) d Z j \begin{align} j-th &= \int_Z \prod_{i=1}^M q_i(Z_i) \cdot \log q_j(Z_j) {\rm d}_Z \\ &= \int_{Z_1} q_1(Z_1) {\rm d}_{Z_1} \dots \int_{Z_j} q_j(Z_j) \cdot \log q_j(Z_j) {\rm d}_{Z_j} \dots \int_{Z_M} q_M(Z_M) {\rm d}_{Z_M} \\ &= \int_{Z_j} q_j(Z_j) \cdot \log q_j(Z_j) {\rm d}_{Z_j} \end{align} j−th=∫Zi=1∏Mqi(Zi)⋅logqj(Zj)dZ=∫Z1q1(Z1)dZ1⋯∫Zjqj(Zj)⋅logqj(Zj)dZj⋯∫ZMqM(ZM)dZM=∫Zjqj(Zj)⋅logqj(Zj)dZj
-  所以可得: 
 r i g h t = ∫ Z j q j ( Z j ) ⋅ log  q j ( Z j ) d Z j + C \begin{align} right = \int_{Z_j} q_j(Z_j) \cdot \log q_j(Z_j) {\rm d}_{Z_j} + C \end{align} right=∫Zjqj(Zj)⋅logqj(Zj)dZj+C
综合一下上述的公式可得:
  
      
       
        
         
          
          
           
            
            
              L 
             
            
              ( 
             
            
              q 
             
            
              ) 
             
            
           
          
          
           
            
             
            
              = 
             
            
              l 
             
            
              e 
             
            
              f 
             
            
              t 
             
            
              − 
             
            
              r 
             
            
              i 
             
            
              g 
             
            
              h 
             
            
              t 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∫ 
              
              
              
                Z 
               
              
                j 
               
              
             
             
             
               q 
              
             
               j 
              
             
            
              ( 
             
             
             
               Z 
              
             
               j 
              
             
            
              ) 
             
            
              ⋅ 
             
            
              log 
             
            
               
             
             
             
               P 
              
             
               ^ 
              
             
            
              ( 
             
            
              X 
             
            
              , 
             
             
             
               Z 
              
             
               j 
              
             
            
              ) 
             
             
             
               d 
              
              
              
                Z 
               
              
                j 
               
              
             
            
              − 
             
             
             
               ∫ 
              
              
              
                Z 
               
              
                j 
               
              
             
             
             
               q 
              
             
               j 
              
             
            
              ( 
             
             
             
               Z 
              
             
               j 
              
             
            
              ) 
             
            
              ⋅ 
             
            
              log 
             
            
               
             
             
             
               q 
              
             
               j 
              
             
            
              ( 
             
             
             
               Z 
              
             
               j 
              
             
            
              ) 
             
             
             
               d 
              
              
              
                Z 
               
              
                j 
               
              
             
            
              − 
             
            
              C 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∫ 
              
              
              
                Z 
               
              
                j 
               
              
             
             
             
               q 
              
             
               j 
              
             
            
              ( 
             
             
             
               Z 
              
             
               j 
              
             
            
              ) 
             
            
              ⋅ 
             
            
              log 
             
            
               
             
             
              
               
               
                 P 
                
               
                 ^ 
                
               
              
                ( 
               
              
                X 
               
              
                , 
               
               
               
                 Z 
                
               
                 j 
                
               
              
                ) 
               
              
              
               
               
                 q 
                
               
                 j 
                
               
              
                ( 
               
               
               
                 Z 
                
               
                 j 
                
               
              
                ) 
               
              
             
             
             
               d 
              
              
              
                Z 
               
              
                j 
               
              
             
            
              − 
             
            
              C 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              − 
             
            
              K 
             
            
              L 
             
            
              ( 
             
             
              
              
                P 
               
              
                ^ 
               
              
             
               ( 
              
             
               X 
              
             
               , 
              
              
              
                Z 
               
              
                j 
               
              
             
               ) 
              
             
            
              ∥ 
             
             
              
              
                q 
               
              
                j 
               
              
             
               ( 
              
              
              
                Z 
               
              
                j 
               
              
             
               ) 
              
             
            
              ) 
             
             
             
               d 
              
              
              
                Z 
               
              
                j 
               
              
             
            
              − 
             
            
              C 
             
            
           
          
          
          
         
        
       
         \begin{align} {\mathcal L}(q) &= left - right \\ &= \int_{Z_j}q_j(Z_j) \cdot \log {\hat P}(X, Z_j) {\rm d}_{Z_j} - \int_{Z_j} q_j(Z_j) \cdot \log q_j(Z_j) {\rm d}_{Z_j} - C \\ &= \int_{Z_j}q_j(Z_j) \cdot \log \frac{{\hat P}(X, Z_j)}{q_j(Z_j)} {\rm d}_{Z_j} - C \\ &= -KL({{\hat P}(X, Z_j)} \Vert {q_j(Z_j)}) {\rm d}_{Z_j} - C \\ \end{align} 
        
       
     L(q)=left−right=∫Zjqj(Zj)⋅logP^(X,Zj)dZj−∫Zjqj(Zj)⋅logqj(Zj)dZj−C=∫Zjqj(Zj)⋅logqj(Zj)P^(X,Zj)dZj−C=−KL(P^(X,Zj)∥qj(Zj))dZj−C
 若要得到最大的$ {\mathcal L}(q) $,可得:
  
      
       
        
        
          { 
         
         
          
           
            
             
              
               
               
                 q 
                
               
                 j 
                
               
              
                ( 
               
               
               
                 Z 
                
               
                 j 
                
               
              
                ) 
               
              
             
               = 
              
              
               
               
                 P 
                
               
                 ^ 
                
               
              
                ( 
               
              
                X 
               
              
                , 
               
               
               
                 Z 
                
               
                 j 
                
               
              
                ) 
               
              
             
            
           
          
          
           
            
             
             
               log 
              
             
                
              
              
              
                P 
               
              
                ^ 
               
              
             
               ( 
              
             
               X 
              
             
               , 
              
              
              
                Z 
               
              
                j 
               
              
             
               ) 
              
             
               = 
              
              
              
                E 
               
               
                
                
                  ∏ 
                 
                 
                 
                   i 
                  
                 
                   ≠ 
                  
                 
                   j 
                  
                 
                
                  M 
                 
                
                
                
                  q 
                 
                
                  i 
                 
                
               
                 ( 
                
                
                
                  Z 
                 
                
                  i 
                 
                
               
                 ) 
                
               
              
              
              
                [ 
               
              
                log 
               
              
                 
               
              
                P 
               
              
                ( 
               
              
                X 
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
                ] 
               
              
             
            
           
          
         
        
       
         \begin{cases} {q_j(Z_j)} = {{\hat P}(X, Z_j)} \\ \log {\hat P}(X, Z_j) = E_{\prod_{i \neq j}^M q_i(Z_i)} \left[ \log P(X, Z) \right] \end{cases} 
        
       
     {qj(Zj)=P^(X,Zj)logP^(X,Zj)=E∏i=jMqi(Zi)[logP(X,Z)]
12.2.2 坐标上升法
至此,我们已经得到了 q j ( Z j ) q_j(Z_j) qj(Zj)的求解公式,接下来我们只要能求出 q 1 ( Z 1 ) , … , q M ( Z M ) q_1(Z_1), \dots, q_M(Z_M) q1(Z1),…,qM(ZM),就可以通过Mean Theory求解出 q ( Z ) q(Z) q(Z)了。
我们根据上面获得的变分最大条件进行分析:
  
      
       
        
        
          log 
         
        
           
         
         
          
          
            q 
           
          
            j 
           
          
         
           ( 
          
          
          
            Z 
           
          
            j 
           
          
         
           ) 
          
         
        
          = 
         
         
         
           E 
          
          
           
           
             ∏ 
            
            
            
              i 
             
            
              ≠ 
             
            
              j 
             
            
           
             M 
            
           
           
           
             q 
            
           
             i 
            
           
          
            ( 
           
           
           
             Z 
            
           
             i 
            
           
          
            ) 
           
          
         
         
         
           [ 
          
         
           log 
          
         
            
          
         
           P 
          
         
           ( 
          
         
           X 
          
         
           , 
          
         
           Z 
          
         
           ) 
          
         
           ] 
          
         
        
       
         \log {q_j(Z_j)} = E_{\prod_{i \neq j}^M q_i(Z_i)} \left[ \log P(X, Z) \right] 
        
       
     logqj(Zj)=E∏i=jMqi(Zi)[logP(X,Z)]
 我们将这个条件展开可以得到:
  
      
       
        
        
          log 
         
        
           
         
         
          
          
            q 
           
          
            j 
           
          
         
           ( 
          
          
          
            Z 
           
          
            j 
           
          
         
           ) 
          
         
        
          = 
         
         
         
           ∫ 
          
          
          
            q 
           
          
            1 
           
          
          
        
          ⋯ 
         
         
         
           ∫ 
          
          
          
            q 
           
           
           
             j 
            
           
             − 
            
           
             1 
            
           
          
         
         
         
           ∫ 
          
          
          
            q 
           
           
           
             j 
            
           
             + 
            
           
             1 
            
           
          
          
        
          ⋯ 
         
         
         
           ∫ 
          
          
          
            q 
           
          
            M 
           
          
         
         
         
           q 
          
         
           1 
          
         
        
          , 
         
        
          … 
         
        
          , 
         
         
         
           q 
          
          
          
            j 
           
          
            − 
           
          
            1 
           
          
         
        
          , 
         
         
         
           q 
          
          
          
            j 
           
          
            + 
           
          
            1 
           
          
         
        
          , 
         
        
          … 
         
        
          , 
         
         
         
           q 
          
         
           M 
          
         
        
          ⋅ 
         
        
          log 
         
        
           
         
        
          P 
         
        
          ( 
         
        
          X 
         
        
          , 
         
        
          Z 
         
        
          ) 
         
        
          d 
         
         
         
           q 
          
         
           1 
          
         
        
          … 
         
        
          d 
         
         
         
           q 
          
          
          
            j 
           
          
            − 
           
          
            1 
           
          
         
        
          d 
         
         
         
           q 
          
          
          
            j 
           
          
            + 
           
          
            1 
           
          
         
        
          … 
         
        
          d 
         
         
         
           q 
          
         
           M 
          
         
        
       
         \log {q_j(Z_j)} = \int_{q_1} \dots \int_{q_{j-1}} \int_{q_{j+1}} \dots \int_{q_M} q_1, \dots, q_{j-1}, q_{j+1}, \dots, q_M \cdot \log P(X, Z) {\rm d}{q_1} \dots {\rm d}{q_{j-1}} {\rm d}{q_{j+1}} \dots {\rm d}{q_M} 
        
       
     logqj(Zj)=∫q1⋯∫qj−1∫qj+1⋯∫qMq1,…,qj−1,qj+1,…,qM⋅logP(X,Z)dq1…dqj−1dqj+1…dqM
 已知该公式,我们可以采用坐标上升法迭代求解 
     
      
       
        
         
         
           q 
          
         
           1 
          
         
        
          ( 
         
         
         
           Z 
          
         
           1 
          
         
        
          ) 
         
        
       
         , 
        
       
         … 
        
       
         , 
        
        
         
         
           q 
          
         
           M 
          
         
        
          ( 
         
         
         
           Z 
          
         
           M 
          
         
        
          ) 
         
        
       
      
        {q_1(Z_1)}, \dots, {q_M(Z_M)} 
       
      
    q1(Z1),…,qM(ZM):
  
      
       
        
        
          { 
         
         
          
           
            
             
             
               log 
              
             
                
              
              
               
                
                
                  q 
                 
                
                  1 
                 
                
               
                 ^ 
                
               
              
                ( 
               
               
               
                 Z 
                
               
                 1 
                
               
              
                ) 
               
              
             
               = 
              
              
              
                ∫ 
               
               
               
                 q 
                
               
                 2 
                
               
               
             
               ⋯ 
              
              
              
                ∫ 
               
               
               
                 q 
                
               
                 M 
                
               
              
              
              
                q 
               
              
                2 
               
              
             
               , 
              
             
               … 
              
             
               , 
              
              
              
                q 
               
              
                M 
               
              
             
               ⋅ 
              
             
               log 
              
             
                
              
             
               P 
              
             
               ( 
              
             
               X 
              
             
               , 
              
             
               Z 
              
             
               ) 
              
             
               d 
              
              
              
                q 
               
              
                2 
               
              
             
               … 
              
             
               d 
              
              
              
                q 
               
              
                M 
               
              
             
            
           
          
          
           
            
             
             
               log 
              
             
                
              
              
               
                
                
                  q 
                 
                
                  2 
                 
                
               
                 ^ 
                
               
              
                ( 
               
               
               
                 Z 
                
               
                 2 
                
               
              
                ) 
               
              
             
               = 
              
              
              
                ∫ 
               
               
                
                
                  q 
                 
                
                  1 
                 
                
               
                 ^ 
                
               
              
              
              
                ∫ 
               
               
               
                 q 
                
               
                 3 
                
               
               
             
               ⋯ 
              
              
              
                ∫ 
               
               
               
                 q 
                
               
                 M 
                
               
              
              
               
               
                 q 
                
               
                 1 
                
               
              
                ^ 
               
              
             
               , 
              
              
              
                q 
               
              
                3 
               
              
             
               , 
              
             
               … 
              
             
               , 
              
              
              
                q 
               
              
                M 
               
              
             
               ⋅ 
              
             
               log 
              
             
                
              
             
               P 
              
             
               ( 
              
             
               X 
              
             
               , 
              
             
               Z 
              
             
               ) 
              
             
               d 
              
              
               
               
                 q 
                
               
                 1 
                
               
              
                ^ 
               
              
             
               d 
              
              
              
                q 
               
              
                3 
               
              
             
               … 
              
             
               d 
              
              
              
                q 
               
              
                M 
               
              
             
            
           
          
          
           
            
            
              … 
             
            
           
          
          
           
            
             
             
               log 
              
             
                
              
              
               
                
                
                  q 
                 
                
                  M 
                 
                
               
                 ^ 
                
               
              
                ( 
               
               
               
                 Z 
                
               
                 M 
                
               
              
                ) 
               
              
             
               = 
              
              
              
                ∫ 
               
               
                
                
                  q 
                 
                
                  1 
                 
                
               
                 ^ 
                
               
               
             
               ⋯ 
              
              
              
                ∫ 
               
               
                
                
                  q 
                 
                 
                 
                   M 
                  
                 
                   − 
                  
                 
                   1 
                  
                 
                
               
                 ^ 
                
               
              
              
               
               
                 q 
                
               
                 1 
                
               
              
                ^ 
               
              
             
               , 
              
             
               … 
              
             
               , 
              
              
               
               
                 q 
                
                
                
                  M 
                 
                
                  − 
                 
                
                  1 
                 
                
               
              
                ^ 
               
              
             
               ⋅ 
              
             
               log 
              
             
                
              
             
               P 
              
             
               ( 
              
             
               X 
              
             
               , 
              
             
               Z 
              
             
               ) 
              
             
               d 
              
              
               
               
                 q 
                
               
                 1 
                
               
              
                ^ 
               
              
             
               … 
              
             
               d 
              
              
               
               
                 q 
                
                
                
                  M 
                 
                
                  − 
                 
                
                  1 
                 
                
               
              
                ^ 
               
              
             
            
           
          
         
        
       
         \begin{cases} \log {\hat {q_1}(Z_1)} = \int_{q_2} \dots \int_{q_M} q_2, \dots, q_M \cdot \log P(X, Z) {\rm d}{q_2} \dots {\rm d}{q_M} \\ \log {\hat {q_2}(Z_2)} = \int_{\hat {q_1}} \int_{q_3} \dots \int_{q_M} {\hat {q_1}}, q_3, \dots, q_M \cdot \log P(X, Z) {\rm d}{\hat {q_1}} {\rm d}{q_3} \dots {\rm d}{q_M} \\ \dots \\ \log {\hat {q_M}(Z_M)} = \int_{\hat {q_1}} \dots \int_{\hat {q_{M-1}}} {\hat {q_1}}, \dots, {\hat {q_{M-1}}} \cdot \log P(X, Z) {\rm d}{\hat {q_1}} \dots {\rm d}{\hat {q_{M-1}}} \\ \end{cases} 
        
       
     ⎩ 
              ⎨ 
              ⎧logq1^(Z1)=∫q2⋯∫qMq2,…,qM⋅logP(X,Z)dq2…dqMlogq2^(Z2)=∫q1^∫q3⋯∫qMq1^,q3,…,qM⋅logP(X,Z)dq1^dq3…dqM…logqM^(ZM)=∫q1^⋯∫qM−1^q1^,…,qM−1^⋅logP(X,Z)dq1^…dqM−1^
 并且可以通过循环多次的迭代增加精度。
但Classical VI有缺点:
- Mean Fied的条件太强,很多模型不满足
- 若维度太高,会变成高维积分导致无法求解
12.3 SGVI——随机梯度变分推断
12.3.1 一般化MC方法
倘若我们使用 
     
      
       
       
         φ 
        
       
      
        \varphi 
       
      
    φ表示 
     
      
       
       
         q 
        
       
         ( 
        
       
         Z 
        
       
         ) 
        
       
      
        q(Z) 
       
      
    q(Z)的参数,同时下文普遍将 
     
      
       
       
         q 
        
       
         ( 
        
       
         Z 
        
       
         ) 
        
       
      
        q(Z) 
       
      
    q(Z)缩写为 
     
      
       
        
        
          q 
         
        
          φ 
         
        
       
      
        q_\varphi 
       
      
    qφ,所以我们可以将公式写为(ELBO对于 
     
      
       
        
        
          P 
         
        
          θ 
         
        
       
         ( 
        
       
         X 
        
       
         , 
        
       
         Z 
        
       
         ) 
        
       
      
        P_\theta(X, Z) 
       
      
    Pθ(X,Z)和 
     
      
       
        
        
          P 
         
        
          θ 
         
        
       
         ( 
        
        
        
          x 
         
        
          i 
         
        
       
         , 
        
       
         Z 
        
       
         ) 
        
       
      
        P_\theta(x_i, Z) 
       
      
    Pθ(xi,Z)都成立):
  
      
       
        
        
          L 
         
        
          ( 
         
        
          φ 
         
        
          ) 
         
        
          = 
         
        
          E 
         
        
          L 
         
        
          B 
         
        
          O 
         
        
          = 
         
         
         
           E 
          
          
          
            q 
           
          
            φ 
           
          
         
         
         
           [ 
          
         
           log 
          
         
            
          
          
           
            
            
              P 
             
            
              θ 
             
            
           
             ( 
            
            
            
              x 
             
            
              i 
             
            
           
             , 
            
           
             Z 
            
           
             ) 
            
           
           
           
             q 
            
           
             φ 
            
           
          
         
           ] 
          
         
        
          = 
         
         
         
           E 
          
          
          
            q 
           
          
            φ 
           
          
         
         
         
           [ 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             θ 
            
           
          
            ( 
           
           
           
             x 
            
           
             i 
            
           
          
            , 
           
          
            Z 
           
          
            ) 
           
          
         
           − 
          
         
           log 
          
         
            
          
          
          
            q 
           
          
            φ 
           
          
         
           ] 
          
         
        
       
         {\mathcal L}(\varphi) = ELBO = E_{q_\varphi } \left[ \log \frac{P_\theta(x_i, Z)}{q_\varphi} \right] = E_{q_\varphi } \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] 
        
       
     L(φ)=ELBO=Eqφ[logqφPθ(xi,Z)]=Eqφ[logPθ(xi,Z)−logqφ]
 我们当前的目标是求解 
     
      
       
        
        
          φ 
         
        
          ^ 
         
        
       
         = 
        
       
         a 
        
       
         r 
        
       
         g 
        
        
         
         
           max 
          
         
            
          
         
        
          φ 
         
        
       
         L 
        
       
         ( 
        
       
         φ 
        
       
         ) 
        
       
      
        {\hat \varphi} = arg\max_\varphi {\mathcal L}(\varphi) 
       
      
    φ^=argmaxφL(φ),为了求解,我们打算采用梯度上升法,而要想使用梯度上升法,就必须求解得到梯度方向 
     
      
       
        
        
          ∇ 
         
        
          φ 
         
        
       
         L 
        
       
         ( 
        
       
         φ 
        
       
         ) 
        
       
      
        \nabla_\varphi {\mathcal L}(\varphi) 
       
      
    ∇φL(φ):
  
      
       
        
         
         
           ∇ 
          
         
           φ 
          
         
        
          L 
         
        
          ( 
         
        
          φ 
         
        
          ) 
         
        
          = 
         
         
         
           ∇ 
          
         
           φ 
          
         
         
         
           E 
          
          
          
            q 
           
          
            φ 
           
          
         
         
         
           [ 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             θ 
            
           
          
            ( 
           
           
           
             x 
            
           
             i 
            
           
          
            , 
           
          
            Z 
           
          
            ) 
           
          
         
           − 
          
         
           log 
          
         
            
          
          
          
            q 
           
          
            φ 
           
          
         
           ] 
          
         
        
          = 
         
         
         
           ∇ 
          
         
           φ 
          
         
         
         
           ∫ 
          
         
           Z 
          
         
         
         
           q 
          
         
           φ 
          
         
        
          ⋅ 
         
         
         
           [ 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             θ 
            
           
          
            ( 
           
           
           
             x 
            
           
             i 
            
           
          
            , 
           
          
            Z 
           
          
            ) 
           
          
         
           − 
          
         
           log 
          
         
            
          
          
          
            q 
           
          
            φ 
           
          
         
           ] 
          
         
         
         
           d 
          
         
           Z 
          
         
        
       
         \nabla_\varphi {\mathcal L}(\varphi) = \nabla_\varphi E_{q_\varphi } \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] = \nabla_\varphi \int_Z q_\varphi \cdot \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] {\rm d}_Z 
        
       
     ∇φL(φ)=∇φEqφ[logPθ(xi,Z)−logqφ]=∇φ∫Zqφ⋅[logPθ(xi,Z)−logqφ]dZ
引入梯度变换公式:
∇ x ∫ z A ( x , z ) ⋅ B ( x , z ) d z = ∫ z ∇ x A ( x , z ) ⋅ B ( x , z ) d z + ∫ z A ( x , z ) ⋅ ∇ x B ( x , z ) d z \nabla_x \int_z A(x, z) \cdot B(x, z) {\rm d}z = \int_z \nabla_x A(x, z) \cdot B(x, z) {\rm d}z + \int_z A(x, z) \cdot \nabla_x B(x, z) {\rm d}z ∇x∫zA(x,z)⋅B(x,z)dz=∫z∇xA(x,z)⋅B(x,z)dz+∫zA(x,z)⋅∇xB(x,z)dz
可得:
  
      
       
        
         
         
           ∇ 
          
         
           φ 
          
         
        
          L 
         
        
          ( 
         
        
          φ 
         
        
          ) 
         
        
          = 
         
         
         
           ∫ 
          
         
           Z 
          
         
         
         
           ∇ 
          
         
           φ 
          
         
         
         
           q 
          
         
           φ 
          
         
        
          ⋅ 
         
         
         
           [ 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             θ 
            
           
          
            ( 
           
           
           
             x 
            
           
             i 
            
           
          
            , 
           
          
            Z 
           
          
            ) 
           
          
         
           − 
          
         
           log 
          
         
            
          
          
          
            q 
           
          
            φ 
           
          
         
           ] 
          
         
         
         
           d 
          
         
           Z 
          
         
        
          + 
         
         
         
           ∫ 
          
         
           Z 
          
         
         
         
           q 
          
         
           φ 
          
         
        
          ⋅ 
         
         
         
           ∇ 
          
         
           φ 
          
         
         
         
           [ 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             θ 
            
           
          
            ( 
           
           
           
             x 
            
           
             i 
            
           
          
            , 
           
          
            Z 
           
          
            ) 
           
          
         
           − 
          
         
           log 
          
         
            
          
          
          
            q 
           
          
            φ 
           
          
         
           ] 
          
         
         
         
           d 
          
         
           Z 
          
         
        
       
         \nabla_\varphi {\mathcal L}(\varphi) = \int_Z \nabla_\varphi q_\varphi \cdot \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] {\rm d}_Z + \int_Z q_\varphi \cdot \nabla_\varphi \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] {\rm d}_Z 
        
       
     ∇φL(φ)=∫Z∇φqφ⋅[logPθ(xi,Z)−logqφ]dZ+∫Zqφ⋅∇φ[logPθ(xi,Z)−logqφ]dZ
 这里主要看一下右边的公式:
  
      
       
        
         
          
          
           
            
            
              r 
             
            
              i 
             
            
              g 
             
            
              h 
             
            
              t 
             
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               q 
              
             
               φ 
              
             
            
              ⋅ 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               [ 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
             
             
               d 
              
             
               Z 
              
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               q 
              
             
               φ 
              
             
            
              ⋅ 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               [ 
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
             
             
               d 
              
             
               Z 
              
             
            
           
          
          
           
            
            
              —— 
             
            
              log 
             
            
               
             
             
              
              
                P 
               
              
                θ 
               
              
             
               ( 
              
              
              
                x 
               
              
                i 
               
              
             
               , 
              
             
               Z 
              
             
               ) 
              
             
            
              与 
             
            
              φ 
             
            
              无关 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              − 
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               q 
              
             
               φ 
              
             
             
             
               d 
              
             
               Z 
              
             
            
           
          
          
           
            
            
              —— 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               [ 
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
            
              = 
             
            
              − 
             
             
             
               1 
              
              
              
                q 
               
              
                φ 
               
              
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               q 
              
             
               φ 
              
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              − 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               q 
              
             
               φ 
              
             
             
             
               d 
              
             
               Z 
              
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              − 
             
             
             
               ∇ 
              
             
               φ 
              
             
            
              1 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              0 
             
            
           
          
          
          
         
        
       
         \begin{align} right &= \int_Z q_\varphi \cdot \nabla_\varphi \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] {\rm d}_Z \\ &= \int_Z q_\varphi \cdot \nabla_\varphi \left[- \log {q_\varphi} \right] {\rm d}_Z & ——\log {P_\theta(x_i, Z)}与\varphi无关 \\ &= - \int_Z \nabla_\varphi q_\varphi {\rm d}_Z & ——\nabla_\varphi \left[- \log {q_\varphi} \right]=-\frac{1}{q_\varphi}\nabla_\varphi q_\varphi \\ &= - \nabla_\varphi \int_Z q_\varphi {\rm d}_Z \\ &= - \nabla_\varphi 1 \\ &= 0 \end{align} 
        
       
     right=∫Zqφ⋅∇φ[logPθ(xi,Z)−logqφ]dZ=∫Zqφ⋅∇φ[−logqφ]dZ=−∫Z∇φqφdZ=−∇φ∫ZqφdZ=−∇φ1=0——logPθ(xi,Z)与φ无关——∇φ[−logqφ]=−qφ1∇φqφ
 所以可以将公式继续写为:
  
      
       
        
         
          
          
           
            
             
             
               ∇ 
              
             
               φ 
              
             
            
              L 
             
            
              ( 
             
            
              φ 
             
            
              ) 
             
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               q 
              
             
               φ 
              
             
            
              ⋅ 
             
             
             
               [ 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
             
             
               d 
              
             
               Z 
              
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               q 
              
             
               φ 
              
             
             
             
               ∇ 
              
             
               φ 
              
             
            
              log 
             
            
               
             
             
             
               q 
              
             
               φ 
              
             
            
              ⋅ 
             
             
             
               [ 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
             
             
               d 
              
             
               Z 
              
             
            
           
          
          
           
            
            
              —— 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               q 
              
             
               φ 
              
             
            
              = 
             
             
             
               q 
              
             
               φ 
              
             
             
             
               ∇ 
              
             
               φ 
              
             
            
              log 
             
            
               
             
             
             
               q 
              
             
               φ 
              
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               E 
              
              
              
                q 
               
              
                φ 
               
              
             
             
             
               [ 
              
              
              
                ∇ 
               
              
                φ 
               
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ⋅ 
              
              
              
                ( 
               
              
                log 
               
              
                 
               
               
                
                
                  P 
                 
                
                  θ 
                 
                
               
                 ( 
                
                
                
                  x 
                 
                
                  i 
                 
                
               
                 , 
                
               
                 Z 
                
               
                 ) 
                
               
              
                − 
               
              
                log 
               
              
                 
               
               
               
                 q 
                
               
                 φ 
                
               
              
                ) 
               
              
             
               ] 
              
             
            
           
          
          
          
         
        
       
         \begin{align} \nabla_\varphi {\mathcal L}(\varphi) &= \int_Z \nabla_\varphi q_\varphi \cdot \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] {\rm d}_Z \\ &= \int_Z q_\varphi \nabla_\varphi \log{q_\varphi} \cdot \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] {\rm d}_Z & ——\nabla_\varphi q_\varphi = q_\varphi \nabla_\varphi \log{q_\varphi} \\ &= E_{q_\varphi} \left[ \nabla_\varphi \log{q_\varphi} \cdot \left( \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right) \right] \\ \end{align} 
        
       
     ∇φL(φ)=∫Z∇φqφ⋅[logPθ(xi,Z)−logqφ]dZ=∫Zqφ∇φlogqφ⋅[logPθ(xi,Z)−logqφ]dZ=Eqφ[∇φlogqφ⋅(logPθ(xi,Z)−logqφ)]——∇φqφ=qφ∇φlogqφ
至此我们已经得到了公式:
  
      
       
        
         
          
          
           
            
             
             
               ∇ 
              
             
               φ 
              
             
            
              L 
             
            
              ( 
             
            
              φ 
             
            
              ) 
             
            
              = 
             
             
             
               E 
              
              
              
                q 
               
              
                φ 
               
              
             
             
             
               [ 
              
              
              
                ∇ 
               
              
                φ 
               
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ⋅ 
              
              
              
                ( 
               
              
                log 
               
              
                 
               
               
                
                
                  P 
                 
                
                  θ 
                 
                
               
                 ( 
                
                
                
                  x 
                 
                
                  i 
                 
                
               
                 , 
                
               
                 Z 
                
               
                 ) 
                
               
              
                − 
               
              
                log 
               
              
                 
               
               
               
                 q 
                
               
                 φ 
                
               
              
                ) 
               
              
             
               ] 
              
             
            
           
          
          
          
         
        
       
         \begin{align} \nabla_\varphi {\mathcal L}(\varphi) = E_{q_\varphi} \left[ \nabla_\varphi \log{q_\varphi} \cdot \left( \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right) \right] \\ \end{align} 
        
       
     ∇φL(φ)=Eqφ[∇φlogqφ⋅(logPθ(xi,Z)−logqφ)]
 通过该公式,我们就可以通过Monte Carlo方法进行采样估算:
  
      
       
        
         
         
           Z 
          
          
          
            ( 
           
          
            l 
           
          
            ) 
           
          
         
        
          ∽ 
         
         
         
           q 
          
         
           φ 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          , 
         
        
          l 
         
        
          ∈ 
         
        
          [ 
         
        
          1 
         
        
          , 
         
        
          L 
         
        
          ] 
           
        
          ⟹ 
           
         
         
           ∇ 
          
         
           φ 
          
         
        
          L 
         
        
          ( 
         
        
          φ 
         
        
          ) 
         
        
          ≈ 
         
         
         
           1 
          
         
           L 
          
         
         
         
           ∑ 
          
          
          
            l 
           
          
            = 
           
          
            1 
           
          
         
           L 
          
         
         
         
           [ 
          
          
          
            ∇ 
           
          
            φ 
           
          
         
           log 
          
         
            
          
          
          
            q 
           
          
            φ 
           
          
         
           ( 
          
          
          
            Z 
           
           
           
             ( 
            
           
             l 
            
           
             ) 
            
           
          
         
           ) 
          
         
           ⋅ 
          
          
          
            ( 
           
          
            log 
           
          
             
           
           
            
            
              P 
             
            
              θ 
             
            
           
             ( 
            
            
            
              x 
             
            
              i 
             
            
           
             , 
            
            
            
              Z 
             
             
             
               ( 
              
             
               l 
              
             
               ) 
              
             
            
           
             ) 
            
           
          
            − 
           
          
            log 
           
          
             
           
           
           
             q 
            
           
             φ 
            
           
          
            ( 
           
           
           
             Z 
            
            
            
              ( 
             
            
              l 
             
            
              ) 
             
            
           
          
            ) 
           
          
            ) 
           
          
         
           ] 
          
         
        
       
         Z^{(l)} \backsim q_\varphi(Z), l \in [1, L] \implies \nabla_\varphi {\mathcal L}(\varphi) \approx \frac{1}{L} \sum_{l=1}^L \left[ \nabla_\varphi \log{q_\varphi}(Z^{(l)}) \cdot \left( \log {P_\theta(x_i, Z^{(l)})} - \log {q_\varphi}(Z^{(l)}) \right) \right] 
        
       
     Z(l)∽qφ(Z),l∈[1,L]⟹∇φL(φ)≈L1l=1∑L[∇φlogqφ(Z(l))⋅(logPθ(xi,Z(l))−logqφ(Z(l)))]
 但这个采样方法无法使用,因为 
     
      
       
        
        
          ∇ 
         
        
          φ 
         
        
       
         log 
        
       
          
        
        
        
          q 
         
        
          φ 
         
        
       
         ( 
        
        
        
          Z 
         
         
         
           ( 
          
         
           l 
          
         
           ) 
          
         
        
       
         ) 
        
       
      
        \nabla_\varphi \log{q_\varphi}(Z^{(l)}) 
       
      
    ∇φlogqφ(Z(l))在 
     
      
       
       
         ( 
        
       
         0 
        
       
         , 
        
       
         1 
        
       
         ) 
        
       
      
        (0, 1) 
       
      
    (0,1)的区间内波动太大( 
     
      
       
       
         log 
        
       
          
        
       
         在 
        
       
         ( 
        
       
         0 
        
       
         , 
        
       
         1 
        
       
         ) 
        
       
      
        \log在(0,1) 
       
      
    log在(0,1)中的取值范围太大),导致单次样本解的方差太大。若要解决就要增加采样的数据了,但这又太浪费时间,不满足现实应用。
12.3.2 降方差——Variance Reduction
为了降低方差,这里要用到重新参数化技巧(Reparametrization Trick):通过对随机化参数的重构,降低当前公式的求解方差。
由于当前的参数是 
     
      
       
       
         Z 
        
       
         ∽ 
        
        
        
          q 
         
        
          φ 
         
        
       
         ( 
        
       
         Z 
        
       
         ∣ 
        
        
        
          x 
         
        
          i 
         
        
       
         ) 
        
       
      
        Z \backsim q_\varphi(Z|x_i) 
       
      
    Z∽qφ(Z∣xi),为了将参数转换为方差没有那么大的参数,我们假设:
  
      
       
        
        
          Z 
         
        
          ∽ 
         
         
         
           g 
          
         
           φ 
          
         
        
          ( 
         
        
          ε 
         
        
          ∣ 
         
         
         
           x 
          
         
           i 
          
         
        
          ) 
         
        
          , 
         
        
          ε 
         
        
          ∽ 
         
        
          p 
         
        
          ( 
         
        
          ε 
         
        
          ) 
         
        
       
         Z \backsim g_\varphi(\varepsilon|x_i), \varepsilon \backsim p(\varepsilon) 
        
       
     Z∽gφ(ε∣xi),ε∽p(ε)
 通过上述方法,将Z随机样本的身份给了 
     
      
       
       
         ε 
        
       
      
        \varepsilon 
       
      
    ε,这样可以通过创建 
     
      
       
        
        
          ε 
         
         
         
           ( 
          
         
           l 
          
         
           ) 
          
         
        
       
      
        \varepsilon^{(l)} 
       
      
    ε(l)求出 
     
      
       
       
         Z 
        
       
      
        Z 
       
      
    Z,所以现在我们已知新旧的两个分布: 
     
      
       
       
         { 
        
        
         
          
           
            
            
              z 
             
            
              ∽ 
             
             
             
               q 
              
             
               φ 
              
             
            
              ( 
             
            
              Z 
             
            
              ∣ 
             
             
             
               x 
              
             
               i 
              
             
            
              ) 
             
            
           
          
         
         
          
           
            
            
              ε 
             
            
              ∽ 
             
            
              p 
             
            
              ( 
             
            
              ε 
             
            
              ) 
             
            
           
          
         
        
       
      
        \begin{cases} z \backsim q_\varphi(Z|x_i) \\ \varepsilon \backsim p(\varepsilon) \end{cases} 
       
      
    {z∽qφ(Z∣xi)ε∽p(ε),这两个分布是通过 
     
      
       
        
        
          g 
         
        
          φ 
         
        
       
      
        g_\varphi 
       
      
    gφ转换,可以得到 
     
      
       
       
         ∣ 
        
        
        
          q 
         
        
          φ 
         
        
       
         ( 
        
       
         Z 
        
       
         ∣ 
        
        
        
          x 
         
        
          i 
         
        
       
         ) 
        
       
         d 
        
       
         z 
        
       
         ∣ 
        
       
         = 
        
       
         ∣ 
        
       
         p 
        
       
         ( 
        
       
         ε 
        
       
         ) 
        
       
         d 
        
       
         ε 
        
       
         ∣ 
        
       
      
        |q_\varphi(Z|x_i) {\rm d}z| = |p(\varepsilon) {\rm d}\varepsilon| 
       
      
    ∣qφ(Z∣xi)dz∣=∣p(ε)dε∣(我也不知道为啥)。
所以我们可以得到以下推导:
  
      
       
        
         
          
          
           
            
             
             
               ∇ 
              
             
               φ 
              
             
            
              L 
             
            
              ( 
             
            
              φ 
             
            
              ) 
             
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               [ 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
            
              ⋅ 
             
             
             
               q 
              
             
               φ 
              
             
             
             
               d 
              
             
               Z 
              
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               ∫ 
              
             
               Z 
              
             
             
             
               [ 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
            
              ⋅ 
             
            
              p 
             
            
              ( 
             
            
              ε 
             
            
              ) 
             
            
              d 
             
            
              ε 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               ∇ 
              
             
               φ 
              
             
             
             
               E 
              
              
              
                p 
               
              
                ( 
               
              
                ε 
               
              
                ) 
               
              
             
             
             
               [ 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ] 
              
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               E 
              
              
              
                p 
               
              
                ( 
               
              
                ε 
               
              
                ) 
               
              
             
             
             
               [ 
              
              
              
                ∇ 
               
              
                φ 
               
              
             
               ( 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ) 
              
             
               ] 
              
             
            
           
          
          
           
            
            
              —— 
             
             
             
               ∇ 
              
             
               φ 
              
             
            
              与 
             
            
              p 
             
            
              ( 
             
            
              ε 
             
            
              ) 
             
            
              无关 
             
            
           
          
          
          
         
         
          
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               E 
              
              
              
                p 
               
              
                ( 
               
              
                ε 
               
              
                ) 
               
              
             
             
             
               [ 
              
              
              
                ∇ 
               
              
                Z 
               
              
             
               ( 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ) 
              
             
               ⋅ 
              
              
              
                ∇ 
               
              
                φ 
               
              
              
              
                g 
               
              
                φ 
               
              
             
               ( 
              
             
               ε 
              
             
               ∣ 
              
              
              
                x 
               
              
                i 
               
              
             
               ) 
              
             
               ] 
              
             
            
           
          
          
           
           
             ——变量转换方法 
            
           
          
          
          
         
        
       
         \begin{align} \nabla_\varphi {\mathcal L}(\varphi) &= \nabla_\varphi \int_Z \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] \cdot q_\varphi {\rm d}_Z \\ &= \nabla_\varphi \int_Z \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] \cdot p(\varepsilon) {\rm d}\varepsilon \\ &= \nabla_\varphi E_{p(\varepsilon)} \left[ \log {P_\theta(x_i, Z)} - \log {q_\varphi} \right] \\ &= E_{p(\varepsilon)} \left[ \nabla_\varphi (\log {P_\theta(x_i, Z)} - \log {q_\varphi}) \right] & ——\nabla_\varphi与p(\varepsilon)无关 \\ &= E_{p(\varepsilon)} \left[ \nabla_Z (\log {P_\theta(x_i, Z)} - \log {q_\varphi}) \cdot \nabla_\varphi g_\varphi(\varepsilon|x_i) \right] & ——变量转换方法 \\ \end{align} 
        
       
     ∇φL(φ)=∇φ∫Z[logPθ(xi,Z)−logqφ]⋅qφdZ=∇φ∫Z[logPθ(xi,Z)−logqφ]⋅p(ε)dε=∇φEp(ε)[logPθ(xi,Z)−logqφ]=Ep(ε)[∇φ(logPθ(xi,Z)−logqφ)]=Ep(ε)[∇Z(logPθ(xi,Z)−logqφ)⋅∇φgφ(ε∣xi)]——∇φ与p(ε)无关——变量转换方法
通过以上变换我们得到了新的采样对象:
  
      
       
        
         
          
          
           
            
             
             
               ∇ 
              
             
               φ 
              
             
            
              L 
             
            
              ( 
             
            
              φ 
             
            
              ) 
             
            
              = 
             
             
             
               E 
              
              
              
                p 
               
              
                ( 
               
              
                ε 
               
              
                ) 
               
              
             
             
             
               [ 
              
              
              
                ∇ 
               
              
                Z 
               
              
             
               ( 
              
             
               log 
              
             
                
              
              
               
               
                 P 
                
               
                 θ 
                
               
              
                ( 
               
               
               
                 x 
                
               
                 i 
                
               
              
                , 
               
              
                Z 
               
              
                ) 
               
              
             
               − 
              
             
               log 
              
             
                
              
              
              
                q 
               
              
                φ 
               
              
             
               ( 
              
             
               Z 
              
             
               ∣ 
              
              
              
                x 
               
              
                i 
               
              
             
               ) 
              
             
               ) 
              
             
               ⋅ 
              
              
              
                ∇ 
               
              
                φ 
               
              
              
              
                g 
               
              
                φ 
               
              
             
               ( 
              
             
               ε 
              
             
               ∣ 
              
              
              
                x 
               
              
                i 
               
              
             
               ) 
              
             
               ] 
              
             
            
           
          
          
          
         
        
       
         \begin{align} \nabla_\varphi {\mathcal L}(\varphi) = E_{p(\varepsilon)} \left[ \nabla_Z (\log {P_\theta(x_i, Z)} - \log {q_\varphi}(Z|x_i)) \cdot \nabla_\varphi g_\varphi(\varepsilon|x_i) \right] \end{align} 
        
       
     ∇φL(φ)=Ep(ε)[∇Z(logPθ(xi,Z)−logqφ(Z∣xi))⋅∇φgφ(ε∣xi)]
 若通过MC对以下对象进行采样,就不会有问题:
  
      
       
        
         
         
           ε 
          
          
          
            ( 
           
          
            l 
           
          
            ) 
           
          
         
        
          ∽ 
         
        
          p 
         
        
          ( 
         
        
          ε 
         
        
          ) 
         
        
          , 
         
        
          l 
         
        
          ∈ 
         
        
          [ 
         
        
          1 
         
        
          , 
         
        
          L 
         
        
          ] 
           
        
          ⟹ 
           
         
         
           ∇ 
          
         
           φ 
          
         
        
          L 
         
        
          ( 
         
        
          φ 
         
        
          ) 
         
        
          ≈ 
         
         
         
           1 
          
         
           L 
          
         
         
         
           ∑ 
          
          
          
            l 
           
          
            = 
           
          
            1 
           
          
         
           L 
          
         
         
         
           [ 
          
          
          
            ∇ 
           
          
            Z 
           
          
         
           ( 
          
         
           log 
          
         
            
          
          
           
           
             P 
            
           
             θ 
            
           
          
            ( 
           
           
           
             x 
            
           
             i 
            
           
          
            , 
           
          
            Z 
           
          
            ) 
           
          
         
           − 
          
         
           log 
          
         
            
          
          
          
            q 
           
          
            φ 
           
          
         
           ( 
          
         
           Z 
          
         
           ∣ 
          
          
          
            x 
           
          
            i 
           
          
         
           ) 
          
         
           ) 
          
         
           ⋅ 
          
          
          
            ∇ 
           
          
            φ 
           
          
          
          
            g 
           
          
            φ 
           
          
         
           ( 
          
         
           ε 
          
         
           ∣ 
          
          
          
            x 
           
          
            i 
           
          
         
           ) 
          
         
           ] 
          
         
        
       
         \varepsilon^{(l)} \backsim p(\varepsilon), l \in [1, L] \implies \nabla_\varphi {\mathcal L}(\varphi) \approx \frac{1}{L} \sum_{l=1}^L \left[ \nabla_Z (\log {P_\theta(x_i, Z)} - \log {q_\varphi}(Z|x_i)) \cdot \nabla_\varphi g_\varphi(\varepsilon|x_i) \right] 
        
       
     ε(l)∽p(ε),l∈[1,L]⟹∇φL(φ)≈L1l=1∑L[∇Z(logPθ(xi,Z)−logqφ(Z∣xi))⋅∇φgφ(ε∣xi)]
 公式中的 
     
      
       
       
         Z 
        
       
      
        Z 
       
      
    Z都可以通过 
     
      
       
        
        
          g 
         
        
          φ 
         
        
       
         ( 
        
       
         ε 
        
       
         ∣ 
        
        
        
          x 
         
        
          i 
         
        
       
         ) 
        
       
      
        g_\varphi(\varepsilon|x_i) 
       
      
    gφ(ε∣xi)求得。
所以SGVI的核心方式还是通过梯度上升的方式进行迭代,但要使用参数重构方法降低计算难度:
  
      
       
        
        
          φ 
         
        
          = 
         
        
          φ 
         
        
          + 
         
        
          S 
         
        
          t 
         
        
          e 
         
        
          p 
         
        
          ⋅ 
         
         
         
           ∇ 
          
         
           φ 
          
         
        
          L 
         
        
          ( 
         
        
          φ 
         
        
          ) 
         
        
       
         \varphi = \varphi + Step \cdot \nabla_\varphi {\mathcal L}(\varphi) 
        
       
     φ=φ+Step⋅∇φL(φ)



















