【搭建 Transformer】

news2025/6/6 0:39:42

搭建 Transformer 的基本步骤

Transformer 是一种基于自注意力机制的深度学习模型,广泛应用于自然语言处理任务。以下为搭建 Transformer 的关键步骤和代码示例。

自注意力机制

自注意力机制是 Transformer 的核心,计算输入序列中每个元素与其他元素的关联度。公式如下:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
其中,$Q$ 为查询矩阵,$K$ 为键矩阵,$V$ 为值矩阵,$d_k$ 为键的维度。

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (0.5)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )
        return self.fc_out(out)

多头注意力

多头注意力通过并行计算多个自注意力头,增强模型的表达能力。

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, mask):
        attention = self.attention(x, x, x, mask)
        x = self.dropout(self.norm(attention + x))
        return x

前馈神经网络

前馈神经网络用于进一步处理自注意力层的输出。

class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_dim):
        super(FeedForward, self).__init__()
        self.ff = nn.Sequential(
            nn.Linear(embed_size, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_size),
        )
        self.norm = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = self.ff(x)
        x = self.dropout(self.norm(out + x))
        return x

编码器层

编码器层由多头注意力和前馈神经网络组成。

class EncoderLayer(nn.Module):
    def __init__(self, embed_size, heads, ff_dim):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.ff = FeedForward(embed_size, ff_dim)

    def forward(self, x, mask):
        x = self.attention(x, mask)
        x = self.ff(x)
        return x

解码器层

解码器层包含掩码多头注意力、编码器-解码器注意力和前馈神经网络。

class DecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, ff_dim):
        super(DecoderLayer, self).__init__()
        self.masked_attention = MultiHeadAttention(embed_size, heads)
        self.attention = MultiHeadAttention(embed_size, heads)
        self.ff = FeedForward(embed_size, ff_dim)

    def forward(self, x, enc_out, src_mask, trg_mask):
        x = self.masked_attention(x, trg_mask)
        x = self.attention(enc_out, src_mask)
        x = self.ff(x)
        return x

完整 Transformer

整合编码器和解码器,构建完整的 Transformer 模型。

class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        embed_size=512,
        num_layers=6,
        heads=8,
        ff_dim=2048,
        max_len=100,
    ):
        super(Transformer, self).__init__()
        self.encoder_embed = nn.Embedding(src_vocab_size, embed_size)
        self.decoder_embed = nn.Embedding(trg_vocab_size, embed_size)
        self.pos_embed = PositionalEncoding(embed_size, max_len)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(embed_size, heads, ff_dim) for _ in range(num_layers)]
        )
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(embed_size, heads, ff_dim) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)

    def forward(self, src, trg, src_mask, trg_mask):
        src_embed = self.pos_embed(self.encoder_embed(src))
        trg_embed = self.pos_embed(self.decoder_embed(trg))

        for layer in self.encoder_layers:
            src_embed = layer(src_embed, src_mask)

        for layer in self.decoder_layers:
            trg_embed = layer(trg_embed, src_embed, src_mask, trg_mask)

        return self.fc_out(trg_embed)

位置编码

位置编码用于注入序列的位置信息。

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.shape[1], :]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2398145.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HCIP(BGP综合实验)

一、实验拓扑 AS 划分: AS1:R1(环回 L0:172.16.0.1/32,L1:192.168.1.0/24)AS2:R2、R3、R4、R5、R6、R7(内部运行 OSPF,AS 号为 64512 和 64513 的联盟)AS3:R…

Attention Is All You Need (Transformer) 以及Transformer pytorch实现

参考https://zhuanlan.zhihu.com/p/569527564 Attention Is All You Need (Transformer) 是当今深度学习初学者必读的一篇论文。 一. Attention Is All You Need (Transformer) 论文精读 1. 知识准备 机器翻译,就是将某种语言的一段文字翻译成另一段文字。 由…

uniapp+vue2+uView项目学习知识点记录

持续更新中... 1、发送给朋友,分享到朋友圈功能开启 利用onShareAppMessage和onShareTimeline生命周期函数,在script中与data同级去写 // 发送给朋友 onShareAppMessage() {return {title: 清清前端, // 分享标题path: /pages/index/index, // 分享路…

精美的软件下载页面HTML源码:现代UI与动画效果的完美结合

精美的软件下载页面HTML源码:现代UI与动画效果的完美结合 在数字化产品推广中,一个设计精良的下载页面不仅能提升品牌专业度,还能显著提高用户转化率。本文介绍的精美软件下载页面HTML源码,通过现代化UI设计与丰富的动画效果&…

车载诊断架构 --- DTC消抖参数(Trip Counter DTCConfirmLimit )

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 做到欲望极简,了解自己的真实欲望,不受外在潮流的影响,不盲从,不跟风。把自己的精力全部用在自己。一是去掉多余,凡事找规律,基础是诚信;二是…

javaEE->IO:

文件: 操作系统中会把很多 硬件设备 和 软件资源 抽象成“文件”,统一进行管理。 大部分谈到的文件,都是指 硬盘的文件,文件就相当于是针对“硬盘”数据的一种抽象 硬盘: 1.机械硬盘:便宜 2.固态硬盘&…

Oracle 用户/权限/角色管理

1. 用户 1.1. 用户的创建和删除 1.1.1. 创建用户 create user user identified {by password | externally} [ default tablespace tablespace ] [ temporary tablespace tablespace ] [ quota {integer [k | m ] | unlimited } on tablespace [ quota {integer [k | m ] | …

工厂方法模式深度解析:从原理到应用实战

作者简介 我是摘星,一名全栈开发者,专注 Java后端开发、AI工程化 与 云计算架构 领域,擅长Python技术栈。热衷于探索前沿技术,包括大模型应用、云原生解决方案及自动化工具开发。日常深耕技术实践,乐于分享实战经验与…

python可视化:端午假期旅游火爆原因分析

python可视化:端午假期旅游火爆原因分析 2025年的旅游市场表现强劲: 2025年端午假期全社会跨区域人员流动量累计6.57亿人次,日均2.19亿人次,同比增长3.0%。入境游订单同比大涨近90%,门票交易额(GMV&#…

SOC-ESP32S3部分:28-BLE低功耗蓝牙

飞书文档https://x509p6c8to.feishu.cn/wiki/CHcowZMLtiinuBkRhExcZN7Ynmc 蓝牙是一种短距的无线通讯技术,可实现固定设备、移动设备之间的数据交换,下图是一个蓝牙应用的分层架构,Application部分则是我们需要实现的内容,Protoc…

Git-flow流

Git git是版本控制软件,一般用来做代码版本控制 github是一个免费版本控制仓库是国内外很多开源项目的集中地,其本体是一个git服务器 Git初始化操作 git init 初始化仓库 git status 查看当前仓库的状态 git add . 将改动的文件加到暂存区 gi…

VirtualBox给Rock Linux9.x配置网络

写这篇文章之前,先说明一下,我参考的是我之前写的《VirtualBox Linux网络配置》 我从CentOS7转到了Rock9,和配置Centos7一样,主流程没有变化,变化的是Rock9.x中的配置文件和使用的命令。 我再说一次,因为主…

知识图谱增强的大型语言模型编辑

https://arxiv.org/pdf/2402.13593 摘要 大型语言模型(LLM)是推进自然语言处理(NLP)任务的关键,但其效率受到不准确和过时知识的阻碍。模型编辑是解决这些挑战的一个有前途的解决方案。然而,现有的编辑方法…

.NET 原生驾驭 AI 新基建实战系列(一):向量数据库的应用与畅想

在当今数据驱动的时代,向量数据库(Vector Database)作为一种新兴的数据库技术,正逐渐成为软件开发领域的重要组成部分。特别是在 .NET 生态系统中,向量数据库的应用为开发者提供了构建智能、高效应用程序的新途径。 一…

【claude+deepseek+gemini】基于李群李代数和螺旋理论工业机器人控制系统软件UI设计

claude的首次设计html是最佳的。之后让deepseek和gemini根据claude的UI设计进行改进设计。。。当然可以尝试很多次,也可以让他们之间来回不断改进…… claude deepseek-r1 0528 上图为deepseek首次设计,下面为改进设计 …… Gemini 2.5 Pro 0506 &#x…

阿里云国际站,如何通过代理商邀请的链接注册账号

阿里云国际站:如何通过代理商邀请链接注册,解锁“云端超能力”与专属福利? 渴望在全球化浪潮中抢占先机?想获得阿里云国际站的海量云资源、遍布全球的加速节点与前沿AI服务,同时又能享受专属折扣、VIP级增值服务支持或…

乾坤qiankun的使用

vue2 为主应用 react 为子应用 在项目中安装乾坤 yarn add qiankun # 或者 npm i qiankun -Svue主应用 在main.js中新增 (需要注意的是路由模型为history模式) registerMicroApps([{name: reactApp,entry: //localhost:3011,container: #container,/…

PH热榜 | 2025-06-03

1. Knowledge 标语:像认识朋友一样去销售给潜在客户,因为你其实了解他们! 介绍:Knowledge 是一个针对个人的销售智能平台,它利用行为数据和心理测评来识别市场上的潜在买家,并指导销售团队以最真实、最有…

论文略读: STREAMLINING REDUNDANT LAYERS TO COMPRESS LARGE LANGUAGE MODELS

2025 ICLR 判断模型层的重要性->剪去不重要的层(用轻量网络代替) 这种方法只减少了层数量,所以可以用常用的方法加载模型 层剪枝阶段 通过输入与输出的余弦相似度来判断各个层的重要性 具有高余弦相似度的层倾向于聚集在一起&#xff0c…

mapbox高阶,生成并加载等时图

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️Fill面图层样式1.4 ☘️symbol符号图层…