用pytorch进行BERT文本分类

news2025/5/17 18:39:56

BERT 是一个强大的语言模型,至少有两个原因:

  1. 它使用从 BooksCorpus (有 8 亿字)和 Wikipedia(有 25 亿字)中提取的未标记数据进行预训练。
  2. 顾名思义,它是通过利用编码器堆栈的双向特性进行预训练的。这意味着 BERT 不仅从左到右,而且从右到左从单词序列中学习信息。

BERT 模型需要一系列 tokens (words) 作为输入。在每个token序列中,BERT 期望输入有两个特殊标记:

  • [CLS] :这是每个sequence的第一个token,代表分类token。
  • [SEP] :这是让BERT知道哪个token属于哪个序列的token。这一特殊表征法主要用于下一个句子预测任务或问答任务。如果我们只有一个sequence,那么这个token将被附加到序列的末尾。

就像Transformer的普通编码器一样,BERT 将一系列单词作为输入,这些单词不断向上流动。每一层都应用自我注意,并将其结果通过前馈网络传递,然后将其传递给下一个编码器。

 

BERT 输出

每个位置输出一个大小为 hidden_ size的向量(BERT Base 中为 768)。对于我们在上面看到的句子分类示例,我们只关注第一个位置的输出(将特殊的 [CLS] token 传递到该位置)。

该向量现在可以用作我们选择的分类器的输入。该论文仅使用单层神经网络作为分类器就取得了很好的效果。

使用 BERT 进行文本分类

本文的主题是用 BERT 对文本进行分类。

在这篇文章中,我们将使用kaggle上的BBC 新闻分类数据集。

数据集已经是 CSV 格式,它有 2126 个不同的文本,每个文本都标记在 5 个类别中的一个之下:

sport(体育),business(商业),politics(政治),tech(科技),entertainment(娱乐)。


模型下载 https://huggingface.co/bert-base-cased/tree/main

数据集下载 bbc-news https://huggingface.co/datasets/SetFit/bbc-news/tree/main

 有4个400多MB的文件,pytorch的模型对应的是436MB的那个文件。

需要安装transforms库

pip install transforms 

全部的流程代码:

# 全部流程代码
import numpy as np
import torch
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
labels = {'business':0,
          'entertainment':1,
          'sport':2,
          'tech':3,
          'politics':4
          }

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.labels = [labels[label] for label in df['category']]
        self.texts = [tokenizer(text, 
                                padding='max_length', 
                                max_length = 512, 
                                truncation=True,
                                return_tensors="pt") 
                      for text in df['text']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_texts, batch_y


# 数据集准备
# 拆分训练集、验证集和测试集 8:1:1
import pandas as pd
bbc_text_df = pd.read_csv('./bbc-news/bbc-text.csv')
# bbc_text_df.head()
df = pd.DataFrame(bbc_text_df)
np.random.seed(112)
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), [int(.8*len(df)), int(.9*len(df))])
print(len(df_train),len(df_val), len(df_test))
# 1780 222 223



# 构建模型
from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 5)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)
        return final_layer

#从上面的代码可以看出,BERT Classifier 模型输出了两个变量:
#1. 在上面的代码中命名的第一个变量_包含sequence中所有 token 的 Embedding 向量层。
#2. 命名的第二个变量pooled_output包含 [CLS] token 的 Embedding 向量。对于文本分类任务,使用这个 Embedding 作为分类器的输入就足够了。
# 然后将pooled_output变量传递到具有ReLU激活函数的线性层。在线性层中输出一个维度大小为 5 的向量,每个向量对应于标签类别(运动、商业、政治、 娱乐和科技)。



from torch.optim import Adam
from tqdm import tqdm

def train(model, train_data, val_data, learning_rate, epochs):
    # 通过Dataset类获取训练和验证集
    train, val = Dataset(train_data), Dataset(val_data)
    # DataLoader根据batch_size获取数据,训练时选择打乱样本
    train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)
    # 判断是否使用GPU
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    if use_cuda:
            model = model.cuda()
            criterion = criterion.cuda()
    # 开始进入训练循环
    for epoch_num in range(epochs):
            # 定义两个变量,用于存储训练集的准确率和损失
            total_acc_train = 0
            total_loss_train = 0
            # 进度条函数tqdm
            for train_input, train_label in tqdm(train_dataloader):
                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)
                # 通过模型得到输出
                output = model(input_id, mask)
                # 计算损失
                batch_loss = criterion(output, train_label)
                total_loss_train += batch_loss.item()
                # 计算精度
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc
                # 模型更新
                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
            # ------ 验证模型 -----------
            # 定义两个变量,用于存储验证集的准确率和损失
            total_acc_val = 0
            total_loss_val = 0
            # 不需要计算梯度
            with torch.no_grad():
                # 循环获取数据集,并用训练好的模型进行验证
                for val_input, val_label in val_dataloader:
                    # 如果有GPU,则使用GPU,接下来的操作同训练
                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)
  
                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label)
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            print(
                f'''Epochs: {epoch_num + 1} 
              | Train Loss: {total_loss_train / len(train_data): .3f} 
              | Train Accuracy: {total_acc_train / len(train_data): .3f} 
              | Val Loss: {total_loss_val / len(val_data): .3f} 
              | Val Accuracy: {total_acc_val / len(val_data): .3f}''') 

#我们对模型进行了 5 个 epoch 的训练,我们使用 Adam 作为优化器,而学习率设置为1e-6。
#因为本案例中是处理多类分类问题,则使用分类交叉熵作为我们的损失函数。
EPOCHS = 5
model = BertClassifier()
LR = 1e-6
train(model, df_train, df_val, LR, EPOCHS)



# 测试模型
def evaluate(model, test_data):
    test = Dataset(test_data)
    test_dataloader = torch.utils.data.DataLoader(test, batch_size=2)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
        model = model.cuda()

    total_acc_test = 0
    with torch.no_grad():
        for test_input, test_label in test_dataloader:
              test_label = test_label.to(device)
              mask = test_input['attention_mask'].to(device)
              input_id = test_input['input_ids'].squeeze(1).to(device)
              output = model(input_id, mask)
              acc = (output.argmax(dim=1) == test_label).sum().item()
              total_acc_test += acc
    print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')

# 用测试数据集进行测试    
evaluate(model, df_test)

 

EPOCHS = 5
model = BertClassifier()
LR = 1e-6
train(model, df_train, df_val, LR, EPOCHS)

在训练的这一步会非常耗时间,用GPU加速了,也需要大概39分钟.

因为BERT模型本身就是一个比较大的模型,参数非常多。

最后一步测试的时候,测试的准确率还是比较高的。达到 99.6%

 

模型的保存。这个在原文里面是没有提到的。

我们花了很多时间训练的模型如果不保存一下,下次还要重新训练岂不是费时费力?

# 保存模型
torch.save(model.state_dict(),"bertMy.pth")
# load 加载模型
model = BertClassifier()
model.load_state_dict(torch.load("bertMy.pth"))

 可以用下面的代码查看model里面的模型。

for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

 也可以将保存的模型文件 bertMy.pth上传到netron网站进行模型可视化。Netronhttps://netron.app/

 

原文位于这个地址,但是原文中的代码缺了读csv那一段 :

保姆级教程,用PyTorch和BERT进行文本分类 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/693104.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

集合专题----set篇

1、Set 接口和常用方法 (1)Set 接口基本介绍 ① 无序(添加和去除的顺序不一致),没有索引; ② 不允许重复元素,所以最多包含一个null; (2)Set 接口的常用方…

自动驾驶开源数据集(附下载链接)

自动驾驶是带动新兴产业的一个突破点,也是中国结合新能源汽车,实现汽车产业弯道超车的不二手段,是打破国外燃油车技术壁垒的关键一步!它不会停止,只是在蓄势待发! 数据集介绍:点击 自动驾驶场…

java进阶—通俗易懂线程池的概念(底层原理)及使用

前言 首先,我们知道创建一个线程 可以直接 使用 new Thread(() ->{}).start();这种形式来创建,当线程的run 方法执行结束,线程就终止了,线程对象就会被垃圾回收机制(GC)释放 然而在我们 开发工作中&…

智安网络|攻防演练对抗:网络边界自动化防御的关键

在当今高度互联的数字世界中,网络安全的重要性日益凸显。为了应对不断增长的网络威胁,组织和企业需要采取主动的防御策略,其中攻防演练对抗和自动化防御在保护网络边界方面扮演着重要的角色。本文将探讨攻防演练对抗的意义,并介绍…

四、用户管理

云尚办公系统:用户管理 B站直达【为尚硅谷点赞】: https://www.bilibili.com/video/BV1Ya411S7aT 本博文以课程相关为主发布,并且融入了自己的一些看法以及对学习过程中遇见的问题给出相关的解决方法。一起学习一起进步!!&#x…

ImportError: numpy.core.multiarray failed to import

遇到的问题: 解决方法: 根据你的opencv版本,去百度搜索对应的 numpy 版本,卸载掉现有的numpy ,安装其他版本: sudo pip install numpy1.19.0或者直接升级到numpy的最新版本: sudo pip install --upgrade…

小程序-真机上接口无法调通,开发者工具上可以

近期在对接小程序,在这里记录一下,我们在对接小程序的时候碰到的一些奇奇怪怪的问题。 其中一个问题如下: 真实效果如下图 开发者工具上可以,访问没有人任何问题。 真机上接口无法调通,也没有报错,也没有…

idea中git的使用详细说明

一.克隆项目 1. 打开File>New>Project from Version Control... 2. 打开gitLab,复制项目地址 3. 粘贴到第1步中的URL中,点击“Clone” 二.代码提交 1. 修改文件后需要提交时,可以在git-Local Changes中看到你修改的文件及修改内容 2. 选…

netty学习(1):多个客户端与服务器通信

1. 基于前面一节netty学习(1):1个客户端与服务器通信 只需要把服务器的handler改造一下即可,通过ChannelGroup 找到所有的客户端channel,发送消息即可。 package server;import io.netty.channel.*; import io.netty.channel.gr…

MS1826 HDMI 多功能视频处理器 4*4矩阵切换器

基本介绍 MS1826 是一款多功能视频处理器,包含 4 路独立 HDMI 音视频输入通道、4 路独立 HDMI 音视频输出通道以及四路独立可配置为输入或者输出的 SPDIF、I2S 音频信号。支持 4 个独立的字 库定制型 OSD;可处理隔行和逐行视频或者图形输入信号&#xff…

Spring Boot 中的 @ComponentScan 注解是什么,原理,如何使用

Spring Boot 中的 ComponentScan 注解是什么,原理,如何使用 在 Spring Boot 中,ComponentScan 是一种注解,它可以让 Spring 自动扫描指定的包及其子包中的组件,并将这些组件自动装配到 Spring 容器中。本文将介绍 Com…

UML14种图

UML14种图 UML是Unified Modeling Language的缩写,译为统一建模语言。 UML是软件行业的建模规范,可以对软件项目建立需求模型、设计模型、实现模型、测试模型。 UML2.0包含的14种图: UML各种图例(常用图形) 1. 类图&…

状态机编程实例-状态表法

上篇文章,使用嵌套switch-case法的状态机编程,实现了一个炸弹拆除小游戏。 本篇,继续介绍状态机编程的第二种方法:状态表法,来实现炸弹拆除小游戏的状态机编程。 1 状态表法 状态表法,顾名思义&#xff0…

YOLOv8的目标对象的分类,分割,跟踪和姿态估计的多任务检测实践(Netron模型可视化)

YOLOv8是目前最新版本,在以前YOLO版本基础上建立并加入了一些新的功能,以进一步提高性能和灵活性,是目前最先进的模型。YOLOv8旨在快速,准确,易于使用,使其成为广泛的目标检测和跟踪,实例分割&a…

MATLAB 之 Simulink 操作基础和系统仿真模型的建立

这里写目录标题 一、Simulink 操作基础1. Simulink 的启动与退出1.1 Simulink 的启动1.2 模型文件的打开1.3 Simulink 的退出 2. Simulink 仿真初步2.1 模型元素2.2 仿真步骤2.3 简单实例 二、系统仿真模型的建立1. Simulink 的基本模块2. 模块操作2.1 添加与删除模块2.2 选取模…

快速训练自己的大语言模型:基于LLAMA-7B的lora指令微调

目录 1. 选用工程:lit-llama2. 下载工程3. 安装环境4. 下载LLAMA-7B模型5. 做模型转换6. 初步测试7. 为什么要进行指令微调?8. 开始进行指令微调8.1. 数据准备8.2 开始模型训练8.3 模型测试 前言: 系统:ubuntu 18.04显卡&#xff…

大数据ETL工具对比(Sqoop, DataX, Kettle)

前言 在实习过程中,遇到了数据库迁移项目,对于数据仓库,大数据集成类应用,通常会采用ETL工具辅助完成,公司和客户使用的比较多的是Sqoop, DataX和Kettle这三种工具。简单的对这三种ETL工具进行一次梳理。 ETL工具&…

无法更新iPhone,提示“无法检查更新”怎么办?

当我们需要 iPhone更新系统时,可以前往iPhone设置-通用-软件更新中获取更新推送。不过一些用户可能会遇到无法更新的问题,例如会提示“无法检查更新,检查软件更新时出错”。 以上情况可能是网络问题,可以尝试重新打开设置&#xf…

vue2实现公式规则编辑校验弹窗功能

文章目录 需求描述技术栈最终效果演示功能实现逻辑拆分代码目录结构实现思路光标实现底部单个符号或字段结构设计监听键盘事件&处理光标公式规则校验 总结 需求描述 需要一个弹窗,弹窗内部需要能够进行公式规则的配置并进行公式规则合法性校验。 技术栈 vue2e…

Thymeleaf的常用语法

🌟 Thymeleaf的常用语法 Thymeleaf是一个Java模板引擎,用于处理HTML、XML、JavaScript、CSS等文件。它可以与Spring框架无缝集成,为Web应用程序提供优雅的模板解决方案。本文将介绍Thymeleaf的常用语法,包括th属性、表达式、内置…