基于BERT预训练模型(bert_base_chinese)训练中文文本分类任务(AI老师协助编程)

news2025/5/25 1:38:28

新建项目 创建一个新的虚拟环境

  1. 创建新的虚拟环境(大多数时候都需要指定python的版本号才能顺利创建):
conda create -n bert_classification python=3.9
  1. 激活虚拟环境:
conda activate myenv

PS:虚拟环境可以避免权限问题,并隔离项目依赖
权限问题的报错:

ERROR: Could not install packages due to an OSError: [WinError 5] 拒绝访问。: 'd:\\anaconda3\\lib\\site-package
s\\__pycache__\\typing_extensions.cpython-39.pyc'
Consider using the `--user` option or check the permissions.

WARNING: Ignoring invalid distribution -ip (d:\anaconda3\lib\site-packages)
WARNING: Ignoring invalid distribution -ip (d:\anaconda3\lib\site-packages)
WARNING: Ignoring invalid distribution -ip (d:\anaconda3\lib\site-packages)

在项目中进行配置

在这里插入图片描述
在这里插入图片描述

配置相关的库

pip install transformers datasets evaluate torch

训练脚本(train_bert.py)如下:

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import evaluate
import numpy as np

# 加载数据集
dataset = load_dataset('csv', data_files={
    'train': 'D:/pyx/Five_data/train.csv',
    'validation': 'D:/pyx/Five_data/val.csv',
    'test': 'D:/pyx/Five_data/test.csv'
})

# 加载分词器
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')

# 预处理函数
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)

# 应用预处理
tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

# 加载模型
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-chinese",
    num_labels=5,
    ignore_mismatched_sizes=True
)

# 定义评估指标
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# 训练参数
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

# 创建Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
)

# 训练模型
trainer.train()
trainer.save_model("./best_model")

# 评估模型
results = trainer.evaluate(tokenized_datasets["test"])
print(f"Test Results: {results}")

上面这个脚本不够完善,遇到了很多报错如下:
报错一

datasets.table.CastError: Couldn't cast
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>: string
__index_level_0__: string
__index_level_1__: string
__index_level_2__: string
__index_level_3__: string
__index_level_4__: string
__index_level_5__: string
__index_level_6__: string
__index_level_7__: string
-- schema metadata --
pandas: '{"index_columns": ["__index_level_0__", "__index_level_1__", "__' + 1534
to
{'<?xml version="1.0" encoding="UTF-8" standalone="yes"?>': Value(dtype='string', id=None)}
because column names don't match

During handling of the above exception, another exception occurred:

检查训练的数据集train.csv等是不是编码有问题,打开是不是乱码,如果是用表格整理的数据,最好是用WPS保存时选择另存为.csv文件
在这里插入图片描述
报错2:

Traceback (most recent call last):
  File "D:\pyx\pythonProject\pythonProject\bert-classification\train_bert.py", line 33, in <module>
    metric = evaluate.load("accuracy")
  File "D:\anaconda3\envs\bert-classification\lib\site-packages\evaluate\loading.py", line 748, in load
    evaluation_module = evaluation_module_factory(
  File "D:\anaconda3\envs\bert-classification\lib\site-packages\evaluate\loading.py", line 681, in evaluation_module_factory
    raise FileNotFoundError(
FileNotFoundError: Couldn't find a module script at D:\pyx\pythonProject\pythonProject\bert-classification\accuracy\accuracy.py. Module 'accuracy' doesn't exist on the Hugging Face Hub either.

定位到有问题的代码片段:

# 定义评估指标
metric = evaluate.load("accuracy")

方法二尝试提高库的版本可以适用于这些参数

但是我的还是报错所以重新修改参数 不使用这些模块 并且让AI帮我添加了很多调试信息,改进部分如下:

# 5. 定义训练参数(完全兼容transformers旧版本)
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    # 移除所有新版本特有的参数
    logging_steps=50,
    save_steps=len(tokenized_datasets["train"]) // 3,  # 每轮保存3次
    save_total_limit=3,  # 最多保存3个检查点
    logging_dir="./train_logs",
    disable_tqdm=False,
    dataloader_num_workers=4,
)
# 6. 自定义计算指标函数
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = compute_accuracy(predictions, labels)
    return {"accuracy": accuracy}

# 7. 创建带监控的Trainer(修改为普通Trainer)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
)

# 8. 开始训练(手动管理验证过程)
logger.info("\n===== 开始训练 =====")
train_result = trainer.train()

# 9. 手动执行验证
logger.info("\n===== 验证模型 =====")
eval_results = trainer.evaluate()
for key, value in eval_results.items():
    logger.info(f"  {key}: {value:.4f}")

# 10. 保存模型
logger.info(f"\n保存模型至: {training_args.output_dir}")
trainer.save_model(training_args.output_dir)

# 11. 评估测试集
logger.info("\n===== 评估测试集 =====")
test_results = trainer.evaluate(tokenized_datasets["test"])
logger.info(f"测试集准确率: {test_results['eval_accuracy']:.4f}")

# 12. 打印训练统计信息(需要手动计算)
logger.info(f"\n训练总步数: {train_result.global_step}")
logger.info(f"训练总耗时: {train_result.metrics['train_runtime']:.2f}秒")
logger.info(f"训练平均速度: {train_result.metrics['train_samples_per_second']:.2f}样本/秒")

然后还需要

pip install 'accelerate>=0.26.0'

成功开始训练
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2385008.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从数据到智能:openGauss+openEuler Intelligence的RAG架构实战

随着人工智能和大规模语言模型技术的崛起&#xff0c;传统的搜索引擎由于其只能提供简单的关键字匹配结果&#xff0c;已经越来越无法满足用户对于复杂、多样化和上下文相关的知识检索需求。与此相对&#xff0c;RAG&#xff08;Retrieval-Augmented Generation&#xff09;技术…

【Linux】初见,基础指令

前言 本文将讲解Linux中最基础的东西-----指令&#xff0c;带大家了解一下Linux中有哪些基础指令&#xff0c;分别有什么作用。 本文中的指令和选项并不全&#xff0c;只介绍较为常用的 pwd指令 语法&#xff1a;pwd 功能&#xff1a;显示当前所在位置&#xff08;路径&#xf…

什么是实时流数据?核心概念与应用场景解析

在当今数字经济时代&#xff0c;实时流数据正成为企业核心竞争力。金融机构需要实时风控系统在欺诈交易发生的瞬间进行拦截&#xff1b;电商平台需要根据用户实时行为提供个性化推荐&#xff1b;工业物联网需要监控设备状态预防故障。这些场景都要求系统能够“即时感知、即时分…

工业RTOS生态重构:从PLC到“端 - 边 - 云”协同调度

一、引言 在当今数字化浪潮席卷全球的背景下&#xff0c;工业领域正经历着深刻变革。工业自动化作为制造业发展的基石&#xff0c;其技术架构的演进直接关系到生产效率、产品质量以及企业的市场竞争力。传统的PLC&#xff08;可编程逻辑控制器&#xff09;架构虽然在工业控制领…

基于开源链动2+1模式AI智能名片S2B2C商城小程序的社群构建与新型消费迎合策略研究

摘要&#xff1a;随着个性化与小众化消费的崛起&#xff0c;消费者消费心理和模式发生巨大变化&#xff0c;社群构建对商家迎合新型消费特点、融入市场经济发展至关重要。开源链动21模式AI智能名片S2B2C商城小程序的出现&#xff0c;为社群构建提供了创新工具。本文探讨该小程序…

高性能RPC框架--Dubbo(五)

Filter&#xff1a; filter过滤器动态拦截请求&#xff08;request&#xff09;或响应&#xff08;response&#xff09;以转换或使用请求或响应中包含的信息。同时对于filter过滤器不仅适合消费端而且还适合服务提供端。我们可以自定义在什么情况下去使用filter过滤器 Activa…

搭建自己的语音对话系统:开源 S2S 流水线深度解析与实战

网罗开发 &#xff08;小红书、快手、视频号同名&#xff09; 大家好&#xff0c;我是 展菲&#xff0c;目前在上市企业从事人工智能项目研发管理工作&#xff0c;平时热衷于分享各种编程领域的软硬技能知识以及前沿技术&#xff0c;包括iOS、前端、Harmony OS、Java、Python等…

feign调用指定服务ip端口

1 背景 在springcloud开发时候&#xff0c;同时修改了feign接口和调用方的代码&#xff0c;希望直接在某个环境调用修改的代码&#xff0c;而线上的服务又不希望被下线因为需要继续为其他访问页面的用户提供功能后端服务&#xff0c;有时候甚者包含你正在修改的功能。 2 修改…

【深尚想!爱普特APT32F1023H8S6单片机重构智能电机控制新标杆】

在智能家电与健康器械市场爆发的今天&#xff0c;核心驱动技术正成为产品突围的关键。传统电机控制方案面临集成度低、开发周期长、性能瓶颈三大痛点&#xff0c;而爱普特电子带来的APT32F1023H8S6单片机无感三合一方案&#xff0c;正在掀起一场智能电机控制的技术革命。 爆款基…

Unity EventCenter 消息中心的设计与实现

在开发过程中&#xff0c;想要传递信号和数据&#xff0c;就得在不同模块之间实现通信。直接通过单例调用虽然简单&#xff0c;但会导致代码高度耦合&#xff0c;难以维护。消息中心提供了一种松耦合的通信方式&#xff1a;发布者不需要知道谁接收事件&#xff0c;接收者不需要…

MySQL远程连接10060错误:防火墙端口设置指南

问题描述&#xff1a; 如果你通过本机服务器远程连接MySQL&#xff0c;出现10060错误&#xff0c;那可能是你的防火墙的问题 解决&#xff1a; 第一步&#xff1a;查看防火墙规则 通过以下命令查询&#xff0c;看ports是否开放了3306端口&#xff0c;目前只开放了22端口 f…

使用 OpenCV 实现 ArUco 码识别与坐标轴绘制

&#x1f3af; 使用 OpenCV 实现 ArUco 码识别与坐标轴绘制&#xff08;含Python源码&#xff09; Aruco 是一种广泛用于机器人、增强现实&#xff08;AR&#xff09;和相机标定的方形标记系统。本文将带你一步一步使用 Python OpenCV 实现图像中多个 ArUco 码的检测与坐标轴…

canal实现mysql数据同步

目录 1、canal下载 2、mysql同步用户创建和授权 3、canal admin安装和启动 4、canal server安装和启动 5、java 端集成监听canal 同步的mysql数据 6、java tcp同步只是其中一种方式&#xff0c;还可以通过kafka、rabbitmq等方式进行数据同步 1、canal下载 canal实现mysq…

易境通专线散拼系统:全方位支持多种专线物流业务!

在全球化电商快速发展的今天&#xff0c;跨境电商物流已成为电商运营中极为重要的环节。为了确保物流效率、降低运输成本&#xff0c;越来越多的电商卖家选择专线物流服务。专线物流作为五大主要跨境电商物流模式之一&#xff0c;通过固定的运输路线和流程&#xff0c;极大提高…

06 如何定义方法,掌握有参无参,有无返回值,调用数组作为参数的方法,方法的重载

1.调用方法 2.掌握有参函数 3.调用数组作为参数 一个例题&#xff1a;数组参数&#xff0c;返回值 方法的重载 两个例题&#xff1a;冒泡排序和九九乘法表的格式学习

使用vscode MSVC CMake进行C++开发和Debug

使用vscode MSVC CMake进行C开发和Debug 前言软件安装安装插件构建debuug方案一debug方案二其他 前言 一般情况下我都是使用visual studio来进行c开发的&#xff0c;但是由于python用的是vscode&#xff0c;所以二者如果统一的话能稍微提高一点效率。 软件安装 需要安装的软…

提升开发运维效率:原力棱镜游戏公司的 Amazon Q Developer CLI 实践

引言 在当今快速发展的云计算环境中&#xff0c;游戏开发者面临着新的挑战和机遇。为了提升开发效率&#xff0c;需要更智能的工具来辅助工作流程。Amazon Q Developer CLI 作为亚马逊云科技推出的生成式 AI 助手&#xff0c;为开发者提供了一种新的方式来与云服务交互。 Ama…

@Column 注解属性详解

提示&#xff1a;文章旨在说明 Column 注解属性如何在日常开发中使用&#xff0c;数据库类型为 MySql&#xff0c;其他类型数据库可能存在偏差&#xff0c;需要注意。 文章目录 一、name 方法二、unique 方法三、nullable 方法四、insertable 方法五、updatable 方法六、column…

基于 ESP32 与 AWS 全托管服务的 IoT 架构:MQTT + WebSocket 实现设备-云-APP 高效互联

目录 一、总体架构图 二、设备端(ESP32)低功耗设计(适配 AWS IoT) 1.MQTT 设置(ESP32 连接 AWS IoT Core) 2.低功耗策略总结(ESP32) 三、云端架构(基于 AWS Serverless + IoT Core) 1.AWS IoT Core 接入 2.云端 → APP:WebSocket 推送方案 流程: 3.数据存…

unity在urp管线中插入事件

由于在urp下&#xff0c;打包后传统的相机事件有些无法正确执行&#xff0c;这时候我们需要在urp管线中的特定时机进行处理一些事件&#xff0c;需要创建继承ScriptableRenderPass和ScriptableRendererFeature的脚本&#xff0c;示例如下&#xff1a; PluginEventPass&#xf…