第9讲、深入理解Scaled Dot-Product Attention

news2025/5/18 15:27:44

Scaled Dot-Product Attention是Transformer架构的核心组件,也是现代深度学习中最重要的注意力机制之一。本文将从原理、实现和应用三个方面深入剖析这一机制。

1. 基本原理

Scaled Dot-Product Attention的本质是一种加权求和机制,通过计算查询(Query)与键(Key)的相似度来确定对值(Value)的关注程度。其数学表达式为:

这个公式包含几个关键步骤:

  1. 计算相似度:通过点积(dot product)计算Query和Key的相似度,得到注意力分数(attention scores)
  2. 缩放(Scaling):将点积结果除以 d k \sqrt{d_k} dk 进行缩放,其中 d k d_k dk是Key的维度
  3. 应用Mask(可选):在某些情况下(如自回归生成)需要遮盖未来信息
  4. Softmax归一化:将注意力分数通过softmax转换为概率分布
  5. 加权求和:用这些概率对Value进行加权求和

2. 为什么需要缩放(Scaling)?

缩放是Scaled Dot-Product Attention区别于普通Dot-Product Attention的关键。当输入的维度 d k d_k dk较大时,点积的方差也会变大,导致softmax函数梯度变得极小(梯度消失问题)。通过除以 d k \sqrt{d_k} dk ,可以将方差控制在合理范围内。

假设Query和Key的各个分量是均值为0、方差为1的独立随机变量,则它们点积的方差为 d k d_k dk。通过除以 d k \sqrt{d_k} dk ,可以将方差归一化为1。

3. 代码实现解析

让我们看看PyTorch中Scaled Dot-Product Attention的典型实现:

def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    # 获取key的维度
    d_k = query.size(-1)
    
    # 计算注意力分数并缩放
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 应用mask(如果提供)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # 应用softmax得到注意力权重
    attn = F.softmax(scores, dim=-1)
    
    # 应用dropout(如果提供)
    if dropout is not None:
        attn = dropout(attn)
    
    # 加权求和
    return torch.matmul(attn, value), attn

这个函数接受query、key、value三个张量作为输入,可选的mask用于遮盖某些位置,dropout用于正则化。

4. 张量维度分析

假设输入的形状为:

  • Query: [batch_size, seq_len_q, d_k]
  • Key: [batch_size, seq_len_k, d_k]
  • Value: [batch_size, seq_len_k, d_v]

计算过程中各步骤的维度变化:

  1. Key转置后: [batch_size, d_k, seq_len_k]
  2. Query与Key的点积: [batch_size, seq_len_q, seq_len_k]
  3. Softmax后的注意力权重: [batch_size, seq_len_q, seq_len_k]
  4. 最终输出: [batch_size, seq_len_q, d_v]

5. 在Multi-Head Attention中的应用

Scaled Dot-Product Attention是Multi-Head Attention的基础。在Multi-Head Attention中,我们将输入投影到多个子空间,在每个子空间独立计算注意力,然后将结果合并:

class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)

        nbatches = query.size(0)

        # 1) 投影并分割成多头
        query, key, value = [
            l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linears, (query, key, value))
        ]

        # 2) 应用注意力机制
        x, self.attn = scaled_dot_product_attention(query, key, value, mask, self.dropout)

        # 3) 合并多头结果
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

6. 实际应用场景

Scaled Dot-Product Attention在多种场景下表现出色:

  1. 自然语言处理:捕捉句子中词与词之间的依赖关系
  2. 计算机视觉:关注图像中的重要区域
  3. 推荐系统:建模用户与物品之间的交互
  4. 语音处理:捕捉音频信号中的时序依赖

7. 优势与局限性

优势

  • 计算效率高(可以通过矩阵乘法并行计算)
  • 能够捕捉长距离依赖关系
  • 模型可解释性强(可以可视化注意力权重)

局限性

  • 计算复杂度为O(n²),对于长序列计算开销大
  • 没有考虑位置信息(需要额外的位置编码)
  • 对于某些任务,可能需要结合CNN等结构以捕捉局部特征

8. 总结

Scaled Dot-Product Attention是现代深度学习中的关键创新,通过简单而优雅的设计实现了强大的表达能力。它不仅是Transformer架构的核心,也启发了众多后续工作,如Performer、Linformer等对注意力机制的改进。理解这一机制对于掌握现代深度学习模型至关重要。

通过缩放点积、应用softmax和加权求和这三个简单步骤,Scaled Dot-Product Attention成功地让模型"关注"输入中的重要部分,这也是它能在各种任务中取得卓越表现的关键所在。

##9、Scaled Dot-Product Attention应用案例

敬请关注下一篇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2378587.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

双向长短期记忆网络-BiLSTM

5月14日复盘 二、BiLSTM 1. 概述 双向长短期记忆网络(Bi-directional Long Short-Term Memory,BiLSTM)是一种扩展自长短期记忆网络(LSTM)的结构,旨在解决传统 LSTM 模型只能考虑到过去信息的问题。BiLST…

MySQL UPDATE 执行流程全解析

引言 当你在 MySQL 中执行一条 UPDATE 语句时,背后隐藏着一套精密的协作机制。从解析器到存储引擎,从锁管理到 WAL 日志,每个环节都直接影响数据一致性和性能。 本文将通过 Mermaid 流程图 和 时序图,完整还原 UPDATE 语句的执行…

亚马逊云科技:开启数字化转型的无限可能

在数字技术蓬勃发展的今天,云计算早已突破单纯技术工具的范畴,成为驱动企业创新、引领行业变革的核心力量。亚马逊云科技凭借前瞻性的战略布局与持续的技术深耕,在全球云计算领域树立起行业标杆,为企业和个人用户提供全方位、高品…

【实测有效】Edge浏览器打开部分pdf文件显示空白

问题现象 Edge浏览器打开部分pdf文件显示空白或显示异常。 ​​​​​​​ ​​​​​​​ ​​​​​​​ 问题原因 部分pdf文件与edge浏览器存在兼容性问题,打开显示异常。 解决办法 法1:修改edge配置 打开edge浏览器&#x…

RJ连接器的未来:它还会是网络连接的主流标准吗?

RJ连接器作为以太网接口的代表,自20世纪以来在计算机网络、通信设备、安防系统等领域中占据了核心地位。以RJ45为代表的RJ连接器,凭借其结构稳定、信号传输可靠、成本低廉等优势,在有线网络布线领域被广泛采用。然而,在无线网络不…

Redis持久化机制详解:保障数据安全的关键策略

在现代应用开发中,Redis作为高性能的内存数据库被广泛使用。然而,内存的易失性特性使得持久化成为Redis设计中的关键环节。本文将全面剖析Redis的持久化机制,包括RDB、AOF以及混合持久化模式,帮助开发者根据业务需求选择最适合的持…

DeepSeek 大模型部署全指南:常见问题、优化策略与实战解决方案

DeepSeek 作为当前最热门的开源大模型之一,其强大的语义理解和生成能力吸引了大量开发者和企业关注。然而在实际部署过程中,无论是本地运行还是云端服务,用户往往会遇到各种技术挑战。本文将全面剖析 DeepSeek 部署中的常见问题,提…

嵌入式培训之数据结构学习(五)栈与队列

一、栈 (一)栈的基本概念 1、栈的定义: 注:线性表中的栈在堆区(因为是malloc来的);系统中的栈区存储局部变量、函数形参、函数返回值地址。 2、栈顶和栈底: 允许插入和删除的一端…

RabbitMQ--进阶篇

RabbitMQ 客户端整合Spring Boot 添加相关的依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId> </dependency> 编写配置文件&#xff0c;配置RabbitMQ的服务信息 spri…

Android Studio报错Cannot parse result path string:

前言 最近在写个小Demo&#xff0c;参考郭霖的《第一行代码》&#xff0c;学习DrawerLayout和NavigationView&#xff0c;不知咋地&#xff0c;突然报错Cannot parse result path string:xxxxxxxxxxxxx 反正百度&#xff0c;问ai都找不到答案&#xff0c;报错信息是完全看不懂…

关于网站提交搜索引擎

发布于Eucalyptus-blog 一、前言 将网站提交给搜索引擎是为了让搜索引擎更早地了解、索引和显示您的网站内容。以下是一些提交网站给搜索引擎的理由&#xff1a; 提高可见性&#xff1a;通过将您的网站提交给搜索引擎&#xff0c;可以提高您的网站在搜索结果中出现的机会。当用…

基于QT(C++)OOP 实现(界面)酒店预订与管理系统

酒店预订与管理系统 1 系统功能设计 酒店预订是旅游出行的重要环节&#xff0c;而酒店预订与管理系统中的管理与信息透明是酒店预订业务的关键问题所在&#xff0c;能够方便地查询酒店信息进行付款退款以及用户之间的交流对于酒店预订行业提高服务质量具有重要的意义。 针对…

机械元件杂散光难以把控?OAS 软件案例深度解析

机械元件的杂散光分析 简介 在光学系统设计与工程实践中&#xff0c;机械元件的杂散光问题对系统性能有着不容忽视的影响。杂散光会降低光学系统的信噪比、图像对比度&#xff0c;甚至导致系统功能失效。因此&#xff0c;准确分析机械元件杂散光并采取有效抑制措施&#xff0c…

游戏引擎学习第289天:将视觉表现与实体类型解耦

回顾并为今天的工作设定基调 我们正在继续昨天对代码所做的改动。我们已经完成了“脑代码&#xff08;brain code&#xff09;”的概念&#xff0c;它本质上是一种为实体构建的自组织控制器结构。现在我们要做的是把旧的控制逻辑迁移到这个新的结构中&#xff0c;并进一步测试…

【Linux网络】ARP协议

ARP协议 虽然我们在这里介绍 ARP 协议&#xff0c;但是需要强调&#xff0c;ARP 不是一个单纯的数据链路层的协议&#xff0c;而是一个介于数据链路层和网络层之间的协议。 ARP数据报的格式 字段长度&#xff08;字节&#xff09;说明硬件类型2网络类型&#xff08;如以太网为…

MUSE Pi Pro 开发板 Imagination GPU 利用 OpenCL 测试

视频讲解&#xff1a; MUSE Pi Pro 开发板 Imagination GPU 利用 OpenCL 测试 继续玩MUSE Pi Pro&#xff0c;今天看下比较关注的gpu这块&#xff0c;从opencl看起&#xff0c;安装clinfo指令 sudo apt install clinfo 可以看到这颗GPU是Imagination的 一般嵌入式中gpu都和hos…

多线程与线程互斥

我们初步学习完线程之后&#xff0c;就要来试着写一写多线程了。在写之前&#xff0c;我们需要继续来学习一个线程接口——叫做线程分离。 默认情况下&#xff0c;新创建的线程是joinable的&#xff0c;线程退出后&#xff0c;需要对其进行pthread_join操作&#xff0c;否则无法…

游戏引擎学习第287天:加入brain逻辑

Blackboard&#xff1a;动态控制类似蛇的多节实体 我们目前正在处理一个关于实体系统如何以组合方式进行管理的问题。具体来说&#xff0c;是在游戏中实现多个实体可以共同或独立行动的机制。例如&#xff0c;我们的主角拥有两个实体组成部分&#xff0c;一个是身体&#xff0…

continue通过我们的开源 IDE 扩展和模型、规则、提示、文档和其他构建块中心,创建、共享和使用自定义 AI 代码助手

​一、软件介绍 文末提供程序和源码下载 Continue 使开发人员能够通过我们的开源 VS Code 和 JetBrains 扩展以及模型、规则、提示、文档和其他构建块的中心创建、共享和使用自定义 AI 代码助手。 二、功能 Chat 聊天 Chat makes it easy to ask for help from an LLM without…

2025年EB SCI2区TOP,多策略改进黑翅鸢算法MBKA+空调系统RC参数辨识与负载聚合分析,深度解析+性能实测

目录 1.摘要2.黑翅鸢优化算法BKA原理3.改进策略4.结果展示5.参考文献6.代码获取7.读者交流 1.摘要 随着空调负载在电力系统中所占比例的不断上升&#xff0c;其作为需求响应资源的潜力日益凸显。然而&#xff0c;由于建筑环境和用户行为的变化&#xff0c;空调负载具有异质性和…