本文记录大模型推理阶段 KV Cache 的原理及显存占用情况。
Self-Attention 与 KV Cache
如图,当新生成的 token x 进到模型计算 Attention 时,先分别乘上参数矩阵  
     
      
       
        
        
          W 
         
        
          q 
         
        
       
      
        W_q 
       
      
    Wq、 
     
      
       
        
        
          W 
         
        
          k 
         
        
       
      
        W_k 
       
      
    Wk、 
     
      
       
        
        
          W 
         
        
          v 
         
        
       
      
        W_v 
       
      
    Wv 得到向量 q,以及矩阵 K、V。然后根据下面公式计算当前 token 跟前面 tokens 的注意力权重(本文为了简化,不考虑多头 MHA)。
 
 自回归生成过程中,K和V矩阵并没有太大变化,比如下图中 cold 这个词对应了 K 的某一列和 V 的某一行,算完就放那里不再变了。
 
 轮到生成 chill 这个词时,其实只需要在原始 K 矩阵追加一列,原始 V 矩阵追加一行,而没必要每生成一个 token 都重新计算一遍 K、V 矩阵,这便是 KV Cache 的意义。
 
因此在推理的时候,不用每次传入前面全部 token 序列的 embedding,而只需传入 KV Cache 以及当前 token x 的 embedding。Transformer 在算完当前 token x 的 Attention 之后,会把新的 K’ 和 V’ 更新到 GPU 显存中。 左图中 Masked Multi Self Attention 这块也是唯一和前面序列有交互的模块,其他模块(比如 Layer Norm、FFN、位置编码等)都不涉及跟已生成 token 的交互。
 
KV Cache 显存占用分析
KV Cache 显存计算方式如下:
  
      
       
        
        
          2 
         
        
          ∗ 
         
        
          p 
         
        
          r 
         
        
          e 
         
        
          c 
         
        
          i 
         
        
          s 
         
        
          i 
         
        
          o 
         
        
          n 
         
        
          ∗ 
         
         
         
           n 
          
          
          
            l 
           
          
            a 
           
          
            y 
           
          
            e 
           
          
            r 
           
          
         
        
          ∗ 
         
         
         
           d 
          
          
          
            m 
           
          
            o 
           
          
            d 
           
          
            e 
           
          
            l 
           
          
         
        
          ∗ 
         
        
          s 
         
        
          e 
         
        
          q 
         
        
          _ 
         
        
          l 
         
        
          e 
         
        
          n 
         
        
          ∗ 
         
        
          b 
         
        
          a 
         
        
          t 
         
        
          c 
         
        
          h 
         
        
          _ 
         
        
          s 
         
        
          i 
         
        
          z 
         
        
          e 
         
        
       
         2 * precision * n_{layer} * d_{model} * seq\_len * batch\_size 
        
       
     2∗precision∗nlayer∗dmodel∗seq_len∗batch_size
- 2 2 2 是指 K 跟 V 俩矩阵。
 - p r e c i s i o n precision precision 是模型每个参数的字节数,比如 fp32 精度下每个参数 4 字节。
 - n l a y e r n_{layer} nlayer 和 n m o d e l n_{model} nmodel 分别是模型 Decoder layer 层数和 embedding 维度大小。
 - s e q _ l e n seq\_len seq_len、 b a t c h _ s i z e batch\_size batch_size 顾名思义分别是最大序列长度和 global batch size。
 
比如以 OPT-30B 模型(bf16,48层,7168维,1024上下文,128 batch size)为例,KV Cache 占的显存是:
  
      
       
        
        
          2 
         
        
          ∗ 
         
        
          2 
         
        
          ∗ 
         
        
          48 
         
        
          ∗ 
         
        
          7168 
         
        
          ∗ 
         
        
          1024 
         
        
          ∗ 
         
        
          128 
         
         
        
          = 
         
        
          180 
         
        
          , 
         
        
          388 
         
        
          , 
         
        
          626 
         
        
          , 
         
        
          432 
         
        
          b 
         
        
          y 
         
        
          t 
         
        
          e 
         
        
          s 
         
         
        
          ≈ 
         
        
          180 
         
        
          G 
         
        
          B 
         
        
       
         2*2*48*7168*1024*128 \\=180,388,626,432 bytes \\≈ 180GB 
        
       
     2∗2∗48∗7168∗1024∗128=180,388,626,432bytes≈180GB
模型本身仅占显存: 
      
       
        
        
          2 
         
        
          ∗ 
         
        
          30 
         
        
          B 
         
        
          = 
         
        
          60 
         
        
          B 
         
        
          b 
         
        
          y 
         
        
          t 
         
        
          e 
         
        
          s 
         
        
          = 
         
        
          60 
         
        
          G 
         
        
          B 
         
        
       
         2*30B=60Bbytes=60GB 
        
       
     2∗30B=60Bbytes=60GB
 光 KV Cache 就顶模型本身占显存的3倍。(当然一般推理时 batch size是1,这时候KV Cache显存占用就砍到 1/128 了,不过 batch 模式能够最大化利用显存,所以这也是为啥各个大模型厂商 batch 模型都比较便宜了)
参考资料:油管《The KV Cache: Memory Usage in Transformers》



















