人工智能(Pytorch)搭建T5模型,真正跑通T5模型,用T5模型生成数字加减结果

news2025/6/21 1:09:43

大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建T5模型,真正跑通T5模型,用T5模型生成数字加减结果。T5(Text-to-Text Transfer Transformer)是一种由Google Brain团队在2019年提出的自然语言处理模型。T5模型基于Transformer结构,可以执行多种自然语言任务,如翻译、摘要、问答、文本生成等。它不同于其他模型的地方在于,T5模型采用了统一的输入输出格式,并通过微调来适应不同的任务。

一、T5模型优势

T5模型基于Transformer结构,其训练方式是无监督的。首先将大量的文本数据输入到模型中进行预训练,使得模型学习到了输入和输出之间的对应关系。而后,再利用有标注的数据对模型进行微调,以适应具体的任务需求。与其他自然语言处理模型相比,T5具备以下优势:

多任务学习能力强:同一个模型可以执行多种自然语言任务,只需要使用不同的微调方法即可。

零样本学习能力强:T5模型可以利用已有知识完成类似但未曾见过的任务。

表示能力强:T5模型可以捕获多种语义信息,并且可以用较少的参数来达到很好的性能。

T5模型与传统语言模型不同,T5并不仅仅只是单纯地预测下一个词或者下一个句子,它可以直接生成完整的文本、回答问题或者进行翻译等多种任务。这归功于T5所使用的“Text-to-Text”框架,即将所有的自然语言处理任务转化为输入和输出都是文本的问题。

T5模型的训练过程中,使用了一种被称为“预训练+微调”的方法。首先,对大规模的数据集进行预训练,学习通用的语言表示;然后,使用少量的目标任务数据对模型进行“微调”,针对具体任务进行优化。T5模型还使用了一种名为“Span Extraction”的技术,使得模型可以在输出序列中找到最相关的片段,从而更好地完成各种任务。同时,T5模型使用了“Token Dropout”机制来避免过拟合,并采用了许多其他技巧,包括多层解码器、混合精度训练等。

二、代码实现

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, AdamW
import random

# 自己定义输入的数据
class MyDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_text, target_text = self.data[idx]
        inputs = self.tokenizer(input_text, return_tensors="pt", padding="max_length", max_length=self.max_length, truncation=True)
        targets = self.tokenizer(target_text, return_tensors="pt", padding="max_length", max_length=self.max_length, truncation=True)
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": targets["input_ids"].squeeze()
        }


# 指定使用的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 使用 T5-small 模型。 如果想用更大的模型,可以将 "t5-small" 替换为例如 "t5-base" 或 "t5-large"
model_name = "t5-small"
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = T5Tokenizer.from_pretrained(model_name)

# 构造训练数据
data = []
for i in range(800):
    a = random.randint(0, 9)
    b = random.randint(0, 9)
    op = random.choice(["+", "-"])
    if op == "+":
        result = a + b
    else:
        result = a - b
    data.append((f"{a} {op} {b}", str(result)))


max_length = 32
train_dataset = MyDataset(data, tokenizer, max_length)

#train_dataset = MyDataset(data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# 训练模型
num_epochs = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch: {epoch}, Loss: {loss.item():.6f}")

# 测试训练后的模型
model.eval()
input_text = "1 + 1"
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(device)
outputs = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"])
result = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Model prediction: {input_text} = {result}")

# 保存模型
model.save_pretrained("trained_t5")
tokenizer.save_pretrained("trained_t5")

以上是T5ForConditionalGeneration模型,利用简单的训练数据进行训练。后面需要添加更多的数据以实现更高的准确率。

训练完的模型和tokenizer会被保存到名为trained_t5的文件夹中。
使用较有限的数据进行较少的训练轮次可能导致较低的准确率,

我这里给大家提供了模型训练和验证的基本方法。如果需要提高模型的性能,可以尝试使用更大的数据集、增加训练轮次以及在设备性能允许的情况下试验更大的T5模型。后续文本生成就是利用这个模型训练的,也可以训练做类似ChatGPT类似的聊天机器人。

大家继续期待吧,后面还有更多的模型使用技巧贡献出来!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/411447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

商办楼宇租赁过程中的风险与管控

在商办地产租赁市场持续高量供应、越来越多楼盘趋向同质化的背景下,商办地产经营需更懂得审时度势,在租赁经营过程中合理运用数字化管理识别、规避风险,针对有风险的经营及时调整管控,提升识别及防范风险的意识和能力,…

数据结构:链表oj下

21. 合并两个有序链表 CM11 链表分割 不加36行代码会造成死循环(形成环) OR36 链表的回文结构 找到中间节点,再把后面的逆置 走完一个链表while(tail) 找链表的最后一个节点while(tail->next) 160. 相交链表 找到AB链表的尾节点&#x…

Python 小型项目大全 6~10

六、凯撒密码 原文:http://inventwithpython.com/bigbookpython/project6.html 凯撒密码是朱利叶斯凯撒使用的一种古老的加密算法。它通过将字母在字母表中移动一定的位置来加密字母。我们称移位的长度为密钥。比如,如果密钥是 3,那么A变成D&…

Linux-初学者系列——篇幅2_系统命令界面

系统命令界面-目录一、命令行提示符1、提示符2、提示符组成3、提示符修改二、系统命令语法规范三、系统命令行常用快捷键1、常用快捷键2、移动光标快捷键3、剪切、粘贴、清楚快捷键4、系统管理控制快捷键5、重复执行命令快捷键上篇: Linux-初学者系列——篇幅1_文件管理命令 一…

Python 小型项目大全 36~40

三十六、沙漏 原文:http://inventwithpython.com/bigbookpython/project36.html 这个可视化程序有一个粗糙的物理引擎,模拟沙子通过沙漏的小孔落下。沙子堆积在沙漏的下半部分;然后把沙漏翻过来,重复这个过程。 运行示例 图 36-…

手写vuex4源码(八)插件机制实现

一、插件的使用 Vuex 不仅提供了全局状态管理能力,还进一步提供了插件机制,便于开发者对 Vuex 插件进行增强; Vuex 插件的使用方式:通过 Store 类提供的 plugin 数组进行 Vuex 插件注册: export default createStor…

ModuleNotFoundError:No module named “te_fusion“

Asecend Tensor Compiler简称ATC,昇腾张量编译器,主要是将开源框架的网络模型或Ascend IR定义的单算子描述文件(json格式)转换为昇腾AI处理器支持的om格式 场景描述:ONNXRuntime调用CANN后端执行时,报了没有找到te_fusion的error&#xff0c…

多项式特征应用案例

多项式特征应用案例 描述 对于线性模型而言,扩充数据的特征(即对原特征进行计算,增加新的特征列)通常是提升模型表现的可选方法,Scikit-learn提供了PolynomialFeatures类来增加多项式特征(polynomial fea…

JavaScript【一】JavaScript变量与数据类型

文章目录🌟前言🌟变量🌟 变量是什么?🌟 变量提升🌟 声明变量🌟 JavaScript有三种声明方式🌟 命名规范🌟 注意🌟数据类型以及运算🌟 检测变量数据类…

【Linux】System V IPC-命名管道共享内存消息队列

System V IPC-命名管道&共享内存&消息队列命名管道共享内存创建共享内存附加和分离共享内存消息队列消息队列的接口命名管道 使用mkfifo命令,创建一个命名管道,通过ll可以查看当前命名管道的类型 p类型,也就是pipe管道类型。 之前我…

Docker镜像之Docker Compose讲解

文章目录1 docker-compose1.1 compose编排工具简介1.2 安装docker-compose1.3 编排启动镜像1.4 haproxy代理后端docker容器1.5 安装socat 直接操作socket控制haproxy1.6 compose中yml 配置指令参考1.6.1 简单命令1.6.2 build1.6.3 depends_on1.6.4 deploy1.6.5 logging1.6.6 ne…

【C++进阶】01:概述

概述 OVERVIEW概述C11新特性:C14新特性:C17新特性:C20新特性:C程序编译过程C内存模型CSTL1.Queue&Stack2.String3.MapC语言C语言面向过程编程✅✅面向对象编程(类和对象)❌✅泛型编程、元编程&#xff…

基于PaddlePaddle的图片分类实战 | 深度学习基础任务教程系列

图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,图像分类是根据图像的语义信息将不同类别图像区分开来,是图像检测、图像分割、物体跟踪、行为分析等其他高层视觉任务的基础。图像分类在安防、交通、互联网、医学等领域有着广泛的应用。 一般…

LeetCode:77. 组合——回溯法,是暴力法?

🍎道阻且长,行则将至。🍓 🌻算法,不如说它是一种思考方式🍀算法专栏: 👉🏻123 一、🌱77. 组合 题目描述:给定两个整数 n 和 k,返回范…

风场数据抓取程序实现(java+python实现)

一、数据源参数定义 关键参数代码: package com.grab.catchWindData.pram;/*** ClassName: DevPrams* Description: TODO**/ public class DevPrams {public static String lev_0to0p1_m_below_ground "lev_0-0.1_m_below_ground";public static Stri…

【微服务笔记08】微服务组件之Hystrix实现请求合并功能

这篇文章,主要介绍微服务组件之Hystrix实现请求合并功能。 目录 一、Hystrix请求合并 1.1、什么是请求合并 1.2、请求合并的实现 (1)引入依赖 (2)编写服务提供者 (3)消费者(Se…

React | 认识React开发

✨ 个人主页:CoderHing 🖥️ Node.js专栏:Node.js 初级知识 🙋‍♂️ 个人简介:一个不甘平庸的平凡人🍬 💫 系列专栏:吊打面试官系列 16天学会Vue 11天学会React Node专栏 &#x…

【分享】免梯子的GPT,玩 ChatGPT 的正确姿势

火了一周的 ChatGPT,HG 不允许还有小伙伴不知道这个东西是什么?简单来说就是,你可以让它扮演任何事物,据说已经有人用它开始了颜色文学创作。因为它太火了,所以,本周特推在几十个带有“chatgpt”的项目中选…

双交叉注意学习用于细粒度视觉分类和目标重新识别

目录Dual Cross-Attention Learning for Fine-Grained Visual Categorization and Object Re-Identification摘要本文方法消融实验Dual Cross-Attention Learning for Fine-Grained Visual Categorization and Object Re-Identification 摘要 目的: 探索了如何扩展…

JDK8——新增时间类、有关时间数据的交互问题

目录 一、实体类 二、数据库 三、数据交换 四、关于LocalDateTime类型 (java 8) 4.1 旧版本日期时间问题 4.2 新版日期时间API介绍 4.2.1 LocalDate、LocalTime、LocalDateTime 4.2.2 日期时间的修改与比较 4.2.3 格式化和解析操作 4.2.4 Instant: 时间戳 4.2.5 Duration 与…