背景
Focal loss是最初由何恺明提出的,最初用于图像领域解决数据不平衡造成的模型性能问题。本文试图从交叉熵损失函数出发,分析数据不平衡问题,focal loss与交叉熵损失函数的对比,给出focal loss有效性的解释。
交叉熵损失函数
 
      
       
        
        
          L 
         
        
          o 
         
        
          s 
         
        
          s 
         
        
          = 
         
        
          L 
         
        
          ( 
         
        
          y 
         
        
          , 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
          = 
         
        
          − 
         
        
          y 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
          − 
         
        
          ( 
         
        
          1 
         
        
          − 
         
        
          y 
         
        
          ) 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
        
          1 
         
        
          − 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
       
         Loss = L(y, \hat{p})=-ylog(\hat{p}) - (1 - y)log(1-\hat{p}) 
        
       
     Loss=L(y,p^)=−ylog(p^)−(1−y)log(1−p^)
 其中, 
     
      
       
        
        
          p 
         
        
          ^ 
         
        
       
      
        \hat{p} 
       
      
    p^为预测概率,y为真实label,二分类对应0、1。
 对二分类中换种写法如下
  
      
       
        
         
         
           L 
          
          
          
            c 
           
          
            e 
           
          
         
        
          ( 
         
        
          y 
         
        
          , 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
          = 
         
         
         
           { 
          
          
           
            
             
              
              
                − 
               
              
                l 
               
              
                o 
               
              
                g 
               
              
                ( 
               
               
               
                 p 
                
               
                 ^ 
                
               
              
                ) 
               
              
                , 
               
              
                        
               
              
                i 
               
              
                f 
               
              
                   
               
              
                y 
               
              
                = 
               
              
                1 
               
              
             
            
            
             
              
             
            
           
           
            
             
              
              
                − 
               
              
                l 
               
              
                o 
               
              
                g 
               
              
                ( 
               
              
                1 
               
              
                − 
               
               
               
                 p 
                
               
                 ^ 
                
               
              
                ) 
               
              
                , 
               
              
                i 
               
              
                f 
               
              
                   
               
              
                y 
               
              
                = 
               
              
                0 
               
              
             
            
            
             
              
             
            
           
          
         
        
          ( 
         
        
          1 
         
        
          ) 
         
        
       
         L_{ce}(y, \hat{p}) = \left\{\begin{matrix} -log(\hat{p}), \ \ \ \ \ \ \ if \ \ y = 1&\\ -log(1-\hat{p}), if \ \ y = 0 & \end{matrix}\right. (1) 
        
       
     Lce(y,p^)={−log(p^),       if  y=1−log(1−p^),if  y=0(1)
样本不均衡问题
对于所有样本,损失函数为
  
      
       
        
        
          L 
         
        
          = 
         
         
         
           1 
          
         
           N 
          
         
         
         
           ∑ 
          
          
          
            i 
           
          
            = 
           
          
            1 
           
          
         
           N 
          
         
        
          l 
         
        
          ( 
         
         
         
           y 
          
         
           i 
          
         
        
          , 
         
         
          
          
            p 
           
          
            ^ 
           
          
         
           i 
          
         
        
          ) 
         
        
       
         L=\frac{1}{N}\sum_{i=1}^{N}l(y_i, \hat{p}_i) 
        
       
     L=N1i=1∑Nl(yi,p^i)
 对于二分类问题,损失函数具体化为
  
      
       
        
        
          L 
         
        
          = 
         
         
         
           1 
          
         
           N 
          
         
        
          ( 
         
         
         
           ∑ 
          
          
           
           
             y 
            
           
             i 
            
           
          
            = 
           
          
            1 
           
          
         
           m 
          
         
        
          − 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
          + 
         
         
         
           ∑ 
          
          
           
           
             y 
            
           
             i 
            
           
          
            = 
           
          
            0 
           
          
         
           n 
          
         
        
          − 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
        
          1 
         
        
          − 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
          ) 
         
        
       
         L=\frac{1}{N}(\sum_{y_i=1}^{m}-log(\hat{p}) + \sum_{y_i=0}^{n}-log(1-\hat{p})) 
        
       
     L=N1(yi=1∑m−log(p^)+yi=0∑n−log(1−p^))
 其中m为正样本个数,n为负样本个数,N为样本总数,m+n=N。
当样本分布失衡时,在损失函数L的分布也会发生倾斜,如m<<n时,负样本就会在损失函数占据主导地位。由于损失函数的倾斜,模型训练过程中会倾向于样本多的类别,造成模型对少样本类别的性能较差。
平衡交叉熵函数(balanced cross entropy)
基于样本非平衡造成的损失函数倾斜,一个直观的做法就是在损失函数中添加权重因子,提高少数类别在损失函数中的权重,平衡损失函数的分布。如在上述二分类问题中,添加权重参数 
     
      
       
       
         α 
        
       
         ∈ 
        
       
         [ 
        
       
         0 
        
       
         , 
        
       
         1 
        
       
         ] 
        
       
      
        \alpha \in [0, 1] 
       
      
    α∈[0,1] 和  
     
      
       
       
         1 
        
       
         − 
        
       
         α 
        
       
      
        1-\alpha 
       
      
    1−α
  
      
       
        
        
          L 
         
        
          = 
         
         
         
           1 
          
         
           N 
          
         
        
          ( 
         
         
         
           ∑ 
          
          
           
           
             y 
            
           
             i 
            
           
          
            = 
           
          
            1 
           
          
         
           m 
          
         
        
          − 
         
        
          α 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
          + 
         
         
         
           ∑ 
          
          
           
           
             y 
            
           
             i 
            
           
          
            = 
           
          
            0 
           
          
         
           n 
          
         
        
          − 
         
        
          ( 
         
        
          1 
         
        
          − 
         
        
          α 
         
        
          ) 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
        
          1 
         
        
          − 
         
         
         
           p 
          
         
           ^ 
          
         
        
          ) 
         
        
          ) 
         
        
       
         L=\frac{1}{N}(\sum_{y_i=1}^{m}-\alpha log(\hat{p}) + \sum_{y_i=0}^{n}-(1- \alpha)log(1-\hat{p})) 
        
       
     L=N1(yi=1∑m−αlog(p^)+yi=0∑n−(1−α)log(1−p^))
 其中 
     
      
       
        
        
          α 
         
         
         
           1 
          
         
           − 
          
         
           α 
          
         
        
       
         = 
        
        
        
          n 
         
        
          m 
         
        
       
      
        \frac {\alpha}{1- \alpha} = \frac {n}{m} 
       
      
    1−αα=mn,即正负样本损失权重的根据正负样本的比例分布进行设置,具体里说,成反比关系。
focal loss
focal loss也是针对样本不均衡问题,从loss角度提供的另外一种解决方法。
 focal loss的具体形式为:
  
      
       
        
         
         
           L 
          
          
          
            f 
           
          
            l 
           
          
         
        
          = 
         
         
         
           { 
          
          
           
            
             
              
              
                − 
               
              
                ( 
               
              
                1 
               
              
                − 
               
               
               
                 p 
                
               
                 ^ 
                
               
               
               
                 ) 
                
               
                 γ 
                
               
              
                l 
               
              
                o 
               
              
                g 
               
              
                ( 
               
               
               
                 p 
                
               
                 ^ 
                
               
              
                ) 
               
              
                   
               
              
                i 
               
              
                f 
               
              
                  
               
              
                y 
               
              
                = 
               
              
                1 
               
              
             
            
            
             
              
             
            
           
           
            
             
              
              
                − 
               
               
               
                 p 
                
               
                 ^ 
                
               
              
                l 
               
              
                o 
               
              
                g 
               
              
                ( 
               
              
                1 
               
              
                − 
               
               
               
                 p 
                
               
                 ^ 
                
               
              
                ) 
               
              
                        
               
              
                i 
               
              
                f 
               
              
                  
               
              
                y 
               
              
                = 
               
              
                0 
               
              
             
            
            
             
              
             
            
           
          
         
        
            
         
        
          ( 
         
        
          2 
         
        
          ) 
         
        
       
         L_{fl} = \left\{\begin{matrix} -(1-\hat{p})^{\gamma}log(\hat{p}) \ \ if \ y=1 &\\ -\hat{p}log(1-\hat{p}) \ \ \ \ \ \ \ if \ y=0 & \end{matrix}\right. \ (2) 
        
       
     Lfl={−(1−p^)γlog(p^)  if y=1−p^log(1−p^)       if y=0 (2)
 令 
     
      
       
        
        
          p 
         
        
          t 
         
        
       
         = 
        
        
        
          { 
         
         
          
           
            
             
              
              
                p 
               
              
                ^ 
               
              
             
                    
              
             
               i 
              
             
               f 
              
             
                 
              
             
               y 
              
             
               = 
              
             
               1 
              
             
            
           
           
            
             
            
           
          
          
           
            
             
             
               1 
              
             
               − 
              
              
              
                p 
               
              
                ^ 
               
              
             
                 
              
             
               o 
              
             
               t 
              
             
               h 
              
             
               e 
              
             
               r 
              
             
               w 
              
             
               i 
              
             
               s 
              
             
               e 
              
             
            
           
           
            
             
            
           
          
         
        
       
      
        p_t = \left\{\begin{matrix} \hat{p} \ \ \ \ if \ y = 1&\\ 1-\hat{p} \ otherwise & \end{matrix}\right. 
       
      
    pt={p^    if y=11−p^ otherwise
 将focal loss表达式(2)统一为一个表达式:
  
      
       
        
         
         
           L 
          
         
           f 
          
         
        
          l 
         
        
          = 
         
        
          − 
         
        
          ( 
         
        
          1 
         
        
          − 
         
         
         
           p 
          
         
           t 
          
         
         
         
           ) 
          
         
           γ 
          
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
         
         
           p 
          
         
           t 
          
         
        
          ) 
         
        
               
         
        
          ( 
         
        
          3 
         
        
          ) 
         
        
       
         L_fl=-(1-p_t)^{\gamma}log(p_t) \ \ \ \ (3) 
        
       
     Lfl=−(1−pt)γlog(pt)    (3)
 同理可将交叉熵表达式(1)统一为一个表达式:
  
      
       
        
         
         
           L 
          
          
          
            c 
           
          
            e 
           
          
         
        
          = 
         
        
          − 
         
        
          l 
         
        
          o 
         
        
          g 
         
        
          ( 
         
         
         
           p 
          
         
           t 
          
         
        
          ) 
         
        
       
         L_{ce} = -log(p_t) 
        
       
     Lce=−log(pt)
  
     
      
       
        
        
          p 
         
        
          t 
         
        
       
      
        p_t 
       
      
    pt 反映了与ground truth即类别  
     
      
       
       
         y 
        
       
      
        y 
       
      
    y 的接近程度,  
     
      
       
        
        
          p 
         
        
          t 
         
        
       
      
        p_t 
       
      
    pt越大说明越接近类别  
     
      
       
       
         y 
        
       
      
        y 
       
      
    y,即分类越准确。
  
     
      
       
       
         γ 
        
       
      
        \gamma 
       
      
    γ 为可调节因子。
对比表达式(3)和(4), focal loss相比交叉熵多了一个modulating factor即 ( 1 − p t ) γ (1-p_t)^{\gamma} (1−pt)γ。对于分类准确的样本 p t → 1 p_t \rightarrow 1 pt→1,modulating factor趋近于0。对于分类不准确的样本 1 − p t → 1 1-p_t \rightarrow 1 1−pt→1,modulating factor趋近于1。即相比交叉熵损失,focal loss对于分类不准确的样本,损失没有改变,对于分类准确的样本,损失会变小。 整体而言,相当于增加了分类不准确样本在损失函数中的权重。
 
     
      
       
        
        
          p 
         
        
          t 
         
        
       
      
        p_t 
       
      
    pt也反应了分类的难易程度, 
     
      
       
        
        
          p 
         
        
          t 
         
        
       
      
        p_t 
       
      
    pt 越大,说明分类的置信度越高,代表样本越易分; 
     
      
       
        
        
          p 
         
        
          t 
         
        
       
      
        p_t 
       
      
    pt 越小,分类的置信度越低,代表样本越难分。因此focal loss相当于增加了难分样本在损失函数的权重,使得损失函数倾向于难分的样本,有助于提高难分样本的准确度。focal loss与交叉熵的对比,可见下图:
 
focal loss vs balanced cross entropy
focal loss相比balanced cross entropy而言,二者都是试图解决样本不平衡带来的模型训练问题,后者从样本分布角度对损失函数添加权重因子,前者从样本分类难易程度出发,使loss聚焦于难分样本。
focal loss为什么有效
focal loss从样本难易分类角度出发,解决样本非平衡带来的模型训练问题。
相信很多人会在这里有一个疑问,样本难易分类角度怎么能够解决样本非平衡的问题,直觉上来讲样本非平衡造成的问题就是样本数少的类别分类难度较高。因此从样本难易分类角度出发,使得loss聚焦于难分样本,解决了样本少的类别分类准确率不高的问题,当然难分样本不限于样本少的类别,也就是focal loss不仅仅解决了样本非平衡的问题,同样有助于模型的整体性能提高。
要想使模型训练过程中聚焦难分类样本,仅仅使得Loss倾向于难分类样本还不够,因为训练过程中模型参数更新取决于Loss的梯度。
  
      
       
        
        
          w 
         
        
          = 
         
        
          w 
         
        
          − 
         
        
          α 
         
         
          
          
            ∂ 
           
          
            L 
           
          
          
          
            ∂ 
           
          
            w 
           
          
         
        
             
         
        
          ( 
         
        
          5 
         
        
          ) 
         
        
       
         w = w - \alpha \frac{ \partial {L}}{ \partial {w}} \ \ (5) 
        
       
     w=w−α∂w∂L  (5)
 如果Loss中难分类样本权重较高,但是难分类样本的Loss的梯度为0,难分类样本不会影响模型学习过程。
对于梯度问题,在focal loss的文章中,也给出了答案。如下图所示为focal loss的梯度示意图。其中 
     
      
       
        
        
          x 
         
        
          t 
         
        
       
         = 
        
       
         y 
        
       
         x 
        
       
      
        x_t = yx 
       
      
    xt=yx ,其中  
     
      
       
       
         y 
        
       
         ∈ 
        
        
        
          { 
         
        
          0 
         
        
          , 
         
        
          1 
         
        
       
         } 
        
       
      
        y \in {\{0, 1}\} 
       
      
    y∈{0,1} 为类别,  
     
      
       
        
        
          p 
         
        
          t 
         
        
       
         = 
        
       
         σ 
        
       
         ( 
        
        
        
          x 
         
        
          t 
         
        
       
         ) 
        
       
      
        p_t=\sigma(x_t) 
       
      
    pt=σ(xt) ,对于易分样本,  
     
      
       
        
        
          x 
         
        
          t 
         
        
       
         > 
        
       
         0 
        
       
      
        x_t > 0 
       
      
    xt>0,即  
     
      
       
        
        
          p 
         
        
          t 
         
        
       
         > 
        
       
         0.5 
        
       
      
        p_t>0.5 
       
      
    pt>0.5 。由图中可以看出,对于focal loss而言,在  
     
      
       
        
        
          x 
         
        
          t 
         
        
       
         > 
        
       
         0 
        
       
      
        x_t>0 
       
      
    xt>0 时,导数很小,趋近于0。因此对于focal loss,导数中易是难分类样本占主导,因此学习过程更加聚焦正在难分类样本。
 
一点小思考
难分类样本与易分类样本其实是一个动态概念,也就是说 p t p_t pt 会随着训练过程而变化。原先易分类样本即 p t p_t pt 大的样本,可能随着训练过程变化为难训练样本即 p t p_t pt 小的样本。
上面讲到,由于Loss梯度中,难训练样本起主导作用,即参数的变化主要是朝着优化难训练样本的方向改变。当参数变化后,可能会使原先易训练的样本 p t p_t pt 发生变化,即可能变为难训练样本。当这种情况发生时,可能会造成模型收敛速度慢,正如苏剑林在他的文章中提到的那样。
为了防止难易样本的频繁变化,应当选取小的学习率。防止学习率过大,造成 w w w 变化较大从而引起 p t p_t pt 的巨大变化,造成难易样本的改变。



















