LLaMA-Adapter源码解析

news2025/7/16 19:17:58

LLaMA-Adapter源码解析

伪代码

def transformer_block_with_llama_adapter(x, gating_factor, soft_prompt):
	residual =x
	y= zero_init_attention(soft_prompt, x) # llama-adapter: prepend prefix
	x= self_attention(x)
	x = x+ gating_factor * y  # llama-adapter: apply zero_init_attention
	x = LayerNorm(x+residual)
	residual = x
	x = FullyConnectedLayers(x)
	x = AdapterLayers(x)
	x = LayerNorm(x + residual)
	return x

源码

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.gate = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        if adapter is not None:
           adapter_len = adapter.shape[1]
           adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
           adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
           adapter_k = adapter_k.transpose(1, 2)
           adapter_v = adapter_v.transpose(1, 2)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        if adapter is not None:
            adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
            adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
            output = output + torch.matmul(adapter_scores, adapter_v)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1157721.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全网公开电商数据的采集重点

数据的采集是根据需求而定的,品牌会做数据采集的原因,一般与内部营销、渠道管控有关,如需要做价格管控时,需要先采集价格,这就需要对数据进行采集,包括价格、促销信息,又或者是需要做行业分析、…

总感觉戴助听器耳朵又闷又堵怎么办?

随着助听器技术的进步发展,这些问题都有了一定程度的改善。例如,现在的助听器变得越来越小巧,外形更加美观和隐蔽;各种降噪技术和验配技巧也提升了助听器的音质和清晰度。 但是,还有一个问题困扰着很多助听器用户&…

缓存击穿只会逻辑过期 OR 互斥锁?深入思考 == 鹤立鸡群

网上但凡看得见的文章,大部分在说缓存穿透时都是无脑分布式锁 / 逻辑过期,分布式锁一点问题都没有么?逻辑过期一点问题都没有么?还能不能再进一步优化? 在聊聊缓存击穿的双重判定锁之前,我们将按照循循渐进…

java基础之IO操作

用户进程发起请求,内核接收到请求后,从I/O设备中获取数据到buffer中,再将buffer中的数据copy到用户进程的地址空间,该用户进程获取到数据后再响应客户端。 数据输入到buffer需要时间,从buffer复制数据至进程也需要时间…

社区参展招募| 与新兴初创企业一起!

一般来说,创业公司的生存率较低,失败率较高。根据不同的数据来源,创业公司的失败率高达 80%-90%。据统计,在中国每年新注册的企业数量超过 100 万家,但能够存活到 5 年以上的企业不足 7%,10 年以上不足 2%。…

路径复杂度(环形回路的复杂度计算)

路径复杂度 1、通用公式: (EF) - N12、非环形回路的复杂度计算公式为什么1?公式为什么(EF)-N? 3、类推到环形回路的复杂度演示区分下纯环形回路 和 不是纯粹的环形回路 3、特殊情况:自旋公式化理解:此时将B自旋回路看成一个环形回…

混合编程 ATPCS规范及案例(汇编调用C、C调用汇编、内联汇编)

1.混合编程的规范 2.汇编调用C 2.C调用汇编 3.内联汇编 例子:

高速视觉筛选机PCIe实时运动控制卡XPCIE1028亮相深圳慕尼黑华南电子展

本次的深圳慕尼黑华南电子展正运动技术将携高速视觉筛选机PCIe实时运动控制卡XPCIE1028亮相。此外,我们还为您准备了的新互动模式,您将有机会赢得超值礼品! 产品导读 正运动技术的PCI Express总线运动控制卡XPCIE1028,具备位置…

二叉树初次的整体过程

1.计算二叉树节点的范围 2.计算用父亲下标找出儿子的下标&#xff0c;用儿子的下标找出父亲的下标方法 2.向下调整算法 void AdjustUp(HPDataType*a,int child) {int parent (child - 1) / 2;while (0 < child){if (a[child] < a[parent]){HPDataType* tem a[child];a[…

Linux常用命令——chgrp命令

在线Linux命令查询工具 chgrp 用来变更文件或目录的所属群组 补充说明 chgrp命令用来改变文件或目录所属的用户组。该命令用来改变指定文件所属的用户组。其中&#xff0c;组名可以是用户组的id&#xff0c;也可以是用户组的组名。文件名可以 是由空格分开的要改变属组的文…

mpp解码详解

解码器数据流接口 一. decode_put_packet 输入码流的形式&#xff1a;分帧与不分帧 MPP 的输入都是没有封装信息的裸码流&#xff0c;裸码流输入有两种形式&#xff1a; 不分帧 这种方式是已经按帧分段的数据&#xff0c;即每一包输入给 decode_put_packet 函数的 MppPacket 数…

PCIe 的 MSI 中断详解,寄存器级别的详细流程分析,完全搞懂硬件的工作流程

PCIe 的 MSI 中断 前言 什么是 MSI 中断 (Message Signaled Interrupts) 概念与内容介绍待补充 正文 对 EP 的初始化 需要对 EP 的配置空间 MSI 相关功能的寄存器进行初始化&#xff0c;主要有两个寄存器 Message Address 和 Message Data。它们分别的含义是 EP 产生 MSI …

关于JSON 字符串或对象互转

目录 1、 识别不了空数组&#xff0c;只能打印出来{__ob__: Observer} 2、遇到把字符串解析为对象 1、 识别不了空数组&#xff0c;只能打印出来{__ob__: Observer} 遇到了这种&#xff0c;可以使用 "JSON.stringify(JSON.parse(JSON.stringify(obj))" 解析成空对象…

B站数据质量保障体系建设与实践

本文将分享 B 站数据质量保障体系的建设和实践。文章将关注数仓和建模的相关方法论&#xff0c;讲解 B 站数仓平台团队在数仓建设和建模过程中所做的工作&#xff0c;并分享质量保障方面取得的成果。 一、背景目标 首先&#xff0c;分享一下 B 站数据质量保障的背景和目标。 …

onnx 模型加载部署运行方式

1.onnx文件加载方式在onnxruntime下是: env Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "YOLOV8");sessionOptions Ort::SessionOptions();std::wstring w_modelPath charToWstring(mnnModelPath.c_str());session Ort::Session(env, w_modelPath.c_…

输入日期是当年的第n天

从键盘输入正确日期&#xff0c;程序输出是当年的第n天。 (本笔记适合熟悉循环和列表的 coder 翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;https://www.python.org/ Free&#xff1a;大咖免费“圣经”教程《 python 完全自学教程》&#xff0c;不仅仅是基础那么…

Springboot整合Minio实现文件上传和下载

目录 1. Minio 1.1 Minio下载 2. Springboot和Minio实现文件存储 1. Minio Minio是一个灵活、高性能、开源的对象存储解决方案&#xff0c;适用于各种存储需求&#xff0c;并可以与云计算、容器化、大数据和应用程序集成。它为用户提供了自主控制和可扩展性&#xff0c;使其…

Redis系统学习(高级篇)-Redis持久化-RDB方式

目录 一、RDB是什么&#xff1f; 二、RDB的执行时机 三、RDB的其他命令 四、RDB的执行原理 五、RDB的优缺点 一、RDB是什么&#xff1f; RDB全称叫Redis Database Backup file&#xff08;Redis数据备份文件&#xff09; &#xff0c;也叫Redis数据快照&#xff0c;就是将…

针对zkVM中Memory Consistency Checks的Polynomial IOPs

1. 引言 主要参考Yuncong Zhang等人2023年论文《Polynomial IOPs for Memory Consistency Checks in Zero-Knowledge Virtual Machines》。 在设计zkvm时&#xff0c;需检查其所有组件的功能一致性&#xff0c;包括&#xff1a; instruction fetcher寄存器文件算术化逻辑单元…

模型实际训练笔记1—AlexNet

1、AlexNet网络模型介绍&#xff1a; AlexNet 是一种深度卷积神经网络&#xff0c;由Alex Krizhevsky、Ilya Sutskever 和 Geoffrey Hinton 在2012年开发。它是深度学习领域的重要里程碑&#xff0c;因为它在当年的 ImageNet 大规模图像分类竞赛&#xff08;ILSVRC&#xff09…