用 PyTorch 从零实现简易GPT(Transformer 模型)

news2025/5/19 7:30:31

用 PyTorch 从零实现简易GPT(Transformer 模型)

本文将结合示例代码,通俗易懂地拆解大模型(Transformer)从数据预处理到推理预测的核心组件与流程,并通过 Mermaid 流程图直观展示整体架构。文章结构分为四层,层次清晰,帮助读者系统掌握大模型原理与实战。


1 引言

大模型(如 GPT、BERT 等)之所以强大,得益于其背后多层自注意力和前馈网络的有机结合。本文以极简版中文 Transformer 为例,从最基础的数据准备到完整训练与推理过程,逐步剖析每个核心环节,让零基础读者也能轻松理解大模型的工作原理,并动手复现。


2 大模型核心组件概览

2.1 整体架构流程图

请添加图片描述

流程图中,数据层负责将中文句子转换为 ID 序列;模型层依次执行词嵌入、位置编码、编码器堆叠及线性投影;任务层完成训练(损失计算与优化)与推理(生成预测)。


3 模块详解

3.1 数据准备

3.1.1 原始文本与分词
  1. 原始中文句子
    示例中我们定义了三句极简中文:

    sentences = [
        "我今天去公园",
        "公园里有很多树",
        "树上有小鸟"
    ]
    

    每个字视为一个 token,无需额外分词工具。

3.1.2 构建词表
  1. 按首次出现顺序构建字表

    chars = []; seen = set()
    for s in sentences:
        for c in s:
            if c not in seen:
                seen.add(c); chars.append(c)
    char2idx = {char: idx+2 for idx,char in enumerate(chars)}
    char2idx["<pad>"] = 0; char2idx["<unk>"] = 1
    idx2char = {v:k for k,v in char2idx.items()}
    vocab_size = len(char2idx)
    
    • <pad>:填充符,用于对齐;
    • <unk>:未知符,用于未登录字符。
3.1.3 生成输入-目标对
  1. 滑动窗口生成序列

    def create_sequences(data, seq_length=3):
        inputs, targets = [], []
        for seq in data:
            for i in range(len(seq)-seq_length):
                inputs.append(seq[i:i+seq_length])
                targets.append(seq[i+1:i+1+seq_length])
        return torch.tensor(inputs), torch.tensor(targets)
    inputs, targets = create_sequences(data, seq_length=3)
    
    • 输入:连续 n 个字的 ID;
    • 目标:右移一位后的 n 个字,用于语言模型预测。

3.2 位置编码(Positional Encoding)

  1. 为什么需要位置编码?
    自注意力机制本身对序列顺序不敏感,必须加入显式位置编码以保留顺序信息。

  2. 实现思路

    class PositionalEncoding(nn.Module):
        def __init__(self, d_model, max_len=5000):
            super().__init__()
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len).unsqueeze(1).float()
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
            pe[:,0::2] = torch.sin(position * div_term)
            pe[:,1::2] = torch.cos(position * div_term)
            self.register_buffer('pe', pe.unsqueeze(0))
        def forward(self, x):
            return x + self.pe[:, :x.size(1), :]
    
    • 正弦/余弦:不同频率编码,不同维度交替使用 sin/cos;
    • 注册 buffer:在模型保存/加载时自动携带,不参与梯度更新。

3.3 Transformer 编码器

3.3.1 模型整体定义
class ChineseTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=32, nhead=4, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len=20)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=128, dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embed(x)                     # (batch, seq_len) -> (batch, seq_len, d_model)
        x = self.pos_enc(x)                   # 加入位置编码
        x = x.permute(1,0,2)                  # 转换为 (seq_len, batch, d_model)
        x = self.transformer(x)               # 多层编码器堆叠
        x = x.permute(1,0,2)                  # 恢复 (batch, seq_len, d_model)
        return self.fc(x)                     # 线性投影到 vocab_size
3.3.2 嵌入层(Embedding)
  • 将离散的 token ID 映射到连续空间中的向量;
  • nn.Embedding(vocab_size, d_model):可训练的查表操作。
3.3.3 多头自注意力(Multi-Head Attention)
  • 每个注意力头关注不同子空间;
  • nn.TransformerEncoderLayer 内部集成了多头注意力与残差连接。
3.3.4 前馈网络(Feed-Forward Network)
  • 两层线性变换 + 激活 + Dropout;
  • 扩展后维度 dim_feedforward,再投回 d_model

3.4 训练与推理

3.4.1 损失函数与优化器
model = ChineseTransformer(vocab_size)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  • 忽略填充位置 <pad>
  • Adam 优化器自适应学习率。
3.4.2 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    model.train(); optimizer.zero_grad()
    output = model(inputs)
    loss = criterion(output.view(-1, vocab_size), targets.view(-1))
    loss.backward(); optimizer.step()
    if (epoch+1) % 20 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
  • 展开:将 (batch, seq_len, vocab_size) 转为二维,匹配交叉熵接口;
  • 定期打印:监控训练动态。
3.4.3 推理预测
def predict_next(text, model, temperature=1.0):
    model.eval()
    with torch.no_grad():
        input_ids = torch.tensor([char2idx.get(c,1) for c in text[-3:]]).unsqueeze(0)
        output = model(input_ids)
        logits = output[0,-1,:] / temperature
        probs = torch.softmax(logits, dim=-1)
        return idx2char[torch.argmax(probs).item()]
  • Temperature:温度系数调整采样分布;
  • 贪心选择:直接取最大概率。

测试示例:

test_cases = ["我今天", "公园里", "树上有"]
for case in test_cases:
    print(f"输入 '{case}' → 预测下一个字: '{predict_next(case, model)}'")

4 完整示例代码

import torch
import torch.nn as nn
import math

###############################
# 1. 准备中文训练数据(极简示例)
###############################
# 定义3个简单中文句子(每个字为一个token)
sentences = [
    "我今天去公园",  # 拆分为 ['我', '今', '天', '去', '公', '园']
    "公园里有很多树", # 拆分为 ['公', '园', '里', '有', '很', '多', '树']
    "树上有小鸟"     # 拆分为 ['树', '上', '有', '小', '鸟']
]

# 构建有序词表(按首次出现顺序)
chars = []  # 初始化空列表
seen = set()  # 用于去重
for s in sentences:
    for c in s:
        if c not in seen:
            seen.add(c)
            chars.append(c)  # 保留首次出现顺序


char2idx = {char: idx+2 for idx, char in enumerate(chars)}  # id从2开始,0和1留给特殊符号
char2idx["<pad>"] = 0   # 填充符
char2idx["<unk>"] = 1   # 未知符
vocab_size = len(char2idx)  # 词表大小
idx2char = {v: k for k, v in char2idx.items()}  # 反向映射

# 将句子转为索引序列
def text_to_ids(text):
    return [char2idx.get(c, 1) for c in text]  # 未知字用<unk>代替

data = [text_to_ids(s) for s in sentences]

###############################
# 2. 位置编码(同上)
###############################
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

###############################
# 3. 定义Transformer模型(微调参数)
###############################
class ChineseTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 32, nhead: int = 4, num_layers: int = 2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len=20)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=128,
            dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embed(x)  # (batch, seq_len) -> (batch, seq_len, d_model)
        x = self.pos_enc(x)
        x = x.permute(1, 0, 2)  # 调整为(seq_len, batch, d_model)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # 恢复为(batch, seq_len, d_model)
        x = self.fc(x)
        return x

###############################
# 4. 数据预处理(生成输入-目标对)
###############################
def create_sequences(data, seq_length=3):
    inputs, targets = [], []
    for seq in data:
        for i in range(len(seq) - seq_length):
            inputs.append(seq[i:i+seq_length])    # 输入:前n个字
            targets.append(seq[i+1:i+1+seq_length]) # 目标:后n个字(右移一位)
    return torch.tensor(inputs), torch.tensor(targets)

# 生成训练数据(序列长度设为3)
inputs, targets = create_sequences(data, seq_length=3)
print("示例输入-目标对:")
print("输入:", [idx2char[i.item()] for i in inputs[0]], "→ 目标:", [idx2char[t.item()] for t in targets[0]])

###############################
# 5. 训练模型
###############################
model = ChineseTransformer(vocab_size=vocab_size)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充位置
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 简单训练循环(仅演示1个epoch)
# model.train()
# optimizer.zero_grad()
#
# output = model(inputs)  # (batch_size, seq_len=3, vocab_size)
# loss = criterion(output.view(-1, vocab_size), targets.view(-1))
# loss.backward()
# optimizer.step()
#
# print(f"训练损失: {loss.item():.4f}")


num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # 前向传播
    output = model(inputs)
    # 计算损失
    loss = criterion(output.view(-1, vocab_size), targets.view(-1))

    # 反向传播
    loss.backward()
    optimizer.step()

    # 每100次打印损失
    if (epoch + 1) % 20 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")


###############################
# 6. 推理测试:预测下一个字
###############################
def predict_next(text, model, temperature=1.0):
    model.eval()
    with torch.no_grad():
        # 将输入文本转为索引
        input_ids = torch.tensor([char2idx.get(c, 1) for c in text[-3:]], dtype=torch.long).unsqueeze(0)
        # 预测
        output = model(input_ids)  # (1, seq_len, vocab_size)
        next_token_logits = output[0, -1, :] / temperature
        probs = torch.softmax(next_token_logits, dim=-1)
        next_token_id = torch.argmax(probs).item()
        return idx2char[next_token_id]

# 测试预测
test_cases = ["我今天", "公园里", "树上有"]
for case in test_cases:
    predicted = predict_next(case, model)
    print(f"输入 '{case}' → 预测下一个字: '{predicted}'")

实测结果验证正确,如下图:

请添加图片描述


5 总结

  1. 通俗易懂:每个模块拆解为小步骤,结合代码示例加深理解;
  2. 实战演练:完整代码可直接运行,快速上手 Transformer 中文建模。

至此,读者已掌握大模型从数据到预测的全流程原理。欢迎在此基础上拓展更多高级功能,共同学习不断进阶!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2379119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【通用大模型】Serper API 详解:搜索引擎数据获取的核心工具

Serper API 详解&#xff1a;搜索引擎数据获取的核心工具 一、Serper API 的定义与核心功能二、技术架构与核心优势2.1 技术实现原理2.2 对比传统方案的突破性优势 三、典型应用场景与代码示例3.1 SEO 监控系统3.2 竞品广告分析 四、使用成本与配额策略五、开发者注意事项六、替…

Spring3+Vue3项目中的知识点——JWT

全称&#xff1a;JOSN Web Token 定义了一种简洁的、自包含的格式&#xff0c;用于通信双方以json数据格式的安全传输信息 组成&#xff1a; 第一部分&#xff1a;Header&#xff08;头&#xff09;&#xff0c;记录令牌类型、签名算法等。 第二部分&#xff1a;Payload&am…

python3GUI--智慧交通分析平台:By:PyQt5+YOLOv8(详细介绍)

文章目录 一&#xff0e;前言二&#xff0e;效果预览1.目标识别与检测2.可视化展示1.车流量统计2. 目标类别占比3. 拥堵情况展示4.目标数量可视化 3.控制台4.核心内容区1.目标检测参数2.帧转QPixmap3.数据管理 5.项目结构 三&#xff0e;总结 平台规定gif最大5M&#xff0c;所以…

Linux任务管理与守护进程

一、任务管理 &#xff08;一&#xff09;进程组、作业、会话概念 &#xff08;1&#xff09;进程组概念&#xff1a;进程组是由一个或多个进程组成的集合&#xff0c;这些进程在某些方面具有关联性。在操作系统中&#xff0c;进程组是用于对进程进行分组管理的一种机制。每个…

C#里与嵌入式系统W5500网络通讯(2)

在嵌入式代码里,需要从嵌入式的MCU访问W5500芯片。 这个是通过SPI通讯来实现的,所以要先连接SPI的硬件通讯线路。 接着下来,就是怎么样访问这个芯片了。 要访问这个芯片,需要通过SPI来发送数据,而发送数据又要有一定的约定格式, 于是芯片厂商就定义下面的通讯格式: …

EMQX开源版安装指南:Linux/Windows全攻略

EMQX开源版安装教程-linux/windows 因最近自己需要使用MQTT&#xff0c;需要搭建一个MQTT服务器&#xff0c;所以想到了很久以前用到的EMQX。但是当时的EMQX使用的是开源版的&#xff0c;在官网可以直接下载。而现在再次打开官网时发现怎么也找不大开源版本了&#xff0c;所以…

【计算机视觉】OpenCV实战项目:GraspPicture 项目深度解析:基于图像分割的抓取点检测系统

GraspPicture 项目深度解析&#xff1a;基于图像分割的抓取点检测系统 一、项目概述项目特点 二、项目运行方式与执行步骤&#xff08;一&#xff09;环境准备&#xff08;二&#xff09;项目结构&#xff08;三&#xff09;执行步骤 三、重要逻辑代码解析&#xff08;一&#…

MySQL 数据库备份与还原

作者&#xff1a;IvanCodes 日期&#xff1a;2025年5月18日 专栏&#xff1a;MySQL教程 思维导图 备份 (Backup) 与 冗余 (Redundancy) 的核心区别: &#x1f3af; 备份是指创建数据的副本并将其存储在不同位置或介质&#xff0c;主要目的是在发生数据丢失、损坏或逻辑错误时进…

Kubernetes控制平面组件:Kubelet详解(四):gRPC 与 CRI gRPC实现

云原生学习路线导航页&#xff08;持续更新中&#xff09; kubernetes学习系列快捷链接 Kubernetes架构原则和对象设计&#xff08;一&#xff09;Kubernetes架构原则和对象设计&#xff08;二&#xff09;Kubernetes架构原则和对象设计&#xff08;三&#xff09;Kubernetes控…

javax.servlet.Filter 介绍-笔记

1.javax.servlet.Filter 简介 javax.servlet.Filter 是 Java Servlet API 中的一个核心接口&#xff0c;用于在请求到达目标资源&#xff08;如 Servlet 或 JSP&#xff09;之前或响应返回给客户端之前执行预处理或后处理操作。它常用于实现与业务逻辑无关的通用功能&#xff…

Win 11开始菜单图标变成白色怎么办?

在使用windows 11的过程中&#xff0c;有时候开始菜单的某些程序图标变成白色的文件形式&#xff0c;但是程序可以正常打开&#xff0c;这个如何解决呢&#xff1f; 这通常是由于快捷方式出了问题&#xff0c;下面跟着操作步骤来解决吧。 1、右键有问题的软件&#xff0c;打开…

入门OpenTelemetry——应用自动埋点

埋点 什么是埋点 埋点&#xff0c;本质就是在你的应用程序里&#xff0c;在重要位置插入采集代码&#xff0c;比如&#xff1a; 收集请求开始和结束的时间收集数据库查询时间收集函数调用链路信息收集异常信息 这些埋点数据&#xff08;Trace、Metrics、Logs&#xff09;被…

C语言链表的操作

初学 初学C语言时&#xff0c;对于链表节点的定义一般是这样的&#xff1a; typedef struct node {int data;struct node *next; } Node; 向链表中添加节点&#xff1a; void addNode(Node **head, int data) {Node *newNode (Node*)malloc(sizeof(Node));newNode->dat…

芯片生态链深度解析(二):基础设备篇——人类精密制造的“巅峰对决”

【开篇&#xff1a;设备——芯片工业的“剑与盾”】 当ASML的EUV光刻机以每秒5万次激光脉冲在硅片上雕刻出0.13nm精度的电路&#xff08;相当于在月球表面精准定位一枚二维码&#xff09;&#xff0c;当国产28nm光刻机在华虹产线实现“从0到1”的突破&#xff0c;这场精密制造…

C语言指针深入详解(二):const修饰指针、野指针、assert断言、指针的使用和传址调用

目录 一、const修饰指针 &#xff08;一&#xff09;const修饰变量 &#xff08;二&#xff09;const 修饰指针变量 二、野指针 &#xff08;一&#xff09;野指针成因 1、指针未初始化 2、指针越界访问 3、指针指向的空间释放 &#xff08;二&#xff09;如何规避野指…

【unity游戏开发——编辑器扩展】使用EditorGUI的EditorGUILayout绘制工具类在自定义编辑器窗口绘制各种UI控件

注意&#xff1a;考虑到编辑器扩展的内容比较多&#xff0c;我将编辑器扩展的内容分开&#xff0c;并全部整合放在【unity游戏开发——编辑器扩展】专栏里&#xff0c;感兴趣的小伙伴可以前往逐一查看学习。 文章目录 前言常用的EditorGUILayout控件专栏推荐完结 前言 EditorG…

Linux基础第三天

系统时间 date命令&#xff0c;date中文具有日期的含义&#xff0c;利用该命令可以查看或者修改Linux系统日期和时间。 基本格式如下&#xff1a; gecubuntu:~$ date gecubuntu:~$ date -s 日期时间 // -s选项可以设置日期和时间 文件权限 chmod命令&#xff0c;是英文…

MoodDrop:打造一款温柔的心情打卡单页应用

我正在参加CodeBuddy「首席试玩官」内容创作大赛&#xff0c;本文所使用的 CodeBuddy 免费下载链接&#xff1a;腾讯云代码助手 CodeBuddy - AI 时代的智能编程伙伴 起心动念&#xff1a;我想做一款温柔的情绪应用 「今天的你&#xff0c;心情如何&#xff1f;」 有时候&#x…

接口——类比摄像

最近迷上了买相机&#xff0c;大疆Pocket、Insta Go3、大疆Mini3、佳能50D、vivo徕卡人像大师&#xff08;狗头&#xff09;&#xff0c;在买配件的时候&#xff0c;发现1/4螺口简直是神中之神&#xff0c;这个万能接口让我想到计算机设计中的接口&#xff0c;遂有此篇—— 接…

二十、案例特训专题3【系统设计篇】web架构设计

一、前言 二、内容提要 三、单机到应用与数据分离 四、集群与负载均衡 五、集群与有状态无状态服务 六、ORM 七、数据库读写分离 八、数据库缓存Memcache与Redis 九、Redis数据分片 哈希分片如果新增分片会很麻烦&#xff0c;需要把之前数据取出来再哈希除模 一致性哈希分片是…