InternLM 论文分类微调实践(XTuner 版)

news2025/7/17 14:08:39

1.环境安装

我创建开发机选择镜像为Cuda12.2-conda,选择GPU为100%A100的资源配置

Conda 管理环境

conda create -n xtuner_101 python=3.10 -y
conda activate xtuner_101
pip install torch==2.4.0+cu121 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
pip install xtuner timm flash_attn datasets==2.21.0 deepspeed==0.16.1
conda install mpi4py -y
#为了兼容模型,降级transformers版本
pip uninstall transformers -y
pip install transformers==4.48.0 --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple

检验环境安装

xtuner list-cfg

2.数据获取

数据为sftdata.jsonl,已上传。

3.训练

 链接模型位置命令

ln -s /root/share/new_models/Shanghai_AI_Laboratory/internlm2_5-7b-chat ./

3.1 微调脚本

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (
    CheckpointHook,
    DistSamplerSeedHook,
    IterTimerHook,
    LoggerHook,
    ParamSchedulerHook,
)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (
    DatasetInfoHook,
    EvaluateChatHook,
    VarlenAttnArgsToMessageHubHook,
)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
pretrained_model_name_or_path = "./internlm2_5-7b-chat"
use_varlen_attn = False

# Data
alpaca_en_path = "/root/xtuner/datasets/train/sftdata.jsonl"#换成自己的数据路径
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 2048
pack_to_max_length = True

# parallel
sequence_parallel_size = 1

# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 1
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.03

# Save
save_steps = 500
save_total_limit = 2  # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = ["请给我介绍五个上海的景点", "Please tell me five scenic spots in Shanghai"]

#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    trust_remote_code=True,
    padding_side="right",
)

model = dict(
    type=SupervisedFinetune,
    use_varlen_attn=use_varlen_attn,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        quantization_config=dict(
            type=BitsAndBytesConfig,
            load_in_4bit=True,
            load_in_8bit=False,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        ),
    ),
    lora=dict(
        type=LoraConfig,
        r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    ),
)

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
alpaca_en = dict(
    type=process_hf_dataset,
    dataset=dict(type=load_dataset, path='json', data_files=alpaca_en_path),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=alpaca_map_fn,
    template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length,
    use_varlen_attn=use_varlen_attn,
)

sampler = SequenceParallelSampler if sequence_parallel_size > 1 else DefaultSampler
train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=alpaca_en,
    sampler=dict(type=sampler, shuffle=True),
    collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn),
)

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale="dynamic",
    dtype="float16",
)

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True,
    ),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        end=max_epochs,
        convert_to_iter_based=True,
    ),
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        system=SYSTEM,
        prompt_template=prompt_template,
    ),
]

if use_varlen_attn:
    custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 10 iterations.
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per `save_steps`.
    checkpoint=dict(
        type=CheckpointHook,
        by_epoch=False,
        interval=save_steps,
        max_keep_ckpts=save_total_limit,
    ),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method="fork", opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend="nccl"),
)

# set visualizer
visualizer = None

# set log level
log_level = "INFO"

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)

 将模型和地址改为自己的路径

3.2 启动微调

cd /root/101
conda activate xtuner_101
xtuner train internlm2_5_chat_7b_qlora_alpaca_e3_copy.py --deepspeed deepspeed_zero1

3.3 合并

3.3.1  将PTH格式转换为HuggingFace格式

xtuner convert pth_to_hf internlm2_5_chat_7b_qlora_alpaca_e3_copy.py ./work_dirs/internlm2_5_chat_7b_qlora_alpaca_e3_copy/iter_195.pth ./work_dirs/hf

3.3.2  合并adapter和基础模型

xtuner convert merge \
/root/internlm2_5-7b-chat \
./work_dirs/hf \
./work_dirs/merged \
--max-shard-size 2GB \

完成这两个步骤后,合并好的模型将保存在./work_dirs/merged目录下,你可以直接使用这个模型进行推理了。 

3.4 推理

from transformers import AutoModelForCausalLM, AutoTokenizer
import time

# 加载模型和分词器
# model_path = "./lora_output/merged"
model_path = "./internlm2_5-7b-chat"
print(f"加载模型:{model_path}")

start_time = time.time()

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True, torch_dtype="auto", device_map="auto"
)

def classify_paper(title, authors, abstract, additional_info=""):
    # 构建输入,包含多选选项
    prompt = f"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', please determine the scientific category of this paper. {additional_info}\n\nA. astro-ph\nB. cond-mat.mes-hall\nC. cond-mat.mtrl-sci\nD. cs.CL\nE. cs.CV\nF. cs.LG\nG. gr-qc\nH. hep-ph\nI. hep-th\nJ. quant-ph"

    # 设置系统信息
    messages = [
        {"role": "system", "content": "你是个优秀的论文分类师"},
        {"role": "user", "content": prompt},
    ]

    # 应用聊天模板
    input_text = tokenizer.apply_chat_template(messages, tokenize=False)

    # 生成回答
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=10,  # 减少生成长度,只需要简短答案
        temperature=0.1,  # 降低温度提高确定性
        top_p=0.95,
        repetition_penalty=1.0,
    )

    # 解码输出
    response = tokenizer.decode(
        outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
    ).strip()

    # 如果回答中包含选项标识符,只返回该标识符
    for option in ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]:
        if option in response:
            return option

    # 如果回答不包含选项,返回完整回答
    return response

# 示例使用
if __name__ == "__main__":
    title = "Outilex, plate-forme logicielle de traitement de textes 'ecrits"
    authors = "Olivier Blanc (IGM-LabInfo), Matthieu Constant (IGM-LabInfo), Eric Laporte (IGM-LabInfo)"
    abstract = "The Outilex software platform, which will be made available to research, development and industry, comprises software components implementing all the fundamental operations of written text processing: processing without lexicons, exploitation of lexicons and grammars, language resource management. All data are structured in XML formats, and also in more compact formats, either readable or binary, whenever necessary; the required format converters are included in the platform; the grammar formats allow for combining statistical approaches with resource-based approaches. Manually constructed lexicons for French and English, originating from the LADL, and of substantial coverage, will be distributed with the platform under LGPL-LR license."

    result = classify_paper(title, authors, abstract)
    print(result)

    # 计算并打印总耗时
    end_time = time.time()
    total_time = end_time - start_time
    print(f"程序总耗时:{total_time:.2f}秒")

 推理结果如下:

微调前模型推理

微调后模型推理

3.5 部署

pip install lmdeploy
python -m lmdeploy.pytorch.chat ./work_dirs/merged \
--max_new_tokens 256 \   
--temperture 0.8 \   
--top_p 0.95 \   
--seed 0

4.评测(跳过)

5.上传模型到魔搭

pip install modelscope

 使用脚本

from modelscope.hub.api import HubApi
YOUR_ACCESS_TOKEN='自己的令牌'
api=HubApi()
api.login(YOUR_ACCESS_TOKEN)

from modelscope.hub.constants import Licenses, ModelVisibility
owner_name='Raven10086'
model_name='InternLM-gmz-camp5'
model_id=f"{owner_name}/{model_name}"
api.create_model(
    model_id,
    visibility=ModelVisibility.PUBLIC,
    license=Licenses.APACHE_V2,
    chinese_name="gmz文本分类微调端侧小模型"
)
api.upload_folder(
    repo_id=f"{owner_name}/{model_name}",
    folder_path='/root/swift_output/InternLM3-8B-SFT-Lora/v5-20250517-164316/checkpoint-62-merged',
    commit_message='fast commit',
    )

 上传成功截图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2379863.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PC:使用WinSCP密钥文件连接sftp服务器

1. 打开winscp工具,点击“标签页”->“新标签页” 2. 点击“高级"->“高级” 3. 点击"验证"->“选择密钥文件” 选择ppk文件,如果没有ppk文件选择pem文件,会自动生成ppk文件 点击确定 4. 输入要连接到的sftp服务器的…

1688正式出海,1688跨境寻源通接口接入,守卫的是国内工厂资源

在1688平台的跨境招商直播中,许多想要进入跨境市场的初学者商家纷纷提问:货通全球的入口在哪里?小白商家应该如何操作?商品为何上传失败? 从表面上看,这似乎是1688平台在拓展海外市场的一次积极“进攻”。…

力扣303 区域和检索 - 数组不可变

文章目录 题目介绍题解 题目介绍 题解 不用管第一个null,从第二个开始看就可以 法一:暴力解法 class NumArray {private int[] nums;public NumArray(int[] nums) {this.nums nums;}public int sumRange(int left, int right) {int res 0;for (int i…

Spring的后置处理器是干什么用的?扩展点又是什么?

Spring 的后置处理器和扩展点是其框架设计的核心机制,它们为开发者提供了灵活的扩展能力,允许在 Bean 的生命周期和容器初始化过程中注入自定义逻辑。 1. 后置处理器(Post Processors) 后置处理器是 Spring 中用于干预 Bean 生命…

[ linux-系统 ] 进程地址空间

验证地址空间 父子进程的变量值不同但是地址相同,说明该地址绝对不是物理地址 我们叫这种地址为虚拟地址/线性地址 分析与结论 上述实验表明,父子进程的变量地址相同但内容不同,说明地址为虚拟地址,且父子进程有各自独立的物理…

文件名是 ‪E:\20250512_191204.mp4, EV软件录屏,未保存直接关机损坏, 如何修复?

去github上下载untrunc 工具就能修复 https://github.com/anthwlock/untrunc/releases 如果访问不了 本机的 hosts文件设置 140.82.112.3 github.com 199.232.69.194 github.global.ssl.fastly.net 就能访问了 实在不行,从这里下载,传上去了 https://do…

Java常见API文档(下)

格式化的时间形式的常用模式对应关系如下: 空参构造创造simdateformate对象,默认格式 练习.按照指定格式展示 package kl002;import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.Date;public class Date3 {publi…

DRIVEGPT4: 通过大语言模型实现可解释的端到端自动驾驶

《DriveGPT4: Interpretable End-to-End Autonomous Driving via Large Language Model》 2024年10月发表,来自香港大学、浙江大学、华为和悉尼大学。 多模态大型语言模型(MLLM)已成为研究界关注的一个突出领域,因为它们擅长处理…

构建共有语料库 - Wiki 语料库

中文Wiki语料库主要指的是从中文Wikipedia(中文维基百科)提取的文本数据。维基百科是一个自由的、开放编辑的百科全书项目,覆盖了从科技、历史到文化、艺术等广泛的主题。 对于基于RAG的应用来说,把Wiki语料作为一个公有的语料库…

苍穹外卖项目中的 WebSocket 实战:实现来单与催单提醒功能

🚀 苍穹外卖项目中的 WebSocket 实战:实现来单与催单提醒功能 在现代 Web 应用中,实时通信成为提升用户体验的关键技术之一。WebSocket 作为一种在单个 TCP 连接上进行全双工通信的协议,被广泛应用于需要实时数据交换的场景&#…

Win10 安装单机版ES(elasticsearch),整合IK分词器和安装Kibana

一. 先查看本机windows是否安装了ES(elasticsearch),检查方法如下: 检查进程 按 Ctrl Shift Esc 组合键打开 “任务管理器”。在 “进程” 选项卡中,查看是否有 elasticsearch 相关进程。如果有,说明系统安装了 ES。 检查端口…

【Redis】List 列表

文章目录 初识列表常用命令lpushlpushxlrangerpushrpushxlpop & rpoplindexlinsertllen阻塞操作 —— blpop & brpop 内部编码应用场景 初识列表 列表类型,用于存储多个字符串。在操作和实现上,类似 C 的双端队列,支持随机访问(O(N)…

JUC入门(四)

ReadWriteLock 代码示例: package com.yw.rw;import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.ReentrantReadWriteLock;public class ReadWriteDemo {public static void main(String[] args) {MyCache myCache new MyCache…

【HarmonyOS 5】鸿蒙mPaaS详解

【HarmonyOS 5】鸿蒙mPaaS详解 一、mPaaS是什么? mPaaS 是 Mobile Platform as a Service 的缩写,即移动开发平台。 蚂蚁移动开发平台mPaaS ,融合了支付宝科技能力,可以为移动应用开发、测试、运营及运维提供云到端的一站式解决…

无法加载文件 E:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本

遇到“无法加载文件 E:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本”这类错误,通常是因为你的 PowerShell 执行策略设置为不允许运行脚本。在 Windows 系统中,默认情况下,出于安全考虑,PowerShell 可能会阻止运行未…

【C++模板与泛型编程】实例化

目录 一、模板实例化的基本概念 1.1 什么是模板实例化? 1.2 实例化的触发条件 1.3 实例化的类型 二、隐式实例化 2.1 隐式实例化的工作原理 2.2 类模板的隐式实例化 2.3 隐式实例化的局限性 三、显式实例化 3.1 显式实例化声明(extern templat…

什么是RDMA?

什么是RDMA? RDMA(RemoteDirect Memory Access)技术全称远程直接内存访问,就是为了解决网络传输中服务器端数据处理的延迟而产生的。它将数据直接从一台计算机的内存传输到另一台计算机,无需双方操作系统的介入。这允许高吞吐、低延迟的网络…

ASIC和FPGA,到底应该选择哪个?

ASIC和FPGA各有优缺点。 ASIC针对特定需求,具有高性能、低功耗和低成本(在大规模量产时);但设计周期长、成本高、风险大。FPGA则适合快速原型验证和中小批量应用,开发周期短,灵活性高,适合初创企…

Python学习笔记--使用Django操作mysql

注意:本笔记基于python 3.12,不同版本命令会有些许差别!!! Django 模型 Django 对各种数据库提供了很好的支持,包括:PostgreSQL、MySQL、SQLite、Oracle。 Django 为这些数据库提供了统一的调…

计算机视觉设计开发工程师学习路线

以下是一条系统化的计算机视觉(CV)学习路线,从基础到进阶,涵盖理论、工具和实践,适合逐步深入,有需要者记得点赞收藏哦: 相关学习:python深度学习,python代码定制 python…