LLM各层参数详细分析(以LLaMA为例)

news2025/6/9 18:22:50

网上大多分析LLM参数的文章都比较粗粒度,对于LLM的精确部署不太友好,在这里记录一下分析LLM参数的过程。

首先看QKV。先上transformer原文
在这里插入图片描述
也就是说,当h(heads) = 1时,在默认情况下, W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV都是2维方阵,方阵维度是 d m o d e l × d m o d e l d_{model} \times d_{model} dmodel×dmodel.

结合llama源码 (https://github.com/facebookresearch/llama/blob/main/llama/model.py)

class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048
# ...

class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, args: ModelArgs):
        """
        Initialize the Attention module.

        Args:
            args (ModelArgs): Model configuration parameters.

        Attributes:
            n_kv_heads (int): Number of key and value heads.
            n_local_heads (int): Number of local query heads.
            n_local_kv_heads (int): Number of local key and value heads.
            n_rep (int): Number of repetitions for local heads.
            head_dim (int): Dimension size of each attention head.
            wq (ColumnParallelLinear): Linear transformation for queries.
            wk (ColumnParallelLinear): Linear transformation for keys.
            wv (ColumnParallelLinear): Linear transformation for values.
            wo (RowParallelLinear): Linear transformation for output.
            cache_k (torch.Tensor): Cached keys for attention.
            cache_v (torch.Tensor): Cached values for attention.

        """
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

计算出
self.n_kv_heads = h = 32
self.head_dim = 4096/32=128
所以 W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV 大小都为(4096, 128). Q × K T Q×K^T Q×KT后,大小为(4096, 4096),除法scale+softmax后不变,然后 × V ×V ×V,大小恢复变为(4096, 128)。Attention不改变大小(在默认 d k = d v d_k=d_v dk=dv情况下)。
在这里插入图片描述

经过Cancat,分开的头又合并,大小变为(4096, 4096)方阵,经过 W O W^O WO全连接,还是(4096, 4096)方阵。

然后看Feed forward.根据源码,

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        """
        Initialize a TransformerBlock.

        Args:
            layer_id (int): Identifier for the layer.
            args (ModelArgs): Model configuration parameters.

        Attributes:
            n_heads (int): Number of attention heads.
            dim (int): Dimension size of the model.
            head_dim (int): Dimension size of each attention head.
            attention (Attention): Attention module.
            feed_forward (FeedForward): FeedForward module.
            layer_id (int): Identifier for the layer.
            attention_norm (RMSNorm): Layer normalization for attention output.
            ffn_norm (RMSNorm): Layer normalization for feedforward output.

        """
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        """
        Perform a forward pass through the TransformerBlock.

        Args:
            x (torch.Tensor): Input tensor.
            start_pos (int): Starting position for attention caching.
            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
            mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.

        Returns:
            torch.Tensor: Output tensor after applying attention and feedforward layers.

        """
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_cis, mask
        )
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

multiattention layer过后,经过加法和norm(RMS norm),进入feed_forward全连接。全连接层第一个维度是args.dim=4096, 第二个维度(hidden_dim)是4 * args.dim = 4*4096=16384 (目前还有问题)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1026849.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RabbitMQ - 死信、TTL原理、延迟队列安装和配置

目录 一、死信交换机 1.1、什么是死信交换机 1.2、TTL 1.2.1、什么是 TTL 1.2.2、通过 TTL 模拟触发死信 二、延迟队列 2.1、什么是延迟队列 2.2、配置延迟队列插件 2.2.1、延迟队列配置 a)下载镜像 b)运行容器 c)刚刚设定的Rabb…

jmeter下载安装教程

一、下载安装jdk(jmeter需要) 1、首页下载jdk,地址:Java Downloads | Oracle 2、下载se,注意需要oracle账号,注册即可 这里的8u384代表JDK 8版本,384代表子版本,u是update(更新)的…

flink集群与资源@k8s源码分析-运行时

1 运行时 运行时提供了Flink作业运行过程依赖的基础执行环境,包含Dispatcher、ResourceManager、JobManager和TaskManager等核心组件,本节分析资源相关运行时组件构建和启动。 flink没有使用spring,缺少ioc的构建过程相当复杂,所有依赖手动关联和置入,为了共享组件,fli…

jenkins容器内配置python项目运行环境(Python3.7.3)

目录 1.查看启动的容器2.进入jenkins容器内部3.使用wget:提示没有wget命令4.查看jenkins容器系统版本5.换成国内源(阿里)5.更新apt-get6.安装wget7.创建python存放目录8.下载python9.解压10.安装依赖11.运行脚本configure12.make编译make ins…

汽车三高试验离不开的远程试验管理平台-TFM

随着信息技术的高速发展,企业对远程试验实时监控与数据管理的需求日益增强。而利用远程试验信息协同技术,可突破部门与地域的限制,并把试验现场的车辆状态信息、试验数据和分析结果实时传输给数据分析部门和设计部门等,从而缩短时…

什么是HTTP/2?它与HTTP/1.1相比有什么改进?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ HTTP/2 简介⭐ 主要的改进和特点1. 多路复用(Multiplexing)2. 头部压缩(Header Compression)3. 服务器推送(Server Push)4. 二进制传输(Binary Protocol&…

12基于MATLAB的短时傅里叶变换( STFT),连续小波变换( CWT),程序已调通,可以直接运行。

基于MATLAB的短时傅里叶变换( STFT),连续小波变换( CWT),程序已调通,可以直接运行

jdk exe安装包如何自制zip解压版

jdk8 oracle官方下载页面 https://www.oracle.com/java/technologies/downloads/#java8-windows 可以看到,只有exe安装包 下载最新的exe安装包 解压 用7Zip解压 里面有好几个JAVA_CAB*文件夹,我们只需要关注两个:9和10,JAVA_CA…

【操作系统笔记】内存分配

内存对齐 问题:为什么需要内存对齐呢? 主要原因是为了兼容,为了让程序可以运行在不同的处理器中,有很多处理器在访问内存的时候,只能从特定的内存地址读取数据。换个说法就是处理器每次只能从内存取出特定个数字节的数…

卡尔曼滤波(Kalman Filter)C#测试

一、操作过程 刚学了一下卡尔曼滤波,具体原理还没细看,大致过程如下 分为两步,第一步Predict,以下两个公式 第二步Correct,以下三个公式 公式看起来很复杂,其中是我们要处理的数据, 是滤…

HTTP 协商缓存 ETag、If-None-Match

(1)浏览器第一次跟服务器请求一个资源,服务器在返回这个资源的同时,在respone header加上ETag。 ETag是服务器根据当前请求的资源生成的一个唯一标识。 这个唯一标识是一个字符串,只要资源有变化这个串就不同&#xff…

CSS的学习

1.认识CSS CSS 叫做"层叠样式表" “层叠样式表” 样式 --> 大小,位置,间距,颜色,字体,表框背景… 统称为"样式",描述了一个网页长什么样子~ 层叠 --> 针对一个html的元素/标签,可以同时应用多组CSS样式~~ 多组样式会叠加在一起~~ CSS描述的是页…

cocosCreator 之 Graphics绘制基础图形,五角星,线型图,柱形图

版本: 3.4.0 环境: Mac Graphics组件 Graphics组件主要用于绘画使用,属于渲染组件。继承结构: #mermaid-svg-WHveKVDzMTXmCbpg {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mer…

Android Studio 创建项目不自动生成BuildConfig文件

今天在AS上新建项目发现找不到BuildConfig文件,怎么clear都不行。通过多方面查找发现原来gradle版本不同造成的,Gradle 8.0默认不生成 BuildConfig 文件。 如上图,8.0版本是没有source文件夹 上图是低于8.0版本有source文件夹 针对这个问题&…

Jenkins学习笔记1

CI 服务器: 认识Jenkins: Jenkins是一个可扩展的持续集成(CI)引擎,是一个开源项目,旨在提供一个开放易用的软件平台,使得软件持续集成变成可能。Jenkins非常易于安装和配置,简单易…

算法leetcode|83. 删除排序链表中的重复元素(rust重拳出击)

文章目录 83. 删除排序链表中的重复元素:样例 1:样例 2:提示: 分析:题解:rust:go:c:python:java: 83. 删除排序链表中的重复元素: 给…

Docker从认识到实践再到底层原理(六-1)|Docker容器基本介绍+命令详解

前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助。 高质量博客汇总 然后就是博主最近最花时间的一个专栏…

vue+element plus 使用table组件,清空用户的选择项

<el-table ref"tableRef"> .... </el-table> <script lang"ts" setup> import { onMounted, reactive, ref, nextTick } from vue const clearBtn () > {console.log(清空用户的选择项)tableRef.value.clearSelection() } </scr…

【论文阅读 09】融合门控自注意力机制的生成对抗网络视频异常检测

2021年 中国图象图形学报 摘 要 背景&#xff1a; 视频异常行为检测是智能监控技术的研究重点&#xff0c;广泛应用于社会安防领域。当前的挑战之一是如何提高异常检测的准确性&#xff0c;这需要有效地建模视频数据的空间维度和时间维度信息。生成对抗网络&#xff08;GANs&…

[学习记录] 设计模式 3. 观察者模式

观察者模式 参考&#xff1a; bugstack 虫洞栈Refactoringhttps://www.cnblogs.com/myseries/p/8735490.htmlhttps://www.jianshu.com/p/4f1cd513a72d 当一个行为发生时传递信息给另外一个用户接收做出相应的处理&#xff0c;两者之间没有直接的耦合关联。 在我们编程开发中也…