《Adaptive Layer-skipping in Pre-trained LLMs》- 论文笔记

news2025/5/19 18:30:42

作者:Xuan Luo, Weizhi Wang, Xifeng Yan
Department of Computer Science, UC Santa Barbara
xuan_luo@ucsb.edu, weizhiwang@ucsb.edu, xyan@cs.ucsb.edu

1. 引言与动机

1.1 背景

  • LLM 的成功与挑战:
    • 大型语言模型 (LLMs) 在翻译、代码生成、推理等任务上取得巨大成功。
    • 核心问题: 当前LLM在生成每个token时,通常需要通过所有Transformer层进行完整的前向传播。
  • 计算资源浪费:
    • 这种统一的计算分配 (Uniform Allocation) 与直觉相悖:简单的任务/token(如重复词、常见短语)理应需要更少的计算资源,而复杂的任务/token(如推理、生成新信息)需要更多。
    • 导致计算效率低下, 过拟合等。

1.2 研究问题与贡献

  • 现有方法的局限:
    • 已有的层跳过 (Layer-skipping) 或早退 (Early-Exit) 方法虽然能减少计算量,但大多忽略了一个根本问题:
      • “不同 Token 的生成,其计算需求是如何变化的?” (How do computational demands vary across the generation of different tokens?)
  • 本文动机:
    • 深入探究Token生成过程中的计算需求异质性。
    • 提出一种能在预训练LLM上实现自适应层跳过的方法,且不修改原始模型参数。
  • 主要贡献:
    • 提出 FlexiDepth: 一个动态调整Transformer层数的即插即用 (plug-in) 方法。
    • 在 Llama-3-8B 上实现显著层跳过(跳过8/32层)同时保持100%基准性能。
    • 揭示了LLM计算需求与Token类型显著相关(如重复Token vs. 计算密集型Token)。
    • 开源了 FlexiDepth 模型 和 FlexiPatterns 数据集 (记录层分配模式)。

2. 相关工作

  • 层跳过/效率提升方法分类:
    • 基于统计信息跳过层: 利用层输入输出差异等信息判断并跳过不重要层 (如 ShortGPT [26])。
    • 早退 (Early-Exit): 在中间层设置判断点,若置信度足够高则直接输出,跳过后续所有层 (如 [37, 18, 34])。
    • 从头训练动态深度模型: 在训练时就加入路由机制,动态决定每层是否执行 (如 MoD [31], SkipLayer [41], Duo-LLM [2])。缺点:需要大量计算资源重新训练。
    • Encoder中的条件计算: 如 PoWER-BERT [11], CoDA [21], COLT5 [1] 等,在Encoder中根据token重要性/复杂度分配不同计算路径。缺点:非因果性,不直接适用于Decoder-only模型。
    • 预训练模型中的跳过: MindSkip [12] 可以在预训练模型上跳过,但主要探索跳过Attention,且本文作者认为其性能或方式有别。
  • FlexiDepth 的定位:
    • 专注于Decoder-only的预训练LLM。
    • 逐层 (Layer-wise) 动态决策,而非早退。
    • 通过轻量级插件实现,冻结原始模型参数。
    • 不仅提升效率,更旨在理解和利用计算需求的变化规律。

3. FlexiDepth

3.1 整体架构

  • 核心思想: 在预训练LLM的每个(或部分,如下文所述,通常是后半部分)Transformer Decoder层,增加决策和适配机制,动态决定每个Token是完整处理还是跳过该层核心计算。

  • FlexiDepth Block (图 2):
    在这里插入图片描述

    • 输入: Hidden State (X)。
    • 两个并行路径:
      • 完整处理路径 (Full-processing Path, 图2 左):
        • Token 通过标准的 Attention 和 FFN 模块。
        • 输出 = g * Original_Layer(X) (g 为路由得分)。
      • 跳过路径 (Skipping Path, 图2 右):
        • Token 绕过 Attention 和 FFN 模块。
        • 通过一个轻量级的 Adapter 进行处理。
        • 输出 = (1-g) * Adapter(Norm(X))。
    • 核心组件 (可训练):
      • Router: 决定Token走哪条路径 (计算得分 g)。
      • Adapter: 处理走跳过路径的Token,解决表征不匹配问题。
    • 输出: 两条路径的输出加权合并。
  • 关键特性: 原始LLM的Attention和FFN参数保持冻结。只训练Router和Adapter。

3.2 Router 设计

  • 目标: 为每个输入Token x_i 计算一个门控分数 g_i ∈ (0, 1),表示其通过完整路径的倾向。

  • 输入: 经过 RMSNorm 标准化的 Hidden State z = Norm(X)。

  • Router 结构 (Eq 2):
    在这里插入图片描述

    • 为什么不用简单的线性层? (消融实验会证明) 简单的线性层不足以捕捉路由决策所需的复杂模式,尤其是在冻结主干模型时。Bottleneck结构在参数高效的同时提供了足够的表达能力。
  • 输出: Gating Score G = σ(Router(z)) (Eq 1),其中 σ 是 Sigmoid 函数。

  • 路由决策: 使用预定义阈值 τ。若 g_i > τ,走完整路径;若 g_i <= τ,走跳过路径。

3.3 Attention Skipping 与 KV Cache

  • 问题: 如果完全跳过Attention层,那么该Token对应的Key (K) 和 Value (V) 就不会被计算。对于自回归模型,后续的Token将无法Attention到这个被跳过的Token,导致上下文信息丢失,严重影响生成质量 (如图3 中间的 ‘No KV Cache’ 所示)。
    在这里插入图片描述

  • FlexiDepth 的解决方案 (图3 右侧 ‘KV Cache’):

    • 对于决定跳过Attention模块的Token (即 g_i <= τ):
      • 仍然计算其对应的 Key (K) 和 Value (V) 并存入KV Cache。
      • 跳过 Query (Q) 的计算以及后续的点积注意力计算 (Scaled Dot-Product Attention)。
  • 好处:

    • 保留了完整的上下文信息,确保后续Token可以Attention到所有历史Token。
    • 依然节省了Query计算和主要的Attention矩阵计算开销。
    • 这是维护自回归生成完整性的关键设计。

3.4 FFN Skipping 与 Adapter

  • 问题: FFN层包含非线性变换,直接跳过FFN会导致:
    • 表征不匹配 (Representation Mismatch): 经过FFN处理的Token和直接跳过的Token处于不同的表示空间。
    • 性能显著下降: (消融实验会证明) 简单跳过FFN效果很差。
  • FlexiDepth 的解决方案 (图2 右侧):
    • 引入一个轻量级 Adapter。
    • 结构: 与原始FFN类似 (MLP结构),但中间维度显著减小 (例如,论文中提到减少16倍)。
    • 功能: 对跳过FFN的Token进行变换,使其表示与经过完整FFN处理的Token对齐 (align)。
  • 好处:
    • 在计算开销很小的情况下,有效弥合了跳过FFN带来的表征差异。
    • 是保证性能的另一个关键组件。

3.5 损失函数

  • 目标: 平衡 生成质量 和 计算效率 (层跳过率)。
    在这里插入图片描述
    在这里插入图片描述

  • 总损失 (Total Loss, Eq 4):

    • L_lm: 标准的下一个Token预测损失 (Language Modeling Loss)。
    • L_skip: 层跳过损失 (Layer-skipping Loss)。
    • α: 平衡系数,控制层跳过损失的权重。
  • 层跳过损失 (L_skip, Eq 3): L_skip = (1/T) * Σ_t (Σ_l g_tl)2 (原文公式似乎有误,应该是类似惩罚“使用层数”的平方和,更可能是 (1/T) * Σ_t Σ_l (g_tl)2 或者类似含义,需确认。但核心思想是惩罚使用的层数。)

    • 惩罚每个Token使用的门控分数 (g) 的总和的平方 (或者各层g的平方和)。
    • 为什么用平方? 对使用更多层的Token施加更大的惩罚,鼓励模型跳过层;同时避免模型陷入全跳或全不跳的极端。有助于稳定训练。
  • 训练细节 (Section 3.1):

    • 只在模型的后半部分层 (如 Llama-3-8B 的后16层) 应用FlexiDepth。原因:先前研究表明跳过早期层对性能影响更大。
    • Router的Bottleneck维度 (dr = d/16),Adapter的中间层维度缩小16倍。
    • 使用 Tulu-v2 数据集训练,AdamW优化器

4. 实验设置

  • 基础模型: Llama-3-8B-Instruct (32层)。
  • 评估基准 (Benchmarks):
    • 单Token生成: MMLU, HellaSwag, Winogrande (考察知识、常识、推理)。
    • 多Token生成: GSM8K (数学推理), HumanEval (代码生成), CoQA (对话式问答)。区分这两类很重要,因为性能差异在多Token任务上更明显。
  • 评估指标 (Metrics): Accuracy (acc), Normalized Accuracy (acc_norm), Exact Match (EM), Pass@1, F1 score (根据不同任务选择)。
  • 对比基线 (Baselines):
    • Vanilla (原始 Llama-3-8B-Instruct)。
    • LayerSkip [9] (早退最后k层 + 推测解码)。
    • ShortGPT [26] (基于输入输出差异剪枝k层)。
    • LaCo [39] (层合并,减少k层)。
    • MindSkip [12] (探索Attention/FFN/Layer跳过,论文采用其Layer Skipping设置)。
  • 公平比较: 所有基线方法都应用于 Llama-3-8B,并配置为跳过相同数量 (k=4 或 k=8) 的层进行比较 (通过调整FlexiDepth的α实现近似跳过层数)。

5. 主要结果与分析

5.1 基准性能比较

在这里插入图片描述

  • 核心发现: FlexiDepth 在跳过层数(k=4, k=8)的情况下,显著优于所有基线方法,尤其是在多Token生成任务 (GSM8K, HumanEval) 上。
  • Skip 8 Layers:
    • 基线方法在 GSM8K 和 HumanEval 上性能几乎崩溃 (接近0)。
    • FlexiDepth 保持了接近100% (100.7%) 的平均性能。
  • 性能甚至略有提升?
    • 在某些任务上,FlexiDepth 性能甚至略微超过了原始模型 (Retain % > 100%)。
    • 假设: 作者推测这可能源于自适应跳过带来的隐式正则化 (implicit regularization) 效果,跳过了不信息或噪声参数。与完全微调的模型对比 (allenai/llama-3-tulu-2-8b),FlexiDepth在GSM8K/HumanEval上表现更好,说明提升不完全来自训练数据。
  • 结论: FlexiDepth 可以在大幅减少计算(跳过8层)的同时,几乎无损甚至略微提升模型在各种任务上的性能,尤其擅长处理需要复杂推理的长序列生成任务。

5.2 跨模型尺寸表现

在这里插入图片描述

  • 实验: 在不同尺寸的指令微调模型上应用FlexiDepth (Llama-2-13B, Llama-3-8B, Qwen-2.5-3B)。
  • 发现:
    • 模型越大,跳过的层数越多。
      • Llama-2-13B: 平均跳过约 6-7 层。
      • Llama-3-8B: 平均跳过约 6 层 (这里跳过层数比Table 1的8层少,可能是α取值不同)。
      • Qwen-2.5-3B: 平均只跳过 1-2 层。
  • 解释:
    • 这表明更大的模型固有地拥有更高的冗余度 (redundancy)。
    • 因此,自适应层跳过方法在更大规模的LLM上具有更大的潜力。

5.3 层分配模式

在这里插入图片描述

  • 主要发现:
    • 任务依赖性
      在这里插入图片描述

      • Summarization (总结): 平均使用更多层 (e.g., 28.65层)。需要深入理解和抽象。
      • Extractive QA (抽取式问答) / Copying (复制): 平均使用较少层 (e.g., 复制 21.95层)。依赖检索和直接输出。
      • Continuation (续写): 使用最多层 (e.g., 30.27层)。需要创造性和上下文连贯性。
    • Token 类型依赖性
      在这里插入图片描述

      • 重复/简单复制: 如重复数字列表、公式左侧的数字,使用较少层。
      • 计算/推理/高不确定性: 如数学运算的结果、总结或续写中的新信息,需要更多层。
  • 结论: LLM的计算需求确实不是均匀的,而是与任务复杂度和当前Token的功能(是复制、计算还是生成新信息)密切相关。FlexiDepth的自适应机制能够捕捉并利用这种模式。

6. 消融实验

在这里插入图片描述

  • 目的: 验证FlexiDepth中各个设计选择的必要性。基于Llama-3-8B进行。
  • 实验设置:
    • Linear Router: 将 MLP Router 替换为简单的线性层 + Sigmoid。
    • No KV Cache: 跳过Attention时,不计算和存储 K, V。
    • No Adapter: 跳过FFN时,移除Adapter。
  • 结果:
    • Linear Router: 性能显著下降 (Retain 68.7%),尤其在 GSM8K (0.657 -> 0.131)。说明复杂路由机制是必要的。
    • No KV Cache: 性能大幅下降 (Retain 84.3%)。证明为跳过Token保留KV Cache对于维护上下文至关重要。
    • No Adapter: 性能灾难性下降 (Retain 28.1%)。凸显Adapter在对齐跳过FFN的Token表征方面的关键作用。
  • 结论: FlexiDepth 中的 Router、KV Cache 保留策略、以及 FFN Adapter 都是不可或缺的设计,共同保证了模型在层跳过时的性能。

7. 局限性与未来工作

  • 主要局限性 (Limitation):
    • 理论FLOPs减少 vs. 实际吞吐量提升: 当前实现未能在现有GPU硬件上带来显著的推理速度提升。
    • 原因:
      • 控制流开销 (Control-flow overhead): 同一个batch内的样本可能走不同的计算路径 (一些Token跳过,一些不跳过),需要复杂的管理。
      • 不规则内存访问 (Irregular memory access): 不同的执行路径导致访存模式不规则,降低GPU并行效率。
  • 未来工作 (Future Work):
    • 硬件感知优化: 需要研究专门的优化技术来克服上述瓶颈,例如:
      • Token Grouping [30]: 将计算需求相似的Token分组处理。
      • Expert Sharding / Load Balancing [30, 15]: 在多GPU或专用硬件上更有效地分配计算负载。
    • 深入研究正则化效应: 探索自适应跳过是否真的能作为一种有效的正则化手段。
    • 将FlexiDepth应用于更广泛的模型和任务。

8. 结论

  • 核心贡献: 提出 FlexiDepth,一种在预训练LLM上实现动态自适应层跳过的方法,无需修改原始模型参数。
  • 关键成果:
    • 在保持SOTA性能(甚至略有超越)的同时,实现了显著的层跳过(如Llama-3-8B跳过8/32层)。
    • 显著优于现有兼容预训练模型的层跳过方法,尤其在复杂生成任务上。
  • 重要洞见:
    • 首次系统地揭示并量化了LLM中Token生成的计算需求异质性,发现其与任务类型和Token功能强相关。
    • 验证了更大模型具有更高冗余度,为自适应方法提供了更大空间。
  • 价值: 提供了一种有效的方法来提升LLM效率(潜力巨大,待硬件优化),并为理解LLM内部计算动态提供了新的视角和工具 (FlexiPatterns数据集)。

9. 代码

https://huggingface.co/xuan-luo/FlexiDepth-Llama-3-8B-Instruct/blob/main/modeling_ddllama.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2337525.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序实现table样式,自带合并行合并列

微信小程序在代码编写过程好像不支持原生table的使用&#xff0c;在开发过程中偶尔又得需要拿table来展示。 1.table效果展示 1.wxml <view class"table-container"><view class"table"><view class"table-row"><view cla…

电脑的品牌和配置

我的笔记本是2020年买的&#xff0c;之前的订单找不到了&#xff0c;就知道是联想&#xff0c;不清楚具体的配置。 本文来源&#xff1a;腾讯元宝 检查系统信息&#xff08;Windows&#xff09; 这通常是 ​​联想&#xff08;Lenovo&#xff09;​​ 的型号代码。 81XV 是联想…

Redis面试——常用命令

一、String &#xff08;1&#xff09;设置值相关命令 1.1.1 SET 功能&#xff1a;设置一个键值对&#xff0c;如果键已存在则覆盖旧值语法&#xff1a; SET key value [EX seconds] [PX milliseconds] [NX|XX]EX seconds&#xff1a;设置键的过期时间为 seconds 秒 PX milli…

Swin-Transformer-UNet改进:融合Global-Local Spatial Attention (GLSA) 模块详解

目录 1.模块概述 2.swinUNet网络 3. 完整代码 1.模块概述 Global-Local Spatial Attention (GLSA) 是一种先进的注意力机制模块,专为计算机视觉任务设计,能够同时捕捉全局上下文信息和局部细节特征。 该模块通过创新的双分支结构和自适应融合机制,显著提升了特征表示能…

ubuntu 向右拖动窗口后消失了、找不到了

这是目前单显示器的设置&#xff0c;因为实际只有1个显示器&#xff0c;之前的设置如下图所示&#xff0c;有2个显示器&#xff0c;一个主显示器&#xff0c;一个23寸的显示器 ubuntu 22.04 系统 今天在操作窗口时&#xff0c;向右一滑&#xff0c;发现这个窗口再也不显示了、找…

2025最新版微软GraphRAG 2.0.0本地部署教程:基于Ollama快速构建知识图谱

一、前言 微软近期发布了知识图谱工具 GraphRAG 2.0.0&#xff0c;支持基于本地大模型&#xff08;Ollama&#xff09;快速构建知识图谱&#xff0c;显著提升了RAG&#xff08;检索增强生成&#xff09;的效果。本文手把手教你如何从零部署&#xff0c;并附踩坑记录和性能实测…

libevent服务器附带qt界面开发(附带源码)

本章是入门章节&#xff0c;讲解如何实现一个附带界面的服务器&#xff0c;后续会完善与优化 使用qt编译libevent源码演示视频qt的一些知识 1.主要功能有登录界面 2.基于libevent实现的服务器的业务功能 使用qt编译libevent 下载这个&#xff0c;其他版本也可以 主要是github上…

智能体数据分析

数据概览&#xff1a; 展示智能体的累计对话次数、累计对话用户数、对话满意度、累计曝光次数。数据分析&#xff1a; 统计对话分析、流量分析、用户分析、行为分析数据指标&#xff0c;帮助开发者完成精准的全面分析。 ps&#xff1a;数据T1更新&#xff0c;当日12点更新前一天…

STM32(M4)入门: 概述、keil5安装与模板建立(价值 3w + 的嵌入式开发指南)

前言&#xff1a;本教程内容源自信盈达教培资料&#xff0c;价值3w&#xff0c;使用的是信盈达的405开发版&#xff0c;涵盖面很广&#xff0c;流程清晰&#xff0c;学完保证能从新手入门到小高手&#xff0c;软件方面可以无基础学习&#xff0c;硬件学习支持两种模式&#xff…

采用若依vue 快速开发系统功能模块

文章目录 运行若依项目 科室管理科室查询-后端代码实现科室查询-前端代码实现科室名称状态搜索科室删除-后端代码实现科室删除-前端代码实现科室新增-后端代码实现科室新增-前端代码实现科室修改-后端代码实现前端代码实现角色权限实现 运行若依项目 运行redis 创建数据库 修改…

HTML:表格数据展示区

<!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>人员信息表</title><link rel"styl…

Oracle测试题目及笔记(单选)

所有题目来自于互联网搜索 当 Oracle 服务器启动时&#xff0c;下列哪种文件不是必须的&#xff08;D&#xff09;。 A&#xff0e;数据文件 B&#xff0e;控制文件 C&#xff0e;日志文件 D&#xff0e;归档日志文件 数据文件、日志文件-在数据库的打开阶段使用 控制文件-在数…

Jmeter创建使用变量——能够递增递减的计数器

Jmeter创建使用变量——能够递增递减的计数器 如下图所示&#xff0c;创建一个 取值需限定为0 2 4这三个值内的变量。 Increment&#xff1a;每次迭代后 递增的值&#xff0c;给计数器增加的值 Maximum value&#xff1a;计数器的最大值&#xff0c;如果超过最大值&#xff0…

数据结构之BFS广度优先算法(腐烂的苹果)

队列这个数据结构在很多场景下都有使用&#xff0c;比如在实现二叉树的层序遍历&#xff0c;floodfill问题(等等未完成)中&#xff0c;都需要借助队列的先进先出特性&#xff0c;下面给出这几个问题的解法 经典的二叉树的层序遍历 算法图示&#xff0c;以下图所示的二叉树为例…

火车头采集动态加载Ajax数据(无分页瀑布流网站)

为了先填充好数据在上线&#xff0c;在本地搭建了一个网站&#xff0c;并用火车头采集数据填充到里面。 开始很上手&#xff0c;因为找的网站的分类中是有分页的。很快捷的找到页面标识。 但是问题来了&#xff0c;如今很多网站都是采用的Ajax加载数据&#xff0c;根本没有分…

Node.js模块化与npm

目录 一、模块化简介 二、CommonJS 规范 1. 基本语法 2. 导出模块 3. 导入模块 三、ECMAScript 标准&#xff08;ESM&#xff09; 1. 启用 ESM 一、默认导出与导入 1. 基本语法 2. 默认导出&#xff08;每个模块仅一个&#xff09; 3. 默认导入 二、命名导出与导入…

nginx中的代理缓存

1.缓存存放路径 对key取哈希值之后&#xff0c;设置cache内容&#xff0c;然后得到的哈希值的倒数第一位作为第一个子目录&#xff0c;倒数第三位和倒数第二位组成的字符串作为第二个子目录&#xff0c;如图。 proxy_cache_path /xxxx/ levels1:2 2.文件名哈希值

【前端vue生成二维码和条形码——MQ】

前端vue生成二维码和条形码——MQ 前端vue生成二维码和条形码——MQ一、安装所需要的库1、安装qrcode2、安装jsbarcode 二、使用步骤1、二维码生成2、条形码生成 至此&#xff0c;大功告成&#xff01; 前端vue生成二维码和条形码——MQ 一、安装所需要的库 1、安装qrcode 1…

flutter 桌面应用之窗口自定义

在开发桌面软件的时候我们经常需要配置软件的窗口的大小以及位置 我们有两个框架选择:window_manager和bitsdojo_window 对比bitsdojo_window 特性bitsdojo_windowwindow_manager自定义标题栏✅ 支持❌ 不支持控制窗口行为&#xff08;大小/位置&#xff09;✅&#xff08;基本…

华为OD机试真题——MELON的难题(2025A卷:200分)Java/python/JavaScript/C++/C语言/GO六种最佳实现

2025 A卷 200分 题型 本文涵盖详细的问题分析、解题思路、代码实现、代码详解、测试用例以及综合分析&#xff1b; 并提供Java、python、JavaScript、C、C语言、GO六种语言的最佳实现方式&#xff01; 2025华为OD真题目录全流程解析/备考攻略/经验分享 华为OD机试真题《MELON的…