中文情感分类

news2025/5/15 6:59:04

本文通过ChnSentiCorp数据集介绍了文本分类任务过程,主要使用预训练语言模型bert-base-chinese直接在测试集上进行测试,也简要介绍了模型训练流程,不过最后没有保存训练好的模型。

一.任务和数据集介绍
1.任务
中文情感分类本质还是一个文本分类问题。
2.数据集
本文使用ChnSentiCorp情感分类数据集,每条数据中包括一句购物评价,以及一个标识,表明这条评价是一条好评还是一条差评。被评价的商品主要是书籍、酒店、计算机配件等。一些例子如下所示:

二.模型架构
基本思路是先特征抽取,然后进行下游任务。前者主要是RNN、LSTM、GRU、BERT、GPT、Transformers等模型,后者本质就是分类模型,比如全连接神经网络等。

三.实现代码
1.准备数据集
(1)使用编码工具

def load_encode_tool(pretrained_model_name_or_path):
    # 加载编码工具bert-base-chinese
    token = BertTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
    # print(token)
    return token
if __name__ == '__main__':
    # 测试编码工具
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
    token = load_encode_tool(pretrained_model_name_or_path)
    print(token)

输出结果如下所示:

BertTokenizer(name_or_path='L:\20230713_HuggingFaceModel\bert-base-chinese', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

其中,vocab_size=21128表示bert-base-chinese模型的字典中有21128个词,特殊token主要是UNK、SEP、PAD、CLS、MASK。需要说明的是model_max_length,定义为self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER。因为从本地加载模型,不满足如下条件,所以给model_max_length赋了一个很大的数值,如下所示:

返回值token总的数据结构如下所示:

接下来测试编码工具如下所示:

if __name__ == '__main__':
    # 测试编码工具
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
    token = load_encode_tool(pretrained_model_name_or_path)
    out = token.batch_encode_plus(
        batch_text_or_text_pairs=['从明天起,做一个幸福的人。', '喂马,劈柴,周游世界。'],
        truncation=True,  # 是否截断
        padding='max_length',  # 是否填充
        max_length=17,  # 最大长度,如果不足,那么填充,如果超过,那么截断
        return_tensors='pt',  # 返回的类型
        return_length=True  # 返回长度
    )
    # 查看编码输出
    for key, value in out.items():
        print(key, value.shape)
        # 把编码还原成文本
        print(token.decode(out['input_ids'][0]))

输出结果如下所示:

input_ids torch.Size([2, 17])
[CLS] 从 明 天 起 , 做 一 个 幸 福 的 人 。 [SEP] [PAD] [PAD]
token_type_ids torch.Size([2, 17])
[CLS] 从 明 天 起 , 做 一 个 幸 福 的 人 。 [SEP] [PAD] [PAD]
length torch.Size([2])
[CLS] 从 明 天 起 , 做 一 个 幸 福 的 人 。 [SEP] [PAD] [PAD]
attention_mask torch.Size([2, 17])
[CLS] 从 明 天 起 , 做 一 个 幸 福 的 人 。 [SEP] [PAD] [PAD]

其中,out数据结构如下所示:

因此,out['input_ids'][0]表示第一个句子,token.decode(out['input_ids'][0])表示对第一个句子进行解码。input_idstoken_type_idsattention_mask编码结果示意图如下所示:

说明:bert-base-chinese编码工具是以字为词,即把每个字都作为一个词进行处理。如果对input_idstoken_type_idsattention_mask物理意义不清楚的,那么参考使用编码工具。

(2)定义数据集

class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        mode_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'
        self.dataset = load_from_disk(mode_name_or_path)[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']
        return text, label
if __name__ == '__main__':
    # 加载训练数据集
    dataset = Dataset('train')
    print(len(dataset), dataset[20])

输出结果如下所示:

9600 ('非常不错,服务很好,位于市中心区,交通方便,不过价格也高!', 1)

(3)定义计算设备
通常做深度学习都会有个N卡,设置CUDA如下所示:

device = 'cpu'
if torch.cuda.is_available():
   device = 'cuda'

(4)定义数据整理函数
主要对输入data进行batch_encode_plus(),然后返回input_ids、attention_mask、token_type_ids和labels:

# 数据整理函数
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    # 编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents, truncation=True, padding='max_length', max_length=500, return_tensors='pt', return_length=True)
    # input_ids:编码之后的数字
    # attention_mask:补零的位置是0, 其他位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)
    # 把数据移动到计算设备上
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)
    return input_ids, attention_mask, token_type_ids, labels
if __name__ == '__main__':
    # 测试编码工具
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
    token = load_encode_tool(pretrained_model_name_or_path)
    
    # 定义计算设备
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    
    # 测试数据整理函数
    data = [
        ('你站在桥上看风景', 1),
        ('看风景的人在楼上看你', 0),
        ('明月装饰了你的窗子', 1),
        ('你装饰了别人的梦', 0),
    ]
    input_ids, attention_mask, token_type_ids, labels = collate_fn(data)
    print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)

结果输出如下所示:

torch.Size([4, 500]) torch.Size([4, 500]) torch.Size([4, 500]) tensor([1, 0, 1, 0], device='cuda:0')

(5)定义数据集加载器
定义数据集加载器使用数据整理函数批量处理数据集中的数据如下所示:

loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=16, collate_fn=collate_fn, shuffle=True, drop_last=True)
  • dataset:如果数据集是train dataset,那么就是训练集数据加载器;如果数据集是test dataset,那么就是测试集数据加载器,上述定义的dataset是训练集数据加载器。
  • batch_size=16:每个batch包括16条数据。
  • collate_fn=collate_fn:使用的数据整理函数
  • shuffle=True:打乱各个batch间的顺序,让数据随机
  • drop_last=True:当剩余数据不足16条时,丢弃这些尾数

2.定义模型
(1)加载预训练模型

# 查看数据样例
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)
    
pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
pretrained = BertModel.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
# 不训练预训练模型,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)
pretrained.to(device)

out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
print(out.last_hidden_state.shape)

输出结果如下所示:

torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) tensor([1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0], device='cuda:0') #数据整理函数计算结果
torch.Size([16, 500, 768]) #16表示batch size,500表示句子包含词数,768表示向量维度

其中,out是BaseModelOutputWithPoolingAndCrossAttentions对象,包括last_hidden_state和pooler_output两个字段。数据结构如下所示:

(2)定义下游任务模型
该模型是权重为768×2的全连接神经网络,本质就是把768维向量转换为2维。计算过程就是通过模型提取特征矩阵(16×500×768),然后取第1个字[CLS]代表整个文本语义特征,用于下游分类任务等。如下所示:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # 使用预训练模型抽取数据特征
        with torch.no_grad():
            out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 对抽取的特征只取第1个字的结果做分类即可
        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)
        return out

3.训练和测试
(1)训练

def train():
    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=5e-4)
    # 定义1oss函数
    criterion = torch.nn.CrossEntropyLoss()
    # 定义学习率调节器
    scheduler = get_scheduler(name='linear', num_warmup_steps=0, num_training_steps=len(loader), optimizer=optimizer)
    # 将模型切换到训练模式
    model.train()
    # 按批次遍历训练集中的数据
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
        # 模型计算
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 计算loss并使用梯度下降法优化模型参数
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        # 输出各项数据的情况,便于观察
        if i % 10 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)

(2)测试

def test():
    # 定义测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'), batch_size=32, collate_fn=collate_fn, shuffle=True, drop_last=True)
    # 将下游任务模型切换到运行模式
    model.eval()
    correct = 0
    total = 0
    # 按批次遍历测试集中的数据
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        # 计算5个批次即可,不需要全部遍历
        if i == 5:
            break
        print(i)
        # 计算
        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 统计正确率
        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)
    print(correct / total)

在PyTorch中,model.train()和model.eval()是用于切换模型训练和评估模式的方法:

  • model.train()方法:将模型切换到训练模式,这会启用一些特定的行为,例如梯度下降、权重更新等。在训练模式下,模型会使用训练数据对模型参数进行更新,以最小化损失函数。
  • model.eval()方法:将模型切换到评估模式,这会禁用一些在训练模式下启用的行为,例如梯度下降、权重更新等。在评估模式下,模型通常用于对测试数据进行预测,以评估模型的性能。
  • param.requires_grad_(False)方法:它和model.eval都可以关闭梯度计算,但两者区别在于param.requires_grad_(False)只关闭单个参数的梯度计算,而model.eval关闭整个模型的梯度计算。

参考文献:
[1]HuggingFace自然语言处理详解:基于BERT中文模型的任务实战
[2]https://huggingface.co/bert-base-chinese/tree/main

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/945307.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

会员管理系统实战开发教程05-会员开卡

上一篇我们讲解了如何点击按钮弹出层,已经罗列了会员管理的一些常见功能。本篇我们介绍一下会员开卡的业务。 1 创建变量 我们会员开卡的业务的话,也是要在本页面弹出,弹出其实只是让组件是否显示和隐藏,我们先定义一个布尔值类…

柠檬水找零【贪心算法-】

柠檬水找零 在柠檬水摊上,每一杯柠檬水的售价为 5 美元。顾客排队购买你的产品,(按账单 bills 支付的顺序)一次购买一杯。 每位顾客只买一杯柠檬水,然后向你付 5 美元、10 美元或 20 美元。你必须给每个顾客正确找零&…

读书笔记-《ON JAVA 中文版》-摘要23[第二十章 泛型-2]

文章目录 第二十章 泛型5. 泛型擦除5.1 泛型擦除5.2 迁移兼容性5.3 擦除的问题5.4 边界处的动作 6. 补偿擦除7. 边界8. 通配符8.1 通配符8.2 逆变 9. 问题10. 动态类型安全11. 泛型异常 第二十章 泛型 普通的类和方法只能使用特定的类型:基本数据类型或类类型。如果…

驱动代码验证

要求 代码 #include <stdio.h> #include <unistd.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <stdlib.h> #include <string.h> #include <sys/ioctl.h>int main(int argc,const char * a…

移动端h5项目的兼容和适配问题

解决兼容性问题的关键在于对移动端产品的生存环境进行梳理&#xff0c;在此基础之上制定应对策略。 所谓生存环境主要分为三个维度&#xff1a; 硬件环境&#xff0c;细分为品牌和机型&#xff0c;决定了屏幕大小、性能等硬件限制 操作系统&#xff0c;比如iOS6和iOS7&#xf…

Redis数据类型(list\set\zset)

"maybe its why" List类型 列表类型是⽤来存储多个有序的字符串&#xff0c;列表中的每个字符串称为元素&#xff08;element&#xff09;&#xff0c;⼀个列表最多可以存储个2^32 - 1个元素。在Redis中&#xff0c;可以对列表两端插⼊&#xff08;push&#xff09…

555定时器

一、定义 定时器是一种多用途的数字-模拟混合集成电路&#xff0c;可极方便的构成施密特触发器、单稳态触发器和多谐振荡器&#xff0c;其简化原理图及引脚定义如下所示 3个绿色电阻&#xff0c;电阻值为5K&#xff1b;2个黄色和粉色比较器&#xff1b;1个紫色SR触发器&#x…

WPF实战项目十三(API篇):备忘录功能api接口、优化待办事项api接口

1、新建MenoDto.cs /// <summary>/// 备忘录传输实体/// </summary>public class MenoDto : BaseDto{private string title;/// <summary>/// 标题/// </summary>public string Title{get { return title; }set { title value; }}private string con…

实验三十一、OCL 电路输出功率和效率的研究

一、题目 研究 OCL 功率放大电路的输出功率和效率。 二、仿真电路 OCL 功率放大电路如图1所示。 图 1 OCL 功率放大电路 图1\,\,\,\textrm{OCL}\,功率放大电路 图1OCL功率放大电路图中采用 NPN 型低频功率晶体管 2SC2001&#xff0c;其参数为&#xff1a; I C M 700 mA I_…

5G NR:RACH流程 -- Msg1之选择正确的PRACH时频资源

PRACH的时域资源是如何确定的 PRACH的时域资源主要由参数“prach-ConfigurationIndex”决定。拿着这个参数的取值去协议38211查表6.3.3.2-2/3/4&#xff0c;需要注意根据实际情况在这三张表中进行选择&#xff1a; FR1 FDD/SULFR1 TDDFR2 TDD Random access preambles can onl…

信号和槽的相关操作

目录 信号和槽 connect()函数 自定义信号槽 例子 自定义信号槽需要注意的事项 信号槽的更多用法 Lambda表达式 ① 函数对象参数 ② 操作符重载函数参数 ③ 可修改标示符 ④ 错误抛出标示符 ⑤ 函数返回值 ⑥ 是函数体 所谓信号槽&#xff0c;实际就是观察者模式。当…

AVS3变换:PBT、ST和SBT

前面的文章介绍了AVS3中的变换工具IST和ISTS&#xff0c;本文将介绍AVS3中剩余的几种变换工具&#xff1a;基于位置的变换&#xff08;PBT,Position Based Transform&#xff09;、二次变换&#xff08;ST, Secondary Transform&#xff09;和子块变换&#xff08;SBT, Sub-Blo…

SmartInspect Professional .Net Delphi Crack

SmartInspect Professional .Net & Delphi Crack SmartInspect Professional是一个用于调试和跟踪.NET、Java和Delphi软件的高级日志记录工具。它使您能够识别错误&#xff0c;找到客户问题的解决方案&#xff0c;并让您清楚地了解软件在不同环境和条件下的工作方式。可以轻…

给oracle逻辑导出clob大字段、大数据量表提提速

文章目录 前言一、大表数据附&#xff1a;查询大表 二、解题思路1.导出排除大表的数据2.rowid切片导出大表数据Linux代码如下&#xff08;示例&#xff09;&#xff1a;Windows代码如下&#xff08;示例&#xff09;&#xff1a;手工执行代码如下&#xff08;示例&#xff09;&…

java八股文面试[多线程]——Synchronized的底层实现原理

笔试&#xff1a;画出Synchronized 线程状态流转实现原理图 synchronized关键字解决的是多个线程之间访问资源的同步性&#xff0c;synchronized 翻译为中文的意思是同步&#xff0c;也称之为”同步锁“。 synchronized的作用是保证在同一时刻&#xff0c; 被修饰的代码块或方…

任意文件上传

文章目录 渗透测试漏洞原理任意文件上传1. 任意文件上传概述1.1 漏洞成因1.2 漏洞原理1.3 漏洞危害1.4 漏洞的利用方法1.5 漏洞的验证 2. WebShell解析2.1 Shell2.1.1 命令解释器 2.2 WebShell2.2.1 大马2.2.2 小马2.2.3 GetShell2.2.4 WebShell项目 3. 任意文件上传攻防3.1 毫…

注册字符设备

五、注册字符设备 struct cdev {struct kobject kobj;//表示该类型实体是一种内核对象struct module *owner;//填THIS_MODULE&#xff0c;表示该字符设备从属于哪个内核模块const struct file_operations *ops;//指向空间存放着针对该设备的各种操作函数地址struct list_head …

RAD Installer Crack,集成到RAD Studio IDE支持

RAD & Installer Crack,集成到RAD Studio IDE支持 用于创建NSIS和Inno Setup安装程序的RAD Studio扩展。它将NSIS(Nullsoft Scriptable Install System)和Inno Setup与Embarcadero RAD Studio IDE结合在一起。它允许您在RAD Studio中设计和构建NSIS和Inno Setup项目&#x…

错误的迷宫:探索开发中的异常管理之旅

引言&#xff1a;为什么我们需要谈论错误处理&#xff1f; 在软件开发的世界中&#xff0c;错误是不可避免的。它们是我们编程旅程中的挑战&#xff0c;但也是我们成长的机会。正确地处理错误不仅可以确保软件的稳定性和可靠性&#xff0c;还可以为开发者提供宝贵的反馈。本文…

Icon设计神器!这5个软件一定要试试

在界面设计中&#xff0c;Icon既可以为用户指明用途&#xff0c;又可以提升界面设计的质感&#xff0c;可以说是一种必不可少的设计素材。而市面上可以制作的Icon的设计软件也十分丰富&#xff0c;今天本文将选出了5个好用的与大家分享&#xff0c;它们不仅功能强大&#xff0c…