深度学习:GPT-2的MindSpore实践

news2025/5/15 16:48:02

GPT-2简介

GPT-2是一个由OpenAI于2019年提出的自回归语言模型。与GPT-1相比,仍基于Transformer Decoder架构,但是做出了一定改进。

模型规格上:

GPT-1有117M参数,为下游微调任务提供预训练模型。

GPT-2显著增加了模型规模,提供了多种模型,如:124M、355M、774M和1.5B

数据集大小上:

GPT-2训练于数据量约有45GB的WebText数据集。数据集的数据收集于Reddit中的网络文章。

模型架构上:

GPT-2维持了GPT-1的Decoder-only架构,但是讲Decoder Block增加至48层,采用了更深层的注意力机制和更大的前馈网络维度并改进了正则化。同时,GPT-2加入了可学习的位置编码。将Layer Norm前置于模型获得输入后。一个额外的Layer Norm被添加于最后一个自注意力Block后。

参数初始化上:

参数初始化上,使用了一个Special Scaled Initialization。Special Scaled Initialization是Xavier Normalization的一种变体,使用了额外的缩放。将因子n调整为残差连接的次数,也就是block数量的两倍。

任务设定

一般语言模型的训练目标设置为:P(output | input)

但是,GPT-2通过同样的无监督模型来完成多个既定任务,学习目标变为:P(output | input, task)

这种修改被称为任务设定。对于同样的输入,模型应该根据不同的任务输出不同的结果。

翻译任务

文本总结

其他下游任务:基于zero-shot或few-shot的文本生成、文本总结、文本翻译、QA问答、文本分类。

基于MindSpore的GPT-2实践

复习Masked Multi Self-Attention

#安装mindnlp 0.4.0套件
!pip install mindnlp==0.4.0
!pip uninstall soundfile -y
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.3.1/MindSpore/unified/aarch64/mindspore-2.3.1-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

假设一个批大小为1,序列长度为10,特征维度为768的输入

 # GPT-2 Masked Self-Attention

# assume an input of self-attention x, input_dim is 768
batch_size, seq_len, embed_dim = 1, 10, 768

x = Tensor(np.random.randn(batch_size, seq_len, embed_dim), mindspore.float32)
x.shape

将输入复制三份作为Q、K、V。

import mindspore.ops as ops
from mindnlp.transformers.ms_utils import Conv1D

# an input will be multipled by three matrixs Wq, Wk, Wv
# concat the three matrixs, you will get a matrix of (768, 768*3)
# x matmul matrix, the output would be (batch_size, seq_len, 768*3)
c_attn = Conv1D(3 * embed_dim, embed_dim)
output = c_attn(x)
# split the output into q, k, v
query, key, value = ops.Split(axis=2, output_num=3)(output)
query.shape, key.shape, value.shape

 将注意力分头

# split self-attention into multi_head attention
def split_heads(tensor, num_heads, attn_head_size):
    '''
    Spilit hidden_size dim into attn_head_size and num_heads
    
    Args:
        tensor: tensor to split
        num_heads: how many heads to split
        attn_head_size: hidden_size of each head
    
    Return:
         Multi-Head tensor
    '''
    new_shape = tensor.shape[:-1] + (num_heads, attn_head_size)
    tensor = tensor.view(new_shape)
    return ops.transpose(tensor, (0, 2, 1, 3))

num_heads = 12
attn_head_size = embed_dim // num_heads

query = split_heads(query, num_heads, attn_head_size)
key = split_heads(key, num_heads, attn_head_size)
value = split_heads(value, num_heads, attn_head_size)

query.shape, key.shape, value.shape

将Q、K相乘,得到注意力分数

# get self-attention score
attn_weights = ops.matmul(query, key.swapaxes(-1, -2))
attn_weights.shape

将注意力分数加上掩码,防止模型看见“未来”的数据

# get mask attn_weighs
max_positions = seq_len
# create a mask matrix
bias = Tensor(np.tril(np.ones((max_positions, max_positions))).reshape((1, 1, max_positions, max_positions)), mindspore.bool_)

# apply Mask Matrix to get masked scores
# this normalization helps stabilize gradients 
# and is common in scaled dot-product attention mechanisms
attn_weights = attn_weights / ops.sqrt(ops.scalar_to_tensor(value.shape[-1]))

query_length, key_length = query.shape[-2], key.shape[-2]
causal_mask = bias[:, :, key_length - query_length: key_length, :key_length].bool()
mask_value = Tensor(np.finfo(np.float32).min, dtype=attn_weights.dtype)
attn_weights = ops.where(causal_mask, attn_weights, mask_value)

 经过SoftMax层,得到掩码分数

# get attn scores
attn_weights = ops.softmax(attn_weights, axis=-1)

掩码分数与V相乘,得到注意力输出

# get output of Masked Self-Attention
attn_output = ops.matmul(attn_weights, value)
attn_output.shape

 将头合并

# merge multi heads
def merge_heads(tensor, num_heads, attn_head_size):
    '''
    Merge attn_head_size dim and num_attn_heads dim to hidden_size
    '''
    tensor = ops.transpose(tensor, (0, 2, 1, 3))
    new_shape = tensor.shape[:-2] + (num_heads * attn_head_size, )
    return tensor.view(new_shape)

attn_output = merge_heads(attn_output, num_heads, attn_head_size)
attn_output.shape

将输出与Wv相乘,得到最终输出

# project Attnetion results with Wv
projection = Conv1D(embed_dim, embed_dim)
attn_output = projection(attn_output)
attn_output.shape

基于MindSpore的GPT2文本摘要

基于GPT-2实现一个简单的文本摘要。

# 数据加载与预处理
from mindnlp.utils import http_get

# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

from mindspore.dataset import TextFileDataset

# load dataset
dataset = TextFileDataset(str(path), shuffle =False)
dataset.get_dataset_size()

mini_dataset, _ = dataset.split([0.001, 0.999], randomize=False)
train_dataset, test_dataset = mini_dataset.split([0.9, 0.1], randomize=False)

import json
import numpy as np
def process_dataset(dataset, tokenizer, batch_size=4, max_seq_len=1024, shuffle=False):
    '''
    数据预处理:
        原始数据格式:
        article:[CLS] article_context [SEP]
        summary:[CLS] summary_context [SEP]
        预处理后的数据格式:
        [CLS] article_context [SEP] summary_context [SEP]
    '''
    def read_map(text):
        '''
        sub function to change the form of data
        '''
        data = json.loads(text.tobytes())
        print(data)
        return np.array(data['article']), np.array(data['summarization'])
    
    def merge_and_pad(article, summary):
        # tokenization, pad to max_seq_length, only article will be truncated
        tokenized = tokenizer(
            text=article, text_pari=summary, padding='max_length', truncation='only_first', max_length=max_seq_len
        )
        # Returns tokenized input IDs for both the input (input_ids) and the labels.
        return tokenized['input_ids'], tokenized['input_ids']
    # 'text': Input column to process.
    # ['article', 'summary']: Names of the output columns after processing.
    dataset = dataset.map(read_map, output_columns=['article', 'summary'])
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])
    dataset = dataset.batch(batch_size)
    
    if shuffle:
        dataset = dataset.shuffle(batch_size)
    return dataset
    
from mindnlp.transformers import BertTokenizer

# Load BERT-base-Chinese tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

# load train dataset
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=1)

# model architecture of GPT2ForSummarization
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def forward(self, input_ids=None, attention_mask=None, labels=None):
        outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return (loss, )

num_epochs = 1
warmup_steps = 100
lr = 1.5e-4
max_grad_norm = 1.0
num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

from mindnlp.engine import TrainingArguments

training_args = TrainingArguments(
    output_dir="gpt2_summarization",
    save_steps=train_dataset.get_dataset_size(),
    save_total_limit=3,
    logging_steps=1000,
    max_steps=num_training_steps,
    learning_rate=lr,
    max_grad_norm=max_grad_norm,
    warmup_steps=warmup_steps
    
)

from mindnlp.engine import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

trainer.train()

def process_test_dataset(test_dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    test_dataset = test_dataset.map(read_map, output_columns=['article', 'summary'])
    test_dataset = test_dataset.map(pad, 'article', ['input_ids'])
    
    test_dataset = test_dataset.batch(batch_size)

    return test_dataset

tokenizer_test = BertTokenizer.from_pretrained('bert-base-chinese')

batched_test_dataset = process_test_dataset(test_dataset, tokenizer_test, batch_size=1)

model = GPT2LMHeadModel.from_pretrained('./gpt2_summarization/checkpoint-45', config=config)

model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in batched_test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print('input', tokenizer.decode(input_ids[0].tolist()))
    print()
    print(output_text)
    i += 1
    if i == 1:
        break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249069.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++设计模式-策略模式-StrategyMethod

动机(Motivation) 在软件构建过程中,某些对象使用的算法可能多种多样,经常改变,如果将这些算法都编码到对象中,将会使对象变得异常复杂;而且有时候支持不使用的算法也是一个性能负担。 如何在运…

Harbor安装、HTTPS配置、修改端口后不可访问?

Harbor安装、HTTPS配置、修改端口后不可访问? 大家好,我是秋意零。今天分享Harbor相关内容,安装部分可完全参考官方文档,写的也比较详细。 安装Harbor 官方文档:https://goharbor.io/docs/2.12.0/install-config/ …

DMS2024|思腾合力受邀参加第二届CCF数字医学大会

随着人工智能技术的不断进步,其在医学领域的应用日益广泛。从医学影像分析、疾病诊断到个性化治疗方案设计,人工智能正在逐步改变传统的医疗模式。未来,数字医学将更加注重数据的整合与挖掘,推动医学研究的深入与创新。 2024年11…

Python 绘制 向量减法

Python 绘制 向量减法 flyfish import matplotlib.pyplot as plt# 向量数据 a [1, 2] b [3, 2]# 计算-a 和 a-b minus_b [-x for x in b] # 反转向量b得到-b a_minus_b [a[i] minus_b[i] for i in range(2)] # 计算a - b# 绘制原点 plt.plot([0], [0], ko) # 黑色圆点…

工作坊报名|使用 TEN 与 Azure,探索你的多模态交互新场景

GPT-4o Realtime API 发布,语音 AI 技术正在进入一场新的爆发。语音AI技术的实时语音和视觉互动能力将为我们带来更多全新创意和应用场景。 实时音频交互: 允许应用程序实时接收并响应语音和文本输入。自然语音生成: 减少 AI 技术生成的语音…

node.js基础学习-http模块-创建HTTP服务器、客户端(一)

http模块式Node.js内置的模块,用于创建和管理HTTP服务器。Node.js使用JavaScript实现,因此性能更好。 使用http模块创建服务器,我们建议使用commonjs模块规范,因为很多第三方的组件都使用了这种规范。当然es6写法也支持。 下面就是…

黑马程序员Java项目实战《苍穹外卖》Day01

苍穹外卖-day01 课程内容 软件开发整体介绍苍穹外卖项目介绍开发环境搭建导入接口文档Swagger 项目整体效果展示: ​ 管理端-外卖商家使用 ​ 用户端-点餐用户使用 当我们完成该项目的学习,可以培养以下能力: 1. 软件开发整体介绍 作为一…

【NOIP普及组】表达式求值

【NOIP普及组】表达式求值 C语言代码C 代码Java代码Python代码 💐The Begin💐点点关注,收藏不迷路💐 给定一个只包含加法和乘法的算术表达式,请你编程计算表达式的值。 输入 输入仅有一行,为需要你计算的…

“蜀道山”高校联合公益赛 Web (部分)

文章目录 奶龙牌WAF海关警察训练平台恶意代码检测器 奶龙牌WAF <?php if ($_SERVER[REQUEST_METHOD] POST && isset($_FILES[upload_file])) {$file $_FILES[upload_file];if ($file[error] UPLOAD_ERR_OK) {$name isset($_GET[name]) ? $_GET[name] : basen…

鸿蒙中的Image组件如何引用网络图片

1.引用网络图片资源 引入网络图片需要申请权限ohos.permission.INTERNET&#xff0c;此时&#xff0c;Image组件的src参数为网络图片的链接&#xff0c;为了成功加载网络图片&#xff0c;您需要在module.json5文件中申请网络访问权限 注意&#xff1a;实际可用的时候&#xff0…

问题记录-Java后端

问题记录 目录 问题记录1.多数据源使用事务注意事项&#xff1f;2.mybatis执行MySQL的存储过程&#xff1f;3.springBoot加载不到nacos配置中心的配置问题4.服务器产生大量close_wait情况 1.多数据源使用事务注意事项&#xff1f; 问题&#xff1a;在springBoot项目中多表处理数…

PySide6 QSS(Qt Style Sheets) Reference: PySide6 QSS参考指南

Qt官网参考资料&#xff1a; QSS介绍&#xff1a; Styling the Widgets Application - Qt for Pythonhttps://doc.qt.io/qtforpython-6/tutorials/basictutorial/widgetstyling.html#tutorial-widgetstyling QSS 参考手册&#xff1a; Qt Style Sheets Reference | Qt Widge…

vue3 开发利器——unplugin-auto-import

这玩意儿是干啥的&#xff1f; 还记得 Vue 3 的组合式 API 语法吗&#xff1f;如果有印象&#xff0c;那你肯定对以下代码有着刻入 DNA 般的熟悉&#xff1a; 刚开始写觉得没什么&#xff0c;但是后来渐渐发现&#xff0c;这玩意儿几乎每个页面都有啊&#xff01; 每次都要写…

搭建AI知识库:打造坚实的团队知识堡垒

在信息爆炸的时代&#xff0c;企业面临着知识管理的挑战。团队知识堡垒的构建&#xff0c;即搭建一个高效的AI知识库&#xff0c;对于保护和利用知识资产、提升团队协作效率和创新能力至关重要。本文将探讨搭建AI知识库的重要性、策略以及如何通过这一系统打造坚实的团队知识堡…

前端-Git

一.基本概念 Git版本控制系统时一个分布式系统&#xff0c;是用来保存工程源代码历史状态的命令行工具 简单来说Git的作用就是版本管理工具。 Git的应用场景&#xff1a;多人开发管理代码&#xff1b;异地开发&#xff0c;版本管理&#xff0c;版本回滚。 Git 的三个区域&a…

《Shader入门精要》透明效果

代码以及实例图可以看github &#xff1a;zaizai77/Shader-Learn: 实现一些书里讲到的shader 在实时渲染中要实现透明效果&#xff0c;通常会在渲染模型时控制它的透明通道&#xff08;Alpha Channel&#xff09;​。当开启透明混合后&#xff0c;当一个物体被渲染到屏幕上时&…

Exploring Prompt Engineering: A Systematic Review with SWOT Analysis

文章目录 题目摘要简介方法论背景相关工作评估结论 题目 探索快速工程&#xff1a;基于 SWOT 分析的系统评价 论文地址&#xff1a; https://arxiv.org/abs/2410.12843 摘要 在本文中&#xff0c;我们对大型语言模型 (LLM) 领域的提示工程技术进行了全面的 SWOT 分析。我们强…

LLM-pruner源码解析

1.超参数 模型剪枝的超参数 模型 模型检查点和日志的保存地址 剪枝比例&#xff0c;这里默认0.5 剪枝类型&#xff0c;这里模型L2 模型生成时的超参数 温度 top_p 最大序列长度 逐通道&#xff0c;逐块&#xff0c;逐层&#xff0c;这个逐层我不记得在论文里面提过啊 layer…

Stable Diffusion 3详解

&#x1f33a;系列文章推荐&#x1f33a; 扩散模型系列文章正在持续的更新&#xff0c;更新节奏如下&#xff0c;先更新SD模型讲解&#xff0c;再更新相关的微调方法文章&#xff0c;敬请期待&#xff01;&#xff01;&#xff01;&#xff08;本文及其之前的文章均已更新&…

零基础学安全--shell(8)脚本相互利用

目录 学习连接 脚本相互利用 脚本利用 利用脚本中的变量 重定向 输出重定向 错误输出 输入重定向 学习连接 声明&#xff01; 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下&#xff0c;如涉及侵权马上删除文章&#xff0c;笔记只是方便各位师傅的学习和探…