文章目录
- 1 Beam Search 解码算法实现
- 2 实现带KV Cache的Beam Search解码
- 3 关于在带kv-cache的情况下的use_cache参数
1 Beam Search 解码算法实现
下面是一个使用PyTorch实现的beam search解码算法:
几个小细节:
- 束搜索可以加入
length_penalty
,目前model.generate
也是有这个参数的,这个惩罚项直接是用来除生成概率的 - 通常这种需要计算概率相乘的情况,都是避免做乘法,而是使用log p相加
- 具体实现中应当考虑eos标识符导致的early stop的候选序列,需要提前存储到外面
- 然后就是关于使用
log softmax
得到log概率后,这其实是一个负的概率,序列越长,log prob会越小,- log prob 才是越大的,因此在做惩罚的时候,应该是吧prob / len(seq) ** penality
,即长序列的 log prob 会被除掉更多,这是合理的,因为短序列的 - log prob 天然地比 长序列地 - log prob 要更小,这样量纲才是正确的
import torch
import torch.nn.functional as F
from typing import List, Tuple
def beam_search(
model: torch.nn.Module,
initial_input: torch.Tensor,
beam_width: int,
max_length: int,
vocab_size: int,
device: torch.device,
length_penalty: float = 1.0,
early_stopping: bool = True
) -> Tuple[List[List[int]], List[float]]:
"""
Beam search 解码算法实现
参数:
model: 用于预测下一个token的模型
initial_input: 初始输入张量 (shape: [1, seq_len])
beam_width: beam大小
max_length: 生成序列的最大长度
vocab_size: 词汇表大小
device: 使用的设备 (cpu/cuda)
length_penalty: 长度惩罚系数 (α), 用于调整对长序列的偏好
early_stopping: 是否在所有beam序列达到EOS时提前停止
返回:
Tuple[List[List[int]], List[float]]: (生成的序列列表, 对应的分数列表)
"""
# 初始化beam
sequences = [[initial_input.tolist()[0]]] # 初始序列
scores = [0.0] # 初始分数 (log概率)
# 存储完整的beam (已经生成EOS的序列)
completed_sequences = []
completed_scores = []
for step in range(max_length):
# 如果所有beam都已完成,提前停止
if early_stopping and len(sequences) == 0:
break
# 准备当前步的输入
candidates = []
for i, seq in enumerate(sequences):
# 跳过已经完成的序列
if len(seq) > 0 and seq[-1] == 2: # 假设2是EOS token
completed_sequences.append(seq)
completed_scores.append(scores[i])
continue
# 将序列转换为张量
input_tensor = torch.tensor([seq], dtype=torch.long).to(device)
# 获取模型预测
with torch.no_grad():
outputs = model(input_tensor)
next_token_logits = outputs[:, -1, :] # 取最后一个token的logits
next_token_probs = F.log_softmax(next_token_logits, dim=-1)
# 获取top-k tokens和它们的log概率
topk_probs, topk_tokens = torch.topk(next_token_probs, beam_width, dim=-1)
topk_probs = topk_probs.squeeze(0)
topk_tokens = topk_tokens.squeeze(0)
# 创建候选序列
for j in range(beam_width):
new_seq = seq.copy()
new_seq.append(topk_tokens[j].item())
new_score = scores[i] + topk_probs[j].item()
candidates.append((new_seq, new_score))
# 如果没有候选序列,提前停止
if not candidates:
break
# 选择top-k候选序列,并剪枝
candidates.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)
sequences, scores = zip(*candidates[:beam_width])
sequences = list(sequences)
scores = list(scores)
# 添加剩余的未完成序列到完成列表中
completed_sequences.extend(sequences)
completed_scores.extend(scores)
# 对完成的序列按分数排序
sorted_sequences = [seq for seq, _ in sorted(
zip(completed_sequences, completed_scores),
key=lambda x: x[1] / (len(x[0]) ** length_penalty),
reverse=True
)]
sorted_scores = sorted(
completed_scores,
key=lambda score: score / (len(sorted_sequences[completed_scores.index(score)]) ** length_penalty),
reverse=True
)
return sorted_sequences, sorted_scores
-
模型要求:
- 模型应接受形状为
[batch_size, seq_len]
的输入 - 模型应输出形状为
[batch_size, seq_len, vocab_size]
的logits
- 模型应接受形状为
-
参数说明:
initial_input
: 初始输入序列(如开始token)beam_width
: 控制搜索宽度,值越大结果可能越好但计算成本越高length_penalty
: 控制生成长度的偏好(α>1鼓励更长序列,α<1鼓励更短序列)early_stopping
: 设为True时,当所有beam都生成EOS token时停止
-
返回值:
- 返回按分数排序的序列列表和对应的分数列表
如何调用?
model = ... # 你的PyTorch模型
initial_input = torch.tensor([[1]]) # 假设1是开始token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
sequences, scores = beam_search(
model=model,
initial_input=initial_input,
beam_width=5,
max_length=50,
vocab_size=10000, # 你的词汇表大小
device=device,
length_penalty=1.2
)
print("Top sequence:", sequences[0])
print("Score:", scores[0])
2 实现带KV Cache的Beam Search解码
这个是关键,之前写过一个k步的贪心解码,使用了递归写法,但不知道怎么在递归中加入kv cache,其实感觉怎么写都不太好。
其实这样看只要写循环就可以了。
实际上transformers的model.generate
方法里,对greedy、beam、以及其他采样的方法都是有做 kv cache 缓存优化的:
# 首次调用(处理初始输入)
outputs = model(input_ids, use_cache=True)
logits = outputs.logits
past_key_values = outputs.past_key_values # 缓存KV
# 后续步骤(生成时)
next_token_input = torch.tensor([[new_token]])
outputs = model(next_token_input, past_key_values=past_key_values, use_cache=True)
new_past_key_values = outputs.past_key_values # 更新缓存
要在Beam Search中为每个候选序列维护独立的KV Cache,我们需要对标准实现进行一些修改。KV Cache可以显著提高自回归模型的推理效率,因为它避免了重复计算先前token的key和value。
以下是带有KV Cache的Beam Search实现:
- 因为我们已经有cache了,因此实际上每步迭代生成下一个token时,只需要把当前序列的最后一个token输入即可,只是很棒的操作
import torch
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict, Any
def beam_search_with_kv_cache(
model: torch.nn.Module,
initial_input: torch.Tensor,
beam_width: int,
max_length: int,
vocab_size: int,
device: torch.device,
length_penalty: float = 1.0,
early_stopping: bool = True,
use_kv_cache: bool = True
) -> Tuple[List[List[int]], List[float]]:
"""
带KV Cache的Beam Search解码算法
参数:
model: 用于预测下一个token的模型
initial_input: 初始输入张量 (shape: [1, seq_len])
beam_width: beam大小
max_length: 生成序列的最大长度
vocab_size: 词汇表大小
device: 使用的设备 (cpu/cuda)
length_penalty: 长度惩罚系数
early_stopping: 是否在所有beam序列达到EOS时提前停止
use_kv_cache: 是否使用KV Cache加速
返回:
Tuple[List[List[int]], List[float]]: (生成的序列列表, 对应的分数列表)
"""
# 初始化beam
sequences = [[initial_input.tolist()[0]]]
scores = [0.0]
# 存储KV Cache (每个候选序列一个cache)
kv_caches = [None] # 初始cache为None
# 存储完整的beam
completed_sequences = []
completed_scores = []
for step in range(max_length):
if early_stopping and len(sequences) == 0:
break
candidates = []
new_kv_caches = []
for i, (seq, score, kv_cache) in enumerate(zip(sequences, scores, kv_caches)):
# 跳过已经完成的序列
if len(seq) > 0 and seq[-1] == 2: # 假设2是EOS token
completed_sequences.append(seq)
completed_scores.append(score)
continue
# 准备输入 (只使用最后一个token,因为前面的已经cache了)
input_tensor = torch.tensor([[seq[-1]]], dtype=torch.long).to(device)
# 前向传播,使用或更新KV Cache
with torch.no_grad():
if use_kv_cache:
if kv_cache is None:
# 第一次调用,处理整个初始序列
full_input = torch.tensor([seq], dtype=torch.long).to(device)
outputs = model(full_input, use_cache=True)
next_token_logits = outputs.logits[:, -1, :]
new_kv_cache = outputs.past_key_values
else:
# 后续调用,使用KV Cache
outputs = model(input_tensor, past_key_values=kv_cache, use_cache=True)
next_token_logits = outputs.logits[:, -1, :]
new_kv_cache = outputs.past_key_values
else:
# 不使用KV Cache的情况
full_input = torch.tensor([seq], dtype=torch.long).to(device)
outputs = model(full_input, use_cache=False)
next_token_logits = outputs.logits[:, -1, :]
new_kv_cache = None
next_token_probs = F.log_softmax(next_token_logits, dim=-1)
# 获取top-k tokens
topk_probs, topk_tokens = torch.topk(next_token_probs, beam_width, dim=-1)
topk_probs = topk_probs.squeeze(0)
topk_tokens = topk_tokens.squeeze(0)
# 创建候选序列
for j in range(beam_width):
new_seq = seq.copy()
new_seq.append(topk_tokens[j].item())
new_score = score + topk_probs[j].item()
candidates.append((new_seq, new_score, new_kv_cache))
# 如果没有候选序列,提前停止
if not candidates:
break
# 选择top-k候选序列
candidates.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)
top_candidates = candidates[:beam_width]
# 解包候选序列
sequences = [cand[0] for cand in top_candidates]
scores = [cand[1] for cand in top_candidates]
kv_caches = [cand[2] for cand in top_candidates]
# 添加剩余的未完成序列
completed_sequences.extend(sequences)
completed_scores.extend(scores)
# 对完成的序列按分数排序
sorted_pairs = sorted(
zip(completed_sequences, completed_scores),
key=lambda x: x[1] / (len(x[0]) ** length_penalty),
reverse=True
)
sorted_sequences = [pair[0] for pair in sorted_pairs]
sorted_scores = [pair[1] for pair in sorted_pairs]
return sorted_sequences, sorted_scores
关键修改说明
-
KV Cache维护:
- 每个候选序列现在维护自己的KV Cache
- 初始时KV Cache为None
- 第一次处理序列时生成完整的KV Cache
- 后续步骤只处理最后一个token并更新KV Cache
-
模型接口要求:
- 模型需要支持
past_key_values
参数来接收先前的KV Cache - 模型需要返回
past_key_values
作为输出的一部分 - 典型实现方式(如HuggingFace的transformers):
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True) next_token_logits = outputs.logits past_key_values = outputs.past_key_values
- 模型需要支持
-
性能优化:
- 使用KV Cache后,每次前向传播只处理最后一个token(这个很有趣,但是要设置
use_cache=True
) - 避免了重复计算先前token的key和value
- 对于长序列可以显著提高速度
- 使用KV Cache后,每次前向传播只处理最后一个token(这个很有趣,但是要设置
一个简单的调用示例:
# 假设我们有一个支持KV Cache的模型
model = ... # 例如HuggingFace的GPT2模型
initial_input = torch.tensor([[model.config.bos_token_id]]) # 开始token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 运行带KV Cache的beam search
sequences, scores = beam_search_with_kv_cache(
model=model,
initial_input=initial_input,
beam_width=5,
max_length=50,
vocab_size=model.config.vocab_size,
device=device,
length_penalty=1.2,
use_kv_cache=True # 启用KV Cache
)
print("Top sequence:", sequences[0])
print("Score:", scores[0])
补注:
在这个部分:
# 前向传播,使用或更新KV Cache
with torch.no_grad():
if use_kv_cache:
if kv_cache is None:
# 第一次调用,处理整个初始序列
full_input = torch.tensor([seq], dtype=torch.long).to(device)
outputs = model(full_input, use_cache=True)
上,输出的full_input
的size是[1, 1, seqlen]
,理论上应该是[1, seqlen]
才对,因此要么是
# 前向传播,使用或更新KV Cache
with torch.no_grad():
if use_kv_cache:
if kv_cache is None:
# 第一次调用,处理整个初始序列
full_input = torch.tensor(seq, dtype=torch.long).to(device)
outputs = model(full_input, use_cache=True)
要么是:
# 前向传播,使用或更新KV Cache
with torch.no_grad():
if use_kv_cache:
if kv_cache is None:
# 第一次调用,处理整个初始序列
full_input = torch.tensor([seq], dtype=torch.long).to(device)
outputs = model(full_input.squeeze(0), use_cache=True)
这样测试跑通应该是没有问题的
3 关于在带kv-cache的情况下的use_cache参数
比如之前手写的一个贪心解码算法:
# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cn
import torch
import logging
from copy import deepcopy
from functools import wraps
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
# Standard greedy decode
# @param model: Huggingface model object
# @param tokenizer: Huggingface tokenizer Object
# @param prompt: Str
# @param max_length: Int, the number of tokens to be generated
# @param device: Str, e.g. "cuda" or "cpu"
# @param kv_cache: Boolean, whether to use KV-cache to accelerate, if True then large memory will be consumed
# @return generated_text: Str
# @return generated_token_prob: List[Tuple(Int, Str, Float)], `len(generated_id_prob)` is `max_length`, indicating the generated probability of each token
# @return generated_logits: Tuple[FloatTensor(1, n_vocab)], `len(generated_logits)` is `max_length`, indicating the logits when each token is generated
def greedy_decode(model,
tokenizer,
prompt,
max_length,
device = "cuda",
kv_cache = True,
):
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device) # Str => Long(1, n_tokens)
past_key_values = None
generated_token_probs = list()
generated_logits = list()
model.gradient_checkpointing_enable()
for i in range(max_length):
logging.info(f"Round {i}: {past_key_values.key_cache[0].size() if past_key_values is not None else None}")
outputs = model(inputs, past_key_values=past_key_values)
logits = outputs.logits # Float(1, n_tokens + i + 1, n_vocab), where `n_vocab` is 151936 in DeepSeek-R1-Distill-Qwen
if kv_cache:
past_key_values = outputs.past_key_values # Dictlike[key_cache: Float(1, 2, X, hidden_size), value_cache: Float(1, 2, X, hidden_size)], where X = (i + 1) * (n_tokens + i / 2)
next_token_probs = F.softmax(logits[:, -1, :], dim=-1) # Float(1, n_tokens + i + 1, n_vocab) => Float(1, n_vocab)
next_token_id = torch.argmax(next_token_probs, dim=-1) # Float(1, n_vocab) => Long(1, )
next_token_prob = next_token_probs[0, next_token_id].item() # Float(1, n_vocab) => Float()
next_token = tokenizer.decode(next_token_id[0].item(), skip_special_tokens=False) # Long(1, ) => Str
inputs = torch.cat([inputs, next_token_id.unsqueeze(-1)], dim=-1) # Long(1, n_tokens + i) => Long(1, n_tokens + i + 1)
generated_token_probs.append((next_token_id.item(), next_token, next_token_prob))
generated_logits.append(logits[:, -1, :])
generated_text = tokenizer.decode(
token_ids = inputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
) # Long(1, n_tokens + max_length) => Str
return generated_text, generated_token_probs, tuple(generated_logits)
实际上除了第一次输入外,接下来都可以用最后一个token作为输入,而不需要把之前整个一长串的input
都输入到model
中去:
# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cn
import torch
import logging
from copy import deepcopy
from functools import wraps
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
# Standard greedy decode
# @param model: Huggingface model object
# @param tokenizer: Huggingface tokenizer Object
# @param prompt: Str
# @param max_length: Int, the number of tokens to be generated
# @param device: Str, e.g. "cuda" or "cpu"
# @param kv_cache: Boolean, whether to use KV-cache to accelerate, if True then large memory will be consumed
# @return generated_text: Str
# @return generated_token_prob: List[Tuple(Int, Str, Float)], `len(generated_id_prob)` is `max_length`, indicating the generated probability of each token
# @return generated_logits: Tuple[FloatTensor(1, n_vocab)], `len(generated_logits)` is `max_length`, indicating the logits when each token is generated
def greedy_decode(model,
tokenizer,
prompt,
max_length,
device = "cuda",
kv_cache = True,
):
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device) # Str => Long(1, n_tokens)
past_key_values = None
generated_token_probs = list()
generated_logits = list()
model.gradient_checkpointing_enable()
for i in range(max_length):
logging.info(f"Round {i}: {past_key_values.key_cache[0].size() if past_key_values is not None else None}")
if kv_cache:
if i == 0:
outputs = model(inputs, past_key_values=past_key_values)
else:
outputs = model(inputs[:, -1].unsqueeze(0), past_key_values=past_key_values, use_cache=True)
else:
outputs = model(inputs, past_key_values=None)
logits = outputs.logits # Float(1, n_tokens + i + 1, n_vocab), where `n_vocab` is 151936 in DeepSeek-R1-Distill-Qwen
if kv_cache:
past_key_values = outputs.past_key_values # Dictlike[key_cache: Float(1, 2, X, hidden_size), value_cache: Float(1, 2, X, hidden_size)], where X = (i + 1) * (n_tokens + i / 2)
next_token_probs = F.softmax(logits[:, -1, :], dim=-1) # Float(1, n_tokens + i + 1, n_vocab) => Float(1, n_vocab)
next_token_id = torch.argmax(next_token_probs, dim=-1) # Float(1, n_vocab) => Long(1, )
next_token_prob = next_token_probs[0, next_token_id].item() # Float(1, n_vocab) => Float()
next_token = tokenizer.decode(next_token_id[0].item(), skip_special_tokens=False) # Long(1, ) => Str
inputs = torch.cat([inputs, next_token_id.unsqueeze(-1)], dim=-1) # Long(1, n_tokens + i) => Long(1, n_tokens + i + 1)
generated_token_probs.append((next_token_id.item(), next_token, next_token_prob))
generated_logits.append(logits[:, -1, :])
generated_text = tokenizer.decode(
token_ids = inputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
) # Long(1, n_tokens + max_length) => Str
return generated_text, generated_token_probs, tuple(generated_logits)
这个确实是很有帮助的,能加速推理很多。这个原理其实很简单,因为只需要KVcache与最后一个token就可以计算得到下一层的注意力权重(其实就是下一轮生成的KVcache),然后倒是发现deepseek在生成图像链接时出错了,难得逮到DeepSeek犯错的时候: