第9.2讲、Tiny Decoder(带 Mask)详解与实战

news2025/5/25 15:21:47

自己搭建一个 Tiny Decoder(带 Mask),参考 Transformer Encoder 的结构,并添加 Masked Multi-Head Self-Attention,它是 Decoder 的核心特征之一。


1. 背景与动机

Transformer 架构已成为自然语言处理(NLP)领域的主流。其 Encoder-Decoder 结构广泛应用于机器翻译、文本生成等任务。Decoder 的核心特征是 Masked Multi-Head Self-Attention,它保证了自回归生成时不会"偷看"未来信息。本文将带你从零实现一个最小可运行的 Tiny Decoder,并深入理解其原理。


2. Tiny Decoder 架构简述

一个标准 Transformer Decoder Layer 包括:

  1. Masked Multi-Head Self-Attention
  2. Encoder-Decoder Attention(跨注意力)
  3. Feed Forward Network (FFN)
  4. LayerNorm + Residual Connection

为了简化,我们暂时不引入 Encoder-Decoder Attention,只聚焦于:

Masked Self-Attention + FFN


3. 什么是 Masked Attention?

Masked Attention 的作用是在 Decoder 生成序列时,禁止看到"未来"的 token,防止信息泄露。

用一个 Mask 矩阵来实现,例如:

Mask for length 4:
[[0, -inf, -inf, -inf],
 [0,   0,  -inf, -inf],
 [0,   0,    0,  -inf],
 [0,   0,    0,    0]]

这个 Mask 会加在 Attention 的 logits 上(即 QKᵗ / sqrt(dk)),将不允许的位置置为 -inf,softmax 之后就是 0。


4. Tiny Decoder 核心代码(简化 PyTorch 实现)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 带掩码的多头自注意力机制
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0  # 保证可以均分到每个头

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        # 用一个线性层同时生成 Q、K、V
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        # 输出投影
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        B, T, C = x.size()
        # 生成 Q、K、V,并分头
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.d_k).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # (B, heads, T, d_k)

        # 计算注意力分数 (QK^T / sqrt(d_k))
        attn_logits = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)  # (B, heads, T, T)

        # 构造下三角 Mask,防止看到未来信息
        mask = torch.tril(torch.ones(T, T)).to(x.device)
        attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))

        # softmax 得到注意力权重
        attn = F.softmax(attn_logits, dim=-1)
        # 加权求和得到输出
        out = attn @ v  # (B, heads, T, d_k)

        # 合并多头
        out = out.transpose(1, 2).contiguous().reshape(B, T, C)
        # 输出投影
        return self.out_proj(out)

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        # 两层全连接+ReLU
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        # 前馈变换
        return self.ff(x)

# Tiny Decoder 层,包含 Masked Self-Attention 和前馈网络
class TinyDecoderLayer(nn.Module):
    def __init__(self, d_model=128, num_heads=4, d_ff=512):
        super().__init__()
        self.self_attn = MaskedSelfAttention(d_model, num_heads)  # 掩码自注意力
        self.ff = FeedForward(d_model, d_ff)  # 前馈网络
        self.norm1 = nn.LayerNorm(d_model)    # 层归一化1
        self.norm2 = nn.LayerNorm(d_model)    # 层归一化2

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        # 先归一化,再做自注意力,并加残差
        x = x + self.self_attn(self.norm1(x))
        # 再归一化,前馈网络,并加残差
        x = x + self.ff(self.norm2(x))
        return x

5. 使用示例

x = torch.randn(2, 10, 128)        # Decoder输入
context = torch.randn(2, 15, 128)  # Encoder输出
decoder = TinyDecoderLayer()
y = decoder(x, context)            # output shape: (2, 10, 128)

6. 进阶扩展

6.1 添加 Encoder-Decoder Attention

Encoder-Decoder Attention 允许 Decoder 在生成时参考 Encoder 的输出(即源语言信息),是机器翻译等任务的关键。其实现方式与 Self-Attention 类似,只是 Q 来自 Decoder,K/V 来自 Encoder。

伪代码:

class CrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        # ...同 MaskedSelfAttention ...
    def forward(self, x, context):
        # x: (B, T_dec, d_model), context: (B, T_enc, d_model)
        # Q from x, K/V from context
        # ...实现...

在 Decoder Layer 中插入:

self.cross_attn = CrossAttention(d_model, num_heads)
# forward:
x = x + self.cross_attn(self.norm_cross(x), context)

6.2 多层 Decoder 堆叠

实际应用中,Decoder 通常由多层堆叠而成:

class TinyDecoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff):
        super().__init__()
        self.layers = nn.ModuleList([
            TinyDecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

6.3 加入 Positional Encoding

Transformer 不具备序列顺序感知能力,需加上 Positional Encoding:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        return x + self.pe[:x.size(1)]

7. 完整训练例子(伪代码)

# 假设有输入数据 input_seq, target_seq
x = embedding(input_seq)
x = pos_encoding(x)
decoder = TinyDecoder(num_layers=2, d_model=128, num_heads=4, d_ff=512)
output = decoder(x)
# 计算 loss, 反向传播

8. 小结

  • Decoder 的关键是 Masked Self-Attention,通过 tril 的下三角掩码防止泄漏未来信息。
  • 可以用 torch.tril 快速构造下三角 Mask。
  • Decoder 层和 Encoder 类似,但注意力机制加了 Mask,而且通常会多出 Encoder-Decoder Attention。
  • 可扩展为多层、加入位置编码、跨注意力等,逐步构建完整的 Transformer Decoder。
    *如果不加 Mask,允许 Decoder 看到未来 token,会导致模型训练"作弊",推理时表现极差,生成文本质量低下,模型失去实际应用价值。因此,Masked Self-Attention 是保证自回归生成和模型泛化能力的关键机制。

9. 参考资料

  • Attention is All You Need
  • The Annotated Transformer
  • PyTorch 官方文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2385429.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于PCRLB的CMIMO雷达网络多目标跟踪资源调度

针对分布式组网CMIMO雷达多目标跟踪(MTT)场景,博客分析了一种目标-雷达匹配方案与功率联合优化算法。在采用分布式组网融合架构的基础上,推导包含波束和功率分配的后验克拉美罗界(PCRLB)。随后,将该效用函数结合CMIMO雷达系统资源&#xff0c…

AtCoder Beginner Contest 407(ABCDE)

A - Approximation 翻译&#xff1a; 给你一个正整数 A 和一个正奇数 B。 请输出与实数 的差最小的整数。 可以证明&#xff0c;在约束条件下&#xff0c;这样的整数是唯一的。 思路&#xff1a; 令。比较来判断答案。 实现&#xff1a; #include<bits/…

VILT模型阅读笔记

代码地址&#xff1a;VILT Abstract Vision-and-Language Pre-training (VLP) has improved performance on various joint vision-andlanguage downstream tasks. Current approaches to VLP heavily rely on image feature extraction processes, most of which involve re…

掌握 npm 核心操作:从安装到管理依赖的完整指南

图为开发者正在终端操作npm命令&#xff0c;图片来源&#xff1a;Unsplash 作为 Node.js 生态的基石&#xff0c;npm&#xff08;Node Package Manager&#xff09;是每位开发者必须精通的工具。每天有超过 1700 万个项目通过 npm 共享代码&#xff0c;其重要性不言而喻。本文…

OpenCV CUDA模块特征检测与描述------一种基于快速特征点检测和旋转不变的二进制描述符类cv::cuda::ORB

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::cuda::ORB 是 OpenCV 库中 CUDA 模块的一部分&#xff0c;它提供了一种基于快速特征点检测和旋转不变的二进制描述符的方法&#xff0c;用于…

Awesome ChatGPT Prompts:释放AI对话潜力的开源利器

项目概览 Awesome ChatGPT Prompts 是由土耳其开发者 Fatih Kadir Akın 发起的开源项目,托管于 GitHub,旨在通过精心设计的提示词模板(Prompts)优化用户与 ChatGPT 的交互体验。项目以 Markdown 和 CSV 格式管理模板,无需复杂编程语言,但需文本处理能力,目前已在 GitH…

PP-YOLOE-SOD学习笔记2

一、解析X-Anylabeling标注后的json格式问题 最近在使用自动标注工具后json格式转化过程中&#xff0c;即标注框的四点坐标转换为两点坐标时&#xff0c;发现json格式的四点顺序是按顺时针方向开始的&#xff0c;那么在转换其实就是删除2、4坐标或者1、3坐标即可。 二、数据集…

算法学习——从零实现循环神经网络

从零实现循环神经网络 一、任务背景二、数据读取与准备1. 词元化2. 构建词表 三、参数初始化与训练1. 参数初始化2. 模型训练 四、预测总结 一、任务背景 对于序列文本来说&#xff0c;如何通过输入的几个词来得到后面的词一直是大家关注的任务之一&#xff0c;即&#xff1a;…

win10使用nginx做简单负载均衡测试

一、首先安装Nginx&#xff1a; 官网链接&#xff1a;https://nginx.org/en/download.html 下载完成后&#xff0c;在本地文件中解压。 解压完成之后&#xff0c;打开conf --> nginx.config 文件 1、在 http 里面加入以下代码 upstream GY{#Nginx是如何实现负载均衡的&a…

2025电工杯数学建模B题思路数模AI提示词工程

我发布的智能体链接&#xff1a;数模AI扣子是新一代 AI 大模型智能体开发平台。整合了插件、长短期记忆、工作流、卡片等丰富能力&#xff0c;扣子能帮你低门槛、快速搭建个性化或具备商业价值的智能体&#xff0c;并发布到豆包、飞书等各个平台。https://www.coze.cn/search/n…

【日志软件】hoo wintail 的替代

hoo wintail 的替代 主要问题是日志大了以后会卡有时候日志覆盖后&#xff0c;改变了&#xff0c;更新了&#xff0c;hoo wintail可能无法识别需要重新打开。 有很多类似的日志监控软件可以替代。以下是一些推荐的选项&#xff1a; 免费软件 BareTail 轻量级的实时日志查看…

Ollama-OCR:基于Ollama多模态大模型的端到端文档解析和处理

基本介绍 Ollama-OCR是一个Python的OCR解析库&#xff0c;结合了Ollama的模型能力&#xff0c;可以直接处理 PDF 文件无需额外转换&#xff0c;轻松从扫描版或原生 PDF 文档中提取文本和数据。根据使用的视觉模型和自定义提示词&#xff0c;Ollama-OCR 可支持多种语言&#xf…

OpenCV CUDA 模块中图像过滤------创建一个拉普拉斯(Laplacian)滤波器函数createLaplacianFilter()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::cuda::createLaplacianFilter 是 OpenCV CUDA 模块中的一个函数&#xff0c;用于创建一个 拉普拉斯&#xff08;Laplacian&#xff09;滤波器…

图论学习笔记 3

自认为写了很多&#xff0c;后面会出 仙人掌、最小树形图 学习笔记。 多图警告。 众所周知王老师有一句话&#xff1a; ⼀篇⽂章不宜过⻓&#xff0c;不然之后再修改使⽤的时候&#xff0c;在其中找想找的东⻄就有点麻烦了。当然⽂章也不宜过多&#xff0c;不然想要的⽂章也不…

【将WPS设置为默认打开方式】--突然无法用WPS打开文件

1. 点击【开始】——【WPS Office】——【配置工具】&#xff1b; 2. 在出现的弹窗中&#xff0c;点击【高级】&#xff1b; 3. 在“兼容设置”中&#xff0c;将复选框勾上&#xff0c;点击【确定】。

电子人的分水岭-FPGA模电和数电

为什么模电这么难学&#xff1f;一文带你透彻理解模电 ——FPGA是“前期数电&#xff0c;后期模电”的典型代表 在电子工程的世界里&#xff0c;有两门基础课程让无数学生“闻之色变”&#xff1a;数字电路&#xff08;数电&#xff09; 和 模拟电路&#xff08;模电&#xff0…

(6)python爬虫--selenium

文章目录 前言一、初识selenium二、安装selenium2.1 查看chrome版本并禁止chrome自动更新2.1.1 查看chrome版本2.1.2 禁止chrome更新自动更新 2.2 安装对应版本的驱动程序2.3安装selenium包 三、selenium关于浏览器的使用3.1 创建浏览器、设置、打开3.2 打开/关闭网页及浏览器3…

Python之两个爬虫案例实战(澎湃新闻+网易每日简报):附源码+解释

目录 一、案例一&#xff1a;澎湃新闻时政爬取 &#xff08;1&#xff09;数据采集网站 &#xff08;2&#xff09;数据介绍 &#xff08;3&#xff09;数据采集方法 &#xff08;4&#xff09;数据采集过程 二、案例二&#xff1a;网易每日新闻简报爬取 &#xff08;1&#x…

✨ PLSQL卡顿优化

✨ PLSQL卡顿优化 1.&#x1f4c2; 打开首选项2.&#x1f527; Oracle连接配置3.⛔ 关闭更新和新闻 1.&#x1f4c2; 打开首选项 2.&#x1f527; Oracle连接配置 3.⛔ 关闭更新和新闻

python+vlisp实现对多段线范围内土方体积的计算

#在工程中&#xff0c;经常用到计算土方回填、土方开挖的体积。就是在一个范围内&#xff0c;计算土被挖走&#xff0c;或者填多少&#xff0c;这个需要测量挖填前后这个范围内的高程点。为此&#xff0c;我开发一个app&#xff0c;可以直接在autocad上提取高程点&#xff0c;然…