使用PyTorch从头实现Transformer

news2025/5/15 8:39:12

前言

  • 本文使用Pytorch从头实现Transformer,原论文Attention is all you need paper,最佳解读博客,学习视频
  • GitHub项目地址Some-Paper-CN。本项目是译者在学习长时间序列预测、CV、NLP和机器学习过程中精读的一些论文,并对其进行了中文翻译。还有部分最佳示例教程
  • 如果有帮助到大家,请帮忙点亮Star,也是对译者莫大的鼓励,谢谢啦~

SelfAttention

  • 整篇论文中,最核心的部分就是SelfAttention部分,SelfAttention模块架构图如下。
    请添加图片描述

规范化公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # batch_size
        N = query.shape[0]
        value_len, keys_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, keys_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        # values, keys, queries shape:(N, seq_len, heads, head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape:(N, query_len, heads, heads_dim)
        # keys shape:(N, key_len, heads, heads_dim)
        # energy shape:(N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0 ,float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        # attention shape:(N, heads, seq_len, seq_len)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape:(N, heads, query_len, key_len)
        # values shape:(N, values_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten lash two dimensions

        out = self.fc_out(out)
        return out
  • 请注意values keysqueryLinear层是不带偏置的!
  • 上述代码中,较难理解的是torch.einsum(),爱因斯坦求和约定,nqhd,nkhd->nhqk可以理解为维度是 ( n , q , h , d ) (n,q,h,d) (n,q,h,d)的张量与 ( n , k , h , d ) (n,k,h,d) (n,k,h,d)的张量沿着维度 d d d相乘,得到维度 ( n , q , d , h ) (n,q,d,h) (n,q,d,h)重新排列后变成 ( n , h , q , k ) (n,h,q,k) (n,h,q,k)
  • 传入mask矩阵是因为每个句子的长度不一样,为了保证维度相同,在长度不足的句子后面使用padding补齐,而padding是不用计算损失的,所以需要mask告诉模型哪些位置需要计算损失。被mask遮掩的地方,被赋予无限小的值,这样在softmax以后概率就几乎为0了。
  • mask这个地方后面也会写到,如果不理解的话,先有这个概念,后面看完代码就会理解了。

TransformerBlock

  • 实现完SelfAttention,开始实现基本模块TransformerBlock,架构图如下。

请添加图片描述

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        # attention shape:(N, seq_len, emb_dim)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out
  • Feed Forward部分由两层带偏置的Linear层和ReLU激活函数

  • 注意,Norm层是LayerNorm层,不是BatchNorm层。原因主要有:

    • 在进行BatchNorm操作时,同一个batch中的所有样本都会被考虑在内。这意味着一个样本的输出可能会受到同一批次其他样本的影响。然而,我们在处理文本数据时,通常希望每个样本(在此例中,是一个句子或一个句子段落)都是独立的。因此,LayerNorm是一个更好的选择,因为它只对单个样本进行操作。
    • 文本通常不是固定长度的,这就意味着每个batch的大小可能会有所不同。BatchNorm需要固定大小的batch才能正常工作,LayerNorm在这点上更为灵活。
  • forward_expansion是为了扩展embedding的维度,使Feed Forward包含更多参数量。

Encoder

  • 实现完TransformerBlock,就可以实现模型的Encoder部分,模块架构图如下。

请添加图片描述

class Encoder(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                ) for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        # positions shape:(N, seq_len)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        # out shape:(N, seq_len, emb_dim)
        for layer in self.layers:
            out = layer(out, out, out, mask)
        return out
  • 为了更好的理解positions并没有按照论文中使用sincos构造,但这一部分并不困难,后面大家有兴趣可以进行替换。

  • positions是为句子的每个字从0开始编号,假设有2个句子,第一个句子有3个字,第二个句子有4个字,即positions = [[0,1,2],[0,1,2,3]]

  • positionsx进入embedding层后相加,然后进入dropout

  • 因为TransformerBlock可能有多个串联,所以使用ModuleList包起来

  • 注意残差连接部分的操作。

DecoderBlock

  • 实现完Encoder部分,整个模型就已经完成一半了,接下来实现Decoder基本单元DecoderBlock,模块架构图如下。

请添加图片描述

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out
  • 注意!这里有两个Attention模块,首先输入x要进入Masked Attention得到query,然后与Encoder部分的输出组成新的v,k,q
  • 第二部分是基本单元transformer_block,可以直接调用。
  • 注意残差连接部分即可。

Decoder

  • 实现完Decoder基本单元DecoderBlock后,就可以正式开始实现Decoder部分了,模块架构图如下。

请添加图片描述

class Decoder(nn.Module):
    def __init__(
            self,
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
            for _ in range(num_layers)]
        )

        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out
  • Decoder部分的embedding部分和Encoder部分差不多,word_embeddingposition_embedding相加进入dropout层。
  • 基本单元DecoderBlock会重复多次,用ModuleList包裹。
  • enc_outEncoder部分的输出,变成了valuekey

Transformer

  • 在实现完EncoderDecoder后,就可以实现整个Transformer结构了,架构图如下。
    请添加图片描述
class Transformer(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            trg_vocab_size,
            src_pad_idx,
            trg_pad_idx,
            embed_size=256,
            num_layers=6,
            forward_expansion=4,
            heads=8,
            dropout=0,
            device='cpu',
            max_length=64
    ):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def mask_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask shape:(N, 1, 1, src_len)
        return src_mask.to(self.device)

    def mask_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        # trg_mask shape:(N, 1, 1, trg_len)
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.mask_src_mask(src)
        trg_mask = self.mask_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        # enc_src shape:(N, seq_len, emb_dim)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out
  • 需要注意的是输入的mask构造方法mask_src_mask,和输出的mask构造方法mask_trg_maskmask_src_mask是对输入的padding部分进行maskmask_trg_mask根据输出构建下三角矩阵想象一下,当模型预测第一个字的时候,后面的所有内容都是不可见的,当模型预测第二个字的时候,仅第一个字可见,后面的内容都不可见…

检验

  • 构建完成后,使用一个简单的小例子检验一下模型是否可以正常运行。
if __name__ == "__main__":
    device = 'cpu'
    # x shape:(N, seq_len)
    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0],
                      [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0],
                        [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10

    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to(device)
    out = model(x, trg[:, :-1])
    print(out.shape)
  • 输出:(2, 7, 10),完整代码放在GitHub项目Some-Paper-CN中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1641662.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

05月04日(周六)30场比赛前瞻

今日数据: 昨日复盘: 欧洲五大联赛指的是欧洲影响力及竞技水平排名前五的足球联赛,通常包括英格兰足球联赛(The Premier League)、西班牙足球甲级联赛(La Liga)、意大利足球甲级联赛&#xff0…

vue2人力资源项目3主页

主页权限验证 前置守卫开启进度条,后置守卫关闭进度条 import router from /router import nProgress from nprogress// 导入进度条(模板自带) import nprogress/nprogress.css// 导入进度条样式(模板自带) // 前置守…

java中对文件的基本操作

文件IO 文件IO。啥叫文件的IO? 他就是指:1.Input(输入)2.Output(输出)。 比如,我们的电脑可以从网络中下载文件,也可以通过网络上传文件等等很多的例子,都体现了输入和…

Xamarin.Android项目使用ConstraintLayout约束布局

Xamarin.AndroidX.ConstraintLayout Xamarin.Android.Support.Constraint.Layout Xamarin.AndroidX.ConstraintLayout.Solver Xamarin.AndroidX.DataBinding.ViewBinding Xamarin.AndroidX.Legacy.Support.Core.UI Xamarin.AndroidX.Lifecycle.LiveData ![在这里插入图片描述]…

android天气实战

页面绘制 问题1、下拉框需要背景为透明 我懒得写全部省份就写了5个所以不需要往下 图标准备 iconfont-阿里巴巴矢量图标库几坤年没来这了好怀念啊,图标库选择下雨的图标等 准备网络请求 0、API接口准备 api免费七日天气接口API 未来一周天气预报api (tianqiap…

SVM直观理解

https://tangshusen.me/2018/10/27/SVM/ https://www.bilibili.com/video/BV16T4y1y7qj/?spm_id_from333.337.search-card.all.click&vd_source8272bd48fee17396a4a1746c256ab0ae SVM是什么? 先来看看维基百科上对SVM的定义: 支持向量机(英语:su…

[BLE] Heart Rate Protocol - Sensor

写在前面 目前我从网上找到的有关BLE心率协议的博文内容良莠不齐,很难让人根据文章内容来全面理解心率服务;此外SIG网站上有关心率服务的文档比较多,内容比较碎,需要读者从多个文档中将需要的内容拼接起来,因此写下这…

【动态规划】路径问题|不同路径I|不同路径II|珠宝的最高价值|下降路径的最小和|最小路径和|

一、不同路径I 62. 不同路径 - 力扣(LeetCode) 💡细节: 1.多开一行和一列(跟一维数组多开一个位置一样),这样方便初始化 2.状态转移方程:注意走一步并不是多一种走的路径&#xff0…

在编程的世界里,我相信每一行代码都是一次对未来的投资

😀前言 突然有感而发也是激励自己互勉 🏠个人主页:尘觉主页 文章目录 在编程的世界里,我相信每一行代码都是一次对未来的投资类似句子编程的本质代码的价值构建可持续的未来结语 在编程的世界里,我相信每一行代码都是一…

数据库基础--MySQL多表查询之外键约束

MySQL多表关系 一对一 顾名思义即一个对应一个的关系,例如身份证号对于每个人来说都是唯一的,即个人信息表与身份证号信息表是一对一的关系。车辆信息表与车牌信息表也是属于一对一的关系。 一对多 即一个表当中的一个字段信息,对应另一张…

【数据库原理及应用】期末复习汇总高校期末真题试卷02

试卷 一、填空题 数据库系统是指计算机系统中引入数据库后的系统,一般由数据库、________、应用系统、数据库管理员和用户构成。当数据库的存储结构发生了改变,由数据库管理员对________映象作相应改变,可以使________保持不变,…

vue快速入门(五十一)历史模式

注释很详细,直接上代码 上一篇 新增内容 历史模式配置方法 默认哈希模式,历史模式与哈希模式在表层的区别是是否有/#/ 其他差异暂不深究 源码 //导入所需模块 import Vue from "vue"; import VueRouter from "vue-router"; import m…

从零开始学AI绘画,万字Stable Diffusion终极教程(一)

【第1期】SD入门 2022年8月,一款叫Stable Diffusion的AI绘画软件开源发布,从此开启了AIGC在图像上的爆火发展时期 率先学会SD的人,已经挖掘出了越来越多AI绘画有趣的玩法 从开始的AI美女、线稿上色、真人漫改、头像壁纸 到后来的AI创意字、AI…

华为eNSP小型园区网络配置(上)

→跟着大佬学习的b站直通车← 目标1:dhcp分配ip地址 目标2:内网用户访问www.yzy.com sw1 # vlan batch 10 # interface Ethernet0/0/1port link-type accessport default vlan 10 # interface Ethernet0/0/2port link-type trunkport trunk allow-pass…

oracle pl/sql 如何让sql windows 显示行号

oracle pl/sql 如何让sql windows 显示行号 下载最新版的pl/sql第一步,在preferences中对sql Windows进行设置,如下所示第二步,在preferences中对User interface进行设置,如下所示结果如下当然,还可以通过右键选择是否…

iptables---防火墙

防火墙介绍 防火墙的作用可以理解为是一堵墙,是一个门,用于保护服务器安全的。 防火墙可以保护服务器的安全,还可以定义各种流量匹配的规则。 防火墙的作用 防火墙具有对服务器很好的保护作用,入侵者必须穿透防火墙的安全防护…

【大模型学习】私有大模型部署(基础知识)

私有大模型 优点 保护内部隐私 缺点 成本昂贵 难以共享 难以更新 大模型底座 基础知识点 知识库 知识库是什么? 知识库的作用是什么? 微调 增强大模型的推理能力 AI Agent 代理,与内部大模型进行交互 开源 and 闭源 是否可以查…

二叉树的实现(详解,数据结构)

目录 一,二叉树需要实现的功能 二,下面是各功能详解 0.思想: 1.创建二叉树结点: 2.通过前序遍历的数组"ABD##E#H##CF##G##"构建二叉树 3.二叉树销毁: 4.前序遍历: 5.中序遍历:…

QT5之布局操作

目录 实验之前的前提 局部布局和整体布局定义 快捷工具 水平和垂直布局 水平布局 在对象区域可以看出三个已经被水平布局在一起 在对象区域选中布局,点击工具取消当前布局 可以将两个小局部进行大局部布局 网格布局 弹簧布局 分割器布局 器件对齐边距 也…

Java Map集合(一)

1. Map接口 1.1 Map接口概述 Map接口是一种双列集合。Map的每个元素都包含一个键对象Key和一个值对象Value ,键对象和值对象之间存在对应关系,这种关系称为映射(Mapping)。 Map接口中的元素,可以通过 key 找到 value&…