欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/141462669

LLaMA3 是 Meta 的最新大语言模型,在整体网络设计进行多项升级,显著提升了模型的性能和效率,重要的改进,如下:
- 词汇量增加至 128k 个。
- 使用 RMS Normalization,即 根均方正则化。
- 使用 旋转位置编码 RoPE。
- 使用 Grouped Query Attention,即 分组查询注意力,head 数量是 32,4组,即 8 个 KV head。
- 使用 SwiGLU Feedforward Network,即 SwiGLU 前馈网络。
详细的推理流程如下:
添加依赖的 Python 库:
pip install tiktoken
pip install matplotlib
pip install blobfile
- tiktoken:快速的 BPE (字节对编码) 分词器,用于与 类 OpenAI 的模型一起使用。
- blobfile:用于处理云存储 (如 Amazon S3、Google Cloud Storage 等) 的 Python 库。
1. 加载 Tokenizer (分词器)
Tokenizer (分词器),将单词切分成 Token,常见的是 BPE (Byte Pair Encoding),字节对编码。
Byte Pair Encoding,即 BPE,也称为 digram coding 是一种算法,用于将文本字符串编码成表格形式,在下游模型中使用,BPE 是一种数据压缩技术,被 OpenAI 用于预训练 GPT 模型时的分词,被许多 Transformer 模型广泛使用。
Tokenizer 的源码:
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json
import os
import matplotlib.pyplot as plt
os.environ['CUDA_VISIBLE_DEVICES'] = "0"  #(代表仅使用第0,1号GPU)
# Meta-Llama-3-8B 来自于 HuggingFace
model_path = "Meta-Llama-3-8B/original/consolidated.00.pth"
tokenizer_path = "Meta-Llama-3-8B/original/tokenizer.model"
special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
tokenizer = tiktoken.Encoding(
    name=Path(tokenizer_path).name,
    pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
    mergeable_ranks=mergeable_ranks,
    special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)
tokenizer.decode(tokenizer.encode("hello world!"))
关于 tiktoken.Encoding 的参数说明:
- name=Path(tokenizer_path).name:指定编码对象的名称。
- pat_str:正则表达式,用于定义文本的分词规则,英文的默认参数。
- mergeable_ranks=load_tiktoken_bpe(tokenizer_path),加载 BPE 编码的规则,属于模型的一部分。
- special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},定了特殊 token 的映射关系。将特殊 token 映射到一个整数值,这个整数值是基于已加载的 BPE 编码规则的长度计算得出的,顺次增加。
测试 hello world!,编码和解码,输出一致,输出:
hello world!
2. 加载 Model (模型)
加载 Llama3 的模型参数,直接使用 pth 模型
model = torch.load(model_path)
print(json.dumps(list(model.keys())[:20], indent=4))
路径:
Meta-Llama-3-8B/original/consolidated.00.pth
加载 Llama3 的模型参数:
with open("Meta-Llama-3-8B/original/params.json", "r") as f:
    config = json.load(f)
参数如下:
{
	'dim': 4096,
	'n_layers': 32,
	'n_heads': 32,
	'n_kv_heads': 8,
	'vocab_size': 128256,
	'multiple_of': 1024,
	'ffn_dim_multiplier': 1.3,
	'norm_eps': 1e-05,
	'rope_theta': 500000.0
}
- dim:隐藏状态维度,即 4096,隐藏状态是模型在处理输入序列时的内部表示。
- n_layers:模型的层数,即 32,层数决定模型的复杂度和表示能力。
- n_heads:自注意力机制中的头数,即 32,自注意力机制用于捕捉输入序列中的不同关系。
- n_kv_heads:键值注意力机制中的头数,即 8,键值对注意力机制用于处理不同类型的信息。
- vocab_size:词汇表的大小,即 128256,词汇表包含模型可以处理的所有单词和标记。
- multiple_of:模型的维度必须是这个值的倍数,即 1024,有助于优化模型的计算效率。
- ffn_dim_multiplier:前馈神经网络 (FFN) 的维度相对于隐藏状态维度的倍数,这里设置为 1.3,FFN 用于对隐藏状态进行非线性变换。
- norm_eps:归一化层的 epsilon 值, 1e-05,归一化层用于规范化模型的中间表示。
- rope_theta:RoPE (Rotary Position Embedding) 的参数,即 500000.0,ROPE 是编码方式,用于处理序列中的相对顺序信息。
缓存参数:
dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])
3. 文本 - Token - Embedding
文本 转换 Token
prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens = [128000] + tokenizer.encode(prompt)
print(tokens)
tokens = torch.tensor(tokens)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
print(prompt_split_as_tokens)
输出:
[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]
['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']
Token 转换 Embedding
embedding_layer = torch.nn.Embedding(vocab_size, dim)
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
token_embeddings_unnormalized.shape
输出:
torch.Size([17, 4096])
4. 实现 RMS Normalization (根均方正则化)
具体RMS Normalization 的函数如下:
# def rms_norm(tensor, norm_weights):
#     rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5
#     return tensor * (norm_weights / rms)
def rms_norm(tensor, norm_weights):
    return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
RMS Normalization,即 Root Mean Square Normalization,根均方正则化,对于网络中层的输入进行规范化的方法,目的是使模型具有重新缩放不变性,并且具备隐式学习率自适应能力,相比于 Layer Normalization,RMS Normalization 更加高效,因为计算复杂度较低,RMS Normalization 公式如下:
  
      
       
        
         
         
           x 
          
         
           ‾ 
          
         
        
          = 
         
         
          
          
            x 
           
          
            i 
           
          
          
          
            R 
           
          
            M 
           
          
            S 
           
          
            ( 
           
          
            x 
           
          
            ) 
           
          
            + 
           
          
            ϵ 
           
          
         
        
          ∗ 
         
         
         
           g 
          
         
           i 
          
         
        
          , 
         
        
            
         
        
          w 
         
        
          h 
         
        
          e 
         
        
          r 
         
        
          e 
         
        
            
         
        
          R 
         
        
          M 
         
        
          S 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
         
          
           
           
             1 
            
           
             n 
            
           
           
           
             ∑ 
            
            
            
              i 
             
            
              = 
             
            
              1 
             
            
           
             n 
            
           
           
           
             x 
            
           
             i 
            
           
             2 
            
           
          
         
        
       
         \overline{x} = \frac{x_{i}}{RMS(x)+\epsilon}*g_{i},\ where \ RMS(x)=\sqrt{\frac{1}{n}\sum_{i=1}^{n}x_{i}^{2}} 
        
       
     x=RMS(x)+ϵxi∗gi, where RMS(x)=n1i=1∑nxi2
 相对于 Layer Normalization 和 RMS Normalization,Layer Normalization 包含缩放和平移两个部分,RMS Normalization 去除了平移部分,只保留了缩放部分。研究表明 LayerNorm 取得成功的关键是缩放部分的缩放不变性,而不是平移部分的平移不变性。RMS Normalization 相比于 Layer Normalization,减少计算均值和平移系数的部分,训练速度更快,效果基本相当,甚至有所提升。
  
      
       
        
        
          μ 
         
        
          = 
         
         
         
           ∑ 
          
          
          
            i 
           
          
            = 
           
          
            1 
           
          
         
           n 
          
         
         
         
           x 
          
         
           i 
          
         
         
        
          σ 
         
        
          = 
         
         
          
           
           
             1 
            
           
             n 
            
           
           
           
             ∑ 
            
            
            
              i 
             
            
              = 
             
            
              1 
             
            
           
             n 
            
           
          
            ( 
           
           
           
             x 
            
           
             i 
            
           
          
            − 
           
           
           
             μ 
            
           
             i 
            
           
           
           
             ) 
            
           
             2 
            
           
          
         
         
        
          y 
         
        
          = 
         
         
          
          
            x 
           
          
            − 
           
          
            μ 
           
          
          
          
            σ 
           
          
            + 
           
          
            ϵ 
           
          
         
        
          γ 
         
        
          + 
         
        
          β 
         
        
       
         \mu=\sum_{i=1}^{n}x_{i} \\ \sigma=\sqrt{\frac{1}{n}\sum_{i=1}^{n}(x_{i}-\mu_{i})^{2}} \\ y=\frac{x-\mu}{\sigma+\epsilon}\gamma+\beta 
        
       
     μ=i=1∑nxiσ=n1i=1∑n(xi−μi)2y=σ+ϵx−μγ+β
 运行 RMS Normalization
token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"])
print(token_embeddings.shape)
输出:
torch.Size([17, 4096])
5. 计算 Self Attention (自注意力) Query
Llama3 的 QKV,以及 O 的相关权重:
print(
    model["layers.0.attention.wq.weight"].shape,
    model["layers.0.attention.wk.weight"].shape,
    model["layers.0.attention.wv.weight"].shape,
    model["layers.0.attention.wo.weight"].shape
)
输出:
torch.Size([4096, 4096])  # wq
torch.Size([1024, 4096])	# wk
torch.Size([1024, 4096])	# wv
torch.Size([4096, 4096])	# vo
Multi-Head Attention,Head (头) 的数量是 n_heads = 32,即:
q_layer0 = model["layers.0.attention.wq.weight"]
head_dim = q_layer0.shape[0] // n_heads
q_layer0 = q_layer0.view(n_heads, head_dim, dim)
print(q_layer0.shape)
Query 的权重:
torch.Size([32, 128, 4096])
每个 Query Head 的维度是 [128, 4096],即:
q_layer0_head0 = q_layer0[0]
print(q_layer0_head0.shape)  # torch.Size([128, 4096])
Token 的维度是 torch.Size([17, 4096]),则与权重相乘,输出是 [17, 128],即:
q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)
print(q_per_token.shape)	# torch.Size([17, 128])
6. 实现 旋转位置编码 RoPE (Rotary Positional Encoding)
使用 RoPE,旋转位置编码,转换 Token 维度:
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
print(q_per_token_split_into_pairs.shape)  # torch.Size([17, 64, 2])
构建旋转位置编码:
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
print(zero_to_one_split_into_64_parts)
输出:
tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
        0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
        0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
        0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
        0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
        0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
        0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
        0.9844])
频率,rope_theta = 500000.0 来自于模型参数:
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
print(freqs)
输出:
tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,
        2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,
        8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,
        2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,
        7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,
        2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,
        6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,
        1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,
        5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,
        1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,
        4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])
转换成极坐标:
freqs_for_each_token = torch.outer(torch.arange(17), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
print(freqs_cis.shape)  # torch.Size([17, 64])
# viewing tjhe third row of freqs_cis
value = freqs_cis[3]
plt.figure()
for i, element in enumerate(value[:17]):
    plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}")
    plt.annotate(f"{i}", xy=(element.real, element.imag), color='red')
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.title('Plot of one row of freqs_cis')
plt.show()
极坐标的图像:

将 Query 的编码转换成 复数(complex) 形式,即 torch.Size([17, 64, 2]) 转换成 torch.Size([17, 64]),再与 freqs_cis 相乘,即:
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
print(q_per_token_as_complex_numbers.shape) # torch.Size([17, 64])
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis
print(q_per_token_as_complex_numbers_rotated.shape) # torch.Size([17, 64])
再将 复数(complex) 转换成实数,维度增加 2 维,则由 torch.Size([17, 64]) 转换成 torch.Size([17, 64, 2]),则:
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)
print(q_per_token_split_into_pairs_rotated.shape) # torch.Size([17, 64, 2])
再转换成 Query 的维度,即 torch.Size([17, 64, 2]) 转换成 torch.Size([17, 128]),这样,Query 向量与 RoPE 位置编码相乘,即:
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
print(q_per_token_rotated.shape) # torch.Size([17, 128])
7. 计算 Self Attention (自注意力) Key 与 QK Mask
将 Key 的 权重 weight,拆分成 n_kv_heads,输出维度 [1028, 4096] 转换 [8, 128, 4096],即:
k_layer0 = model["layers.0.attention.wk.weight"]
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)
print(k_layer0.shape) # torch.Size([8, 128, 4096])
Key 添加 RoPE 的位置编码,与 Query 类似,即:
k_layer0_head0 = k_layer0[0]
print(k_layer0_head0.shape) # torch.Size([128, 4096])
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)  # [17, 4096] x [4096, 128] = [17, 128]
print(k_per_token.shape) # torch.Size([17, 128])
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
print(k_per_token_split_into_pairs.shape) # torch.Size([17, 64, 2])
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
print(k_per_token_as_complex_numbers.shape) # torch.Size([17, 64])
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
print(k_per_token_split_into_pairs_rotated.shape) # torch.Size([17, 64, 2])
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
print(k_per_token_rotated.shape) # torch.Size([17, 128])
Query 矩阵与 Key 矩阵,计算:
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5
print(qk_per_token.shape)
输出:
torch.Size([17, 17])
自注意力机制的矩阵公式:
  
      
       
        
        
          A 
         
        
          = 
         
        
          s 
         
        
          o 
         
        
          f 
         
        
          t 
         
        
          m 
         
        
          a 
         
        
          x 
         
        
          ( 
         
         
          
          
            Q 
           
           
           
             K 
            
           
             ⊤ 
            
           
          
          
           
           
             d 
            
           
             k 
            
           
          
         
        
          ) 
         
        
          V 
         
        
       
         A=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V 
        
       
     A=softmax(dkQK⊤)V
显示这个注意力矩阵:
def display_qk_heatmap(qk_per_token):
    _, ax = plt.subplots()
    im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')
    ax.set_xticks(range(len(prompt_split_as_tokens)))
    ax.set_yticks(range(len(prompt_split_as_tokens)))
    ax.set_xticklabels(prompt_split_as_tokens)
    ax.set_yticklabels(prompt_split_as_tokens)
    ax.figure.colorbar(im, ax=ax)
    
display_qk_heatmap(qk_per_token)
矩阵显示:

Decoder 的 Mask 矩阵:
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)  # 最大负数
mask = torch.triu(mask, diagonal=1)
print(mask)
qk_per_token_after_masking = qk_per_token + mask
display_qk_heatmap(qk_per_token_after_masking)
矩阵显示:

Softmax 之后的概率矩阵:
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
display_qk_heatmap(qk_per_token_after_masking_after_softmax)
矩阵显示:

8. 计算 Self-Attention (自注意力) Value 与 QKV
Value 矩阵,与 Key 矩阵维度相同,都是 n_kv_heads=8, 但是 Value 不需要使用位置编码,最终计算 qkv_attention 矩阵:
v_layer0 = model["layers.0.attention.wv.weight"]
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)
print(v_layer0.shape) # torch.Size([8, 128, 4096])
v_layer0_head0 = v_layer0[0]
print(v_layer0_head0.shape) # torch.Size([128, 4096])
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)
print(v_per_token.shape) # torch.Size([17, 128])
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
print(qkv_attention.shape)  # torch.Size([17, 128])
9. 计算 多头注意力 Grouped Query Attention
在 Llama3 中,头 n_heads=32 的数量是 32 个,KV头 n_kv_heads=8 的数量是 8 个,即每 4 个 Query 共享 1 组 KV。也就是说,多头自注意力 n_heads=32 与 KV 头 n_kv_heads=8,融合到在一起,相当于 Query 的 1~4,使用相同的 Key 和 Value,输出 32 个 Head, [32, 17, 128] 维度,32x128 = 4096,与 Embedding 的维度相同,即:
qkv_attention_store = []
scale = n_heads // n_kv_heads  # 32 / 8 = 4
for head in range(n_heads):
    q_layer0_head = q_layer0[head]
    k_layer0_head = k_layer0[head//scale] # key weights are shared across 4 heads
    v_layer0_head = v_layer0[head//scale] # value weights are shared across 4 heads
    q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
    k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
    v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)
    q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
    q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
    q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
    k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
    k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
    k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
    qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
    mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
    mask = torch.triu(mask, diagonal=1)
    qk_per_token_after_masking = qk_per_token + mask
    qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    qkv_attention_store.append(qkv_attention)
len(qkv_attention_store)  # 32 个 head
GQA (Grouped Query Attention),与 MHA (Multi-Head Attention) 、MQA (Multi-Query Attention) 的区别如下:
- MHA 是一种基础的注意力机制,通过将输入分割成多个头 (heads) 来并行计算注意力,每个头学习输入的不同部分,最终将结果合并,以捕获序列的不同方面信息。
- MQA 是优化的注意力机制,通过让所有头共享相同的键 (keys) 和值 (values),减少了参数量和计算量,从而加快了推理速度,但可能会牺牲一些精度。
- GQA 是 MHA 和 MQA 的折中方案,将查询头 (Query Heads) 分组,每组共享一个键和值,而不是所有头都共享,能够在减少计算量的同时,保持更多的多样性,从而在推理速度和模型精度之间取得平衡。
再融合 32 个 Head,即完成数据转换:
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
print(stacked_qkv_attention.shape) # torch.Size([17, 4096])
输出权重的维度,即 [4096, 4096],再进行线性变换,输入经过自注意力的 Embedding,即:
w_layer0 = model["layers.0.attention.wo.weight"]
print(w_layer0.shape) # torch.Size([4096, 4096])
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)
print(embedding_delta.shape) # torch.Size([17, 4096])
再进行残差连接,偏移 embedding_delta 相加的是非正则的矩阵(token_embeddings_unnormalized),即:
embedding_after_edit = token_embeddings_unnormalized + embedding_delta
print(embedding_after_edit.shape)  # torch.Size([17, 4096])
embedding_after_edit_normalized 进行 RMS 正则化,即:
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])
print(embedding_after_edit_normalized.shape)  # torch.Size([17, 4096])
10. 实现 SwiGLU 前馈网络 (SwiGLU Feedforward Network)
SwiGLU 前馈网络 的具体操作如下:
w1 = model["layers.0.feed_forward.w1.weight"]  # torch.Size([14336, 4096])
w2 = model["layers.0.feed_forward.w2.weight"]  # torch.Size([4096, 14336]), 扩展3.5倍
w3 = model["layers.0.feed_forward.w3.weight"]  # torch.Size([14336, 4096])
tmp1 = torch.matmul(embedding_after_edit_normalized, w1.T)  # 提升维度, [17, 14336]
tmp2 = torch.matmul(embedding_after_edit_normalized, w3.T)  # 提升维度, [17, 14336]
output_after_feedforward = torch.matmul(torch.functional.F.silu(tmp1) * tmp2, w2.T)
print(output_after_feedforward.shape) # torch.Size([17, 4096])
# 残差连接
layer_0_embedding = embedding_after_edit+output_after_feedforward
print(layer_0_embedding.shape)  # torch.Size([17, 4096])
SwiGLU FFN 的网络如下:
  
      
       
        
        
          S 
         
        
          w 
         
        
          i 
         
        
          G 
         
        
          L 
         
        
          U 
         
        
            
         
        
          F 
         
        
          F 
         
        
          N 
         
        
          = 
         
        
          ( 
         
        
          S 
         
        
          i 
         
        
          L 
         
        
          U 
         
        
          ( 
         
        
          x 
         
         
         
           W 
          
         
           1 
          
         
           ⊤ 
          
         
        
          ) 
         
        
          ⊙ 
         
        
          x 
         
         
         
           W 
          
         
           3 
          
         
           ⊤ 
          
         
        
          ) 
         
        
          x 
         
         
         
           W 
          
         
           2 
          
         
           ⊤ 
          
         
        
       
         SwiGLU\ FFN=(SiLU(xW_{1}^{\top}) \odot xW_{3}^{\top})xW_{2}^{\top} 
        
       
     SwiGLU FFN=(SiLU(xW1⊤)⊙xW3⊤)xW2⊤
 SwiGLU 相比于 ReLU 函数的优势:
- Swish 对于负值的响应相对较小,克服 ReLU 的输出始终为零
- GLU 具有门控特性,根据输入决定信息是否通过或过滤,使网络更有效地学习到有用的表示,有助于提高模型的泛化能力。
关于 F.silu 是 swish 激活函数:
  
      
       
        
        
          s 
         
        
          i 
         
        
          l 
         
        
          u 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
        
          x 
         
        
          ∗ 
         
        
          σ 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
         
        
          σ 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
         
         
           1 
          
          
          
            1 
           
          
            + 
           
           
           
             e 
            
            
            
              − 
             
            
              x 
             
            
           
          
         
        
       
         silu(x)=x*\sigma(x) \\ \sigma(x) = \frac{1}{1+e^{-x}} 
        
       
     silu(x)=x∗σ(x)σ(x)=1+e−x1
 绘制函数:
def draw(func):
    x = np.arange(-10, 10, 0.1)
    y = []
    x_torch = torch.from_numpy(x)
    for t in x_torch:
        y_1 = func(t)
        y_1 = y_1.numpy()
        y.append(y_1)
    plt.plot(x, y, label="silu")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.xlim(-7, 7)
    plt.ylim(-7, 7)
    plt.grid()
    plt.legend()
    plt.show()
ReLU 图像:
draw(torch.functional.F.silu)
def my_silu(x):
    t = 1 / (1 + torch.exp(-x))
    return x*t

参考:torch.nn.SiLU
11. 实现网络的循环多层计算
每一层都运行相同模块,层数 n_layers=32,即:
final_embedding = token_embeddings_unnormalized
for layer in range(n_layers):
    qkv_attention_store = []
    layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
    q_layer = model[f"layers.{layer}.attention.wq.weight"]
    q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
    k_layer = model[f"layers.{layer}.attention.wk.weight"]
    k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
    v_layer = model[f"layers.{layer}.attention.wv.weight"]
    v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    for head in range(n_heads):
        q_layer_head = q_layer[head]
        k_layer_head = k_layer[head//4]
        v_layer_head = v_layer[head//4]
        q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
        k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
        v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
        q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
        q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
        q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
        q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
        k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
        k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
        k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
        k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
        qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
        mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        qk_per_token_after_masking = qk_per_token + mask
        qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
        qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
        qkv_attention_store.append(qkv_attention)
    stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
    embedding_after_edit = final_embedding + embedding_delta
    embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
    w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
    w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
    w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
    output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
    final_embedding = embedding_after_edit+output_after_feedforward
再执行 RMS Norm:
final_embedding = rms_norm(final_embedding, model["norm.weight"])
print(final_embedding.shape)  # torch.Size([17, 4096])
12. 解码输出字符
解码最后一维特征,即 final_embedding[-1],输出的权重是 [128256, 4096],即输出向量是 4096 维,包括 128256 个 Token:
print(model["output.weight"].shape) # torch.Size([128256, 4096])
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
print(logits.shape) # torch.Size([128256])
next_token = torch.argmax(logits, dim=-1)
print(next_token)  # tensor(2983)
output_v = tokenizer.decode([next_token.item()])
print(output_v)  # '42'
完整的输入和输出:
"the answer to the ultimate question of life, the universe, and everything is "
"42"
13. 完整源码
参考: llama3-from-scratch,完整源码如下:
import json
import os
from pathlib import Path
import tiktoken
import torch
from tiktoken.load import load_tiktoken_bpe
from tqdm import tqdm
os.environ['CUDA_VISIBLE_DEVICES'] = "4"  #(代表仅使用第0,1号GPU)
def infer_for_llama3(prompt, model_path, tokenizer_path, config_path):
    # model_path = "Meta-Llama-3-8B/original/consolidated.00.pth"
    # tokenizer_path = "Meta-Llama-3-8B/original/tokenizer.model"
    # config_path = "Meta-Llama-3-8B/original/params.json"
    special_tokens = [
                         "<|begin_of_text|>",
                         "<|end_of_text|>",
                         "<|reserved_special_token_0|>",
                         "<|reserved_special_token_1|>",
                         "<|reserved_special_token_2|>",
                         "<|reserved_special_token_3|>",
                         "<|start_header_id|>",
                         "<|end_header_id|>",
                         "<|reserved_special_token_4|>",
                         "<|eot_id|>",  # end of turn
                     ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
    mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
    tokenizer = tiktoken.Encoding(
        name=Path(tokenizer_path).name,
        pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
        mergeable_ranks=mergeable_ranks,
        special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
    )
    model = torch.load(model_path)
    with open(config_path, "r") as f:
        config = json.load(f)
    dim = config["dim"]
    n_layers = config["n_layers"]
    n_heads = config["n_heads"]
    n_kv_heads = config["n_kv_heads"]
    vocab_size = config["vocab_size"]
    multiple_of = config["multiple_of"]
    ffn_dim_multiplier = config["ffn_dim_multiplier"]
    norm_eps = config["norm_eps"]
    rope_theta = torch.tensor(config["rope_theta"])
    tokens = [128000] + tokenizer.encode(prompt)
    tokens = torch.tensor(tokens)
    prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
    embedding_layer = torch.nn.Embedding(vocab_size, dim)
    embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
    token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
    def rms_norm(tensor, norm_weights):
        return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
    final_embedding = token_embeddings_unnormalized
    n_tokens = len(tokens)
    zero_to_one_split_into_64_parts = torch.tensor(range(64)) / 64
    freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
    freqs_for_each_token = torch.outer(torch.arange(n_tokens), freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
    kv_scale = n_heads // n_kv_heads
    for layer in tqdm(range(n_layers), "layers"):
        qkv_attention_store = []
        layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
        q_layer = model[f"layers.{layer}.attention.wq.weight"]
        q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
        k_layer = model[f"layers.{layer}.attention.wk.weight"]
        k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
        v_layer = model[f"layers.{layer}.attention.wv.weight"]
        v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
        for head in range(n_heads):
            q_layer_head = q_layer[head]
            k_layer_head = k_layer[head // kv_scale]
            v_layer_head = v_layer[head // kv_scale]
            q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
            k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
            v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
            q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
            q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
            q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
            q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
            k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
            k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
            k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
            k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
            qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
            mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            qk_per_token_after_masking = qk_per_token + mask
            qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking,
                                                                                   dim=1).to(torch.bfloat16)
            qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
            qkv_attention_store.append(qkv_attention)
        stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
        w_layer = model[f"layers.{layer}.attention.wo.weight"]
        embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
        embedding_after_edit = final_embedding + embedding_delta
        embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
        w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
        w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
        w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
        output_after_feedforward = torch.matmul(
            torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(
                embedding_after_edit_normalized, w3.T), w2.T)
        final_embedding = embedding_after_edit + output_after_feedforward
    final_embedding = rms_norm(final_embedding, model["norm.weight"])
    logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
    next_token = torch.argmax(logits, dim=-1)
    print(f"[Info] next_token: {next_token}")
    word = tokenizer.decode([next_token.item()])
    return word
That’s all!



















