NLP学习记录十:多头注意力

news2025/5/15 5:08:39

一、单头注意力

单头注意力的大致流程如下:

① 查询编码向量、键编码向量和值编码向量分别经过自己的全连接层(Wq、Wk、Wv)后得到查询Q、键K和值V;

② 查询Q和键K经过注意力评分函数(如:缩放点积运算)得到值权重矩阵;

③ 权重矩阵与值向量相乘,得到输出结果。

 图1 单头注意力模型

 

二、多头注意力 

2.1 使用多头注意力的意义      

        看了一些对多头注意力机制解释的视频,我自己的浅显理解是:在实践中,我们会希望查询Q能够从给定内容中尽可能多地匹配到与自己相关的语义信息,从而得到更准确的预测输出。而多头注意力将查询、键和值分成不同的子空间表示(representation subspaces)(有点类似于子特征?),使得匹配过程更加细化。

2.2 代码实现

        也许直接看代码能更快地理解这个过程:

import torch
from torch import nn
from attentionScore import DotProductAttention
# 多头注意力模型
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    # queries:(batch_size,查询的个数,query_size)
    # keys:(batch_size,“键-值”对的个数,key_size)
    # values:(batch_size,“键-值”对的个数,value_size)
    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
        queries = self.W_q(queries)
        keys = self.W_k(keys)
        values = self.W_v(values)

        # 经过变换后,输出的queries,keys,values的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
        queries = transpose_qkv(queries, self.num_heads)
        keys = transpose_qkv(keys, self.num_heads)
        values = transpose_qkv(values, self.num_heads)

        # valid_lens的形状:(batch_size,)或(batch_size,查询的个数)
        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
# 为了多注意力头的并行计算而变换形状
def transpose_qkv(X, num_heads):
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])
# 逆转transpose_qkv函数的操作
def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

        可以发现,前面的处理流程和单头注意力的第①步是一样的,都是使用全连接层计算查询Q、键K、值V。但在进行点积运算之前,模型使用transpose_qkv函数对QKV进行了切割变换,下图可以帮助理解这个过程:

图2 transpose_qkv函数处理Q

图3 transpose_qkv函数处理K 

        这个过程就像是把一个整体划分为了很多小的子空间。一个不知道恰不恰当的比喻,就像是把“父母”这个词拆分成了“长辈”、“养育者”、“监护人”、“爸妈”多重含义。

        对切割变换后的QK进行缩放点积运算,过程如下图所示:

 图4 对切割变换后的Q和K进行缩放点积运算

        transpose_output后的输出结果:

图5 对值加权结果进行transpose_output变换后 

        对比单头注意力的值加权输出,原来的每个查询Q匹配到了更多的value:

图6 多头注意力与单头注意力的值加权结果对比

        整个过程就像是把一个父需求分割成不同的子需求,子需求单独与不同的子特征进行匹配,最后使得每个父需求获得了更多的语义信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2307172.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring基础01

Spring基础01 软件开发原则 OCP开闭原则:七大开发原则当中最基本的原则,其他的六个原则是为这个原则服务的。 对扩展开放,对修改关闭。在扩展系统功能的时候,没有修改之前写好的代码,就符合OCP原则,反之&a…

2025年2月,TVBOX接口最新汇总版

这里写自定义目录标题 1、离线版很必要2、关于在线版好还是离线版更实在,作个总结:★ 离线版的优点:★ 离线版的缺点: 3.1、 针对FM内置的写法;3.2、 如果是用在YSC,那么格式也要有些小小的改变3.2.1、 YSC…

Dubbo RPC 原理

一、Dubbo 简介 Apache Dubbo 是一款高性能、轻量级的开源 RPC 框架,支持服务治理、协议扩展、负载均衡、容错机制等核心功能,广泛应用于微服务架构。其核心目标是解决分布式服务之间的高效通信与服务治理问题。 二、Dubbo 架构设计 1. 核心组件 Prov…

第2章_保护您的第一个应用程序

第2章_保护您的第一个应用程序 在本章中,您将学习如何使用 Keycloak 保护您的第一个应用程序。为了让事情更有趣,您将运行的示例应用程序由两部分组成,前端 Web 应用程序和后端 REST API。这将向您展示用户如何向前端进行身份验证&#xff0…

【Godot4.3】自定义圆角容器

概述 Godot控件想要完全实现现代UI风格,需要进行大量的自定义组件设计。本篇就依托于笔者自己对现代UI设计中的圆角面板元素模仿来制作圆角容器组件。 圆角容器 圆角元素在现代的扁平UI设计中非常常见,在Godot中可以通过改进PanelContainer来或者自定…

Flutter系列教程之(5)——常用控件Widget的使用示例

目录 1.页面跳转 2.某个控件设置点击事件 3.AlertDialog对话框的使用 4.文本输入框 5.按钮 圆角扁平按钮: 圆角悬浮按钮: 6.补充 圆点 7.布局使用 Row控件左右对齐 调整边距 1.页面跳转 首先,先介绍一下页面跳转功能吧 Flutter使用 Navigator 进行页面…

DeepSeek开源周,第三弹再次来袭,DeepGEMM

在大型模型推理中,矩阵乘法(GEMM)是计算的核心瓶颈。DeepGEMM 应运而生——一款专为 FP8精度矩阵乘法 设计的轻量级CUDA库,由深度求索(DeepSeek)团队开源。它凭借极简代码(核心仅300行&#xff…

stm32四种方式精密控制步进电机

在搭建完clion的开发环境后,我决定重写之前的项目并优化完善,争取做出完全可落地的东西,也结合要写的论文内容一同学习下去。 因此,首当其冲的就是回到步进电机控制领域,把之前使用中断溢出进行步进电机控制的方案进行…

git merge -s ours ...的使用方法

当我们在自己的feature branch上开发时,并且已经commit,push了好几次 同时develop分支也commit , push了好几次, 如下图所示 这个时候就不能直接将feature branch上的改动 pull request到develop上面,因为develop基线已经不一样了…

数字可调控开关电源设计(论文+源码)

1 设计要求 在本次数字可调控开关电源设计过程中,对关键参数设定如下: (1)输入电压:DC24-26V,输出电压:12-24(可调); (2)输出电压误差&#xf…

【DeepSeek】【GPT-Academic】:DeepSeek集成到GPT-Academic(官方+第三方)

目录 1 官方deepseek 1.1 拉取学术GPT项目 1.2 安装依赖 1.3 修改配置文件中的DEEPSEEK_API_KEY 2 第三方API 2.1 修改DEEPSEEK_API_KEY 2.2 修改CUSTOM_API_KEY_PATTERM 2.3 地址重定向 2.4 修改模型参数 2.5 成功调用 2.6 尝试添加一个deepseek-r1参数 3 使用千帆…

DeepSeek R1 + 飞书机器人实现AI智能助手

效果 TFChat项目地址 https://github.com/fish2018/TFChat 腾讯大模型知识引擎用的是DeepSeek R1,项目为sanic和redis实现,利用httpx异步处理流式响应,同时使用buffer来避免频繁调用飞书接口更新卡片的网络耗时。为了进一步减少网络IO消耗&…

Android移动应用开发实践-1-下载安装和简单使用Android Studio 3.5.2版本(频频出错)

一、下载安装 1.Android Studio3.5.2下载地址:Android Studio3.5.2下载地址 其他版本下载地址:其他版本下载地址 2.安装教程(可以多找几个看看) 安装 | 手把手教你Android studio 3.5.2安装(安装教程)_a…

Rk3568驱动开发_驱动编写和挂载_2

1.字符驱动介绍: 字符驱动:按照字节流镜像读写操作的设备,读写数据分先后顺序,例如:点灯、按键、IIC、SPI、等等都是字符设备,这些设备的驱动叫字符驱动设备 Linux应用层如何调用驱动: 字符设…

【苍穹外卖】问题笔记

【DAY1 】 1.VCS找不到 好吧,发现没安git 接着发现安全模式有问题,点开代码信任此项目 2.导入初始文件,全员爆红 好像没maven,配一个 并在设置里设置好maven 3.启用注解,见新手苍穹 pom.xml改lombok版本为1.1…

1.1部署es:9200

安装es:root用户: 1.布署java环境 - 所有节点 wget https://d6.injdk.cn/oraclejdk/8/jdk-8u341-linux-x64.rpm yum localinstall jdk-8u341-linux-x64.rpm -y java -version 2.下载安装elasticsearch - 所有节点 wget ftp://10.3.148.254/Note/Elk/…

上传securecmd失败

上传securecmd失败 问题描述:KES V8R6部署工具中,节点管理里新建节点下一步提示上传securecmd失败,如下: 解决办法: [rootlocalhost ~]# yum install -y unzip 上传的过程中会解压,如果未安装unzip依赖包…

C++:dfs,bfs各两则

1.木棒 167. 木棒 - AcWing题库 乔治拿来一组等长的木棒,将它们随机地砍断,使得每一节木棍的长度都不超过 5050 个长度单位。 然后他又想把这些木棍恢复到为裁截前的状态,但忘记了初始时有多少木棒以及木棒的初始长度。 请你设计一个程序…

P9420 [蓝桥杯 2023 国 B] 子 2023

P9420 [蓝桥杯 2023 国 B] 子 2023 题目 分析代码 题目 分析 刚拿到这道题,我大脑简单算了一下,这个值太大了,直观感觉就很难!! 但是,你仔仔细细的一看,先从最简单的第一步入手,再…

2025-02-26 学习记录--C/C++-C语言 判断字符串S2是否在字符串S1中

合抱之木&#xff0c;生于毫末&#xff1b;九层之台&#xff0c;起于累土&#xff1b;千里之行&#xff0c;始于足下。&#x1f4aa;&#x1f3fb; C语言 判断字符串S2是否在字符串S1中 #include <stdio.h> // 引入标准输入输出库&#xff0c;用于使用 printf 等函数 #…