Prompt-Tuning 提示词微调

news2025/5/17 2:30:59

1. Hard Prompt

定义: Hard prompt 是一种更为具体和明确的提示,要求模型按照给定的信息生成精确的结果,通常用于需要模型提供准确答案的任务.

原理: Prompt Tuning原理如下图所示:冻结主模型全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示层,即一个Embedding模块。论文实验表明,只要模型规模够大,简单加入 Prompt tokens 进行微调,就能取得很好的效果。

优点: 资源消耗比 Soft Prompt 小,容易收敛。

2. Soft Prompt

**定义:**Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在生成结果时有更大的自由度,通常用于启发模型进行创造性的生成。

缺点: 资源消耗大,参数都是随机初始化的,Loss 会有震荡。

在这里插入图片描述

from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit, PeftModel
 
# 分词器
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")
 
# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
 
if __name__ == "__main__":
    # 加载数据集
    dataset = load_from_disk("./PEFT/data/alpaca_data_zh")
    
    # 处理数据
    tokenized_ds = dataset.map(process_func, remove_columns = dataset.column_names)
    # print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
    # print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
    
    # 创建模型
    model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
    
    # 设置 Prompt-Tuning
    # Soft Prompt
    # config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化
    # Hard Prompt
    config = PromptTuningConfig(task_type = TaskType.CAUSAL_LM,
                                prompt_tuning_init = PromptTuningInit.TEXT,
                                prompt_tuning_init_text = "下面是一段人与机器人的对话。", # 设置hard_prompt的具体内容
                                num_virtual_tokens = len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),
                                tokenizer_name_or_path = "Langboat/bloom-1b4-zh")
    model = get_peft_model(model, config) # 生成Prompt-Tuning对应的model
    print(model.print_trainable_parameters())
    
    # 训练参数
    args = TrainingArguments(
        output_dir = "/tmp_1203",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        logging_steps = 10,
        num_train_epochs = 1
    )
    
    # trainer
    trainer = Trainer(
        model = model,
        args = args,
        train_dataset = tokenized_ds,
        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, padding = True)
    )
    
    # 训练模型
    trainer.train()
    
    # 模型推理
    model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
    peft_model = PeftModel.from_pretrained(model = model, model_id = "/tmp_1203/checkpoint-500/")
    peft_model = peft_model.cuda()
    ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)
    print(tokenizer.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2342106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

asp.net core webapi+efcore

简洁的restfull风格 目前c#提供了多种风格的web编程,因为微软有自己的前端,所以集成了很多内容,不过基于现在编程前后端分离的模式,webapi是合适的。 webapi 目前网络上有很多介绍,不反复说这个了。在建立控制器时&…

前端渲染pdf文件解决方案-pdf.js

目录 一、前言 二、简介 1、pdf.js介绍 2、插件版本参数 三、通过viewer.html实现预览(推荐) 1、介绍 2、部署 【1】下载插件包 【2】客户端方式 【3】服务端方式(待验证) 3、使用方法 【1】预览PDF文件 【2】外部搜索…

从边缘到云端,如何通过时序数据库 TDengine 实现数据的全局洞

在当今数字化转型加速的背景下,海量的数据生成和实时处理需求已成为企业面临的关键挑战。无论是物联网设备、工业自动化系统,还是智能城市的各类传感器,数据的采集、传输与分析效率,直接影响企业的决策与运营。为此,TD…

2025.04.23【Treemap】树状图数据可视化指南

Multi-level treemap How to build a treemap with group and subgroups. Customization Customize treemap labels, borders, color palette and more 文章目录 Multi-level treemapCustomization Treemap 数据可视化指南Treemap 的基本概念为什么使用 TreemapTreemap 的应用…

蓝桥杯 15.小数第n位

小数第n位 原题目链接 题目描述 我们知道,整数做除法时,有时会得到有限小数,有时会得到无限循环小数。 如果我们把有限小数的末尾加上无限多个 0,它们就具有了统一的形式。 本题的任务是:在上述约定下&#xff0c…

用高斯溅射技术跨越机器人模拟与现实的鸿沟:SplatSim 框架解析

在机器人领域,让机器人在现实世界中精准执行任务是大家一直追求的目标。可模拟环境和现实世界之间存在着不小的差距,特别是基于 RGB 图像的操作策略,从模拟转移到现实时总是状况百出。 今天咱们就来聊聊 SplatSim 框架,看看它是怎…

AI大模型学习十一:‌尝鲜ubuntu 25.04 桌面版私有化sealos cloud + devbox+minio,实战运行成功

一、说明 用了ubuntu 25.04,内核为GNU/Linux 6.14.0-15-generic x86_64,升级了部分image,过程曲折啊 sealos 能干啥 对集群生命周期进行管理,一键安装高可用 Kubernetes 集群,增删节点清理集群自恢复等 通过 sealos…

如何在 Python 项目中引入 Rust 函数

目录 1. 初始化 Python 项目2. 添加 Rust 开发工具3. 初始化 Rust 项目4. 开发模式构建5. 验证模块是否成功安装6. 测试 Rust 函数总结 (封面pid: 129416070) Python 是一门非常流行的编程语言,具有易于使用和开发的特点。然而,随着项目需求的增长和性能…

聊聊SpringAI流式输出的底层实现?

在 Spring AI 中,流式输出(Streaming Output)是一种逐步返回 AI 模型生成结果的技术,允许服务器将响应内容分批次实时传输给客户端,而不是等待全部内容生成完毕后再一次性返回。 这种机制能显著提升用户体验&#xff…

MySQL 8 自动安装脚本(CentOS-7 系统)

文章目录 一、MySQL 8 自动安装脚本脚本说明📌 使用脚本前提条件1. 操作系统2. 用户权限3. 网络要求 📌 脚本的主要功能1. 环境检查2. MySQL 自动安装3. 自动配置 MySQL4. 防火墙配置5. 验证与输出 📌 适用场景 二、执行sh脚本1. 给予脚本执行…

在Notepad++中使用NppAtyle插件格式化代码

参考链接:Artistic Style 使用教程(中文版) 1.下载NppAStyle插件(根据版本,选择32位或者64位) https://github.com/ywx/NppAStyle/releases 2.菜单栏中选择:插件->打开插件文件夹 创建文件夹…

拼多多面经,暑期实习Java一面

项目中的设计模式 mysql连接过程,索引,分库分表场景,路由策略 redis使用场景,分片集群怎么搭建与路由,数据一致性 分布式锁怎么用的,具体使用参数 线程池怎么用的,过程 sql having 分布式事务 如…

FramePack:让视频生成更高效、更实用

想要掌握如何将大模型的力量发挥到极致吗?叶梓老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具(限时免费)。 1小时实战课程,您将学习到如何轻松上手并有效利用 Llama Factory 来微调您的模型,以发挥其…

ctfshow web8

前言 学习内容:简单的盲注脚本的书写 web8 这个题目题目手动注入很麻烦 主要是他过滤了 union 空格和 过滤了union的解决方法 1、使用盲注(报错注入和盲注) 2、使用时间盲注 3、堆叠注入 盲注脚本的书写 首先他是有注入点的 然后熟悉requests包的使用 …

gem5-gpu教程03 当前的gem5-gpu软件架构(因为涉及太多专业名词所以用英语表达)

Current gem5-gpu Software Architecture 这是当前gem5-gpu软件架构的示意图。 Ruby是在gem5-gpu上下文中用于处理CPU和GPU之间内存访问的高度可配置的内存系统 CudaCore (src/gpu/gpgpu-sim/cuda_core.*, src/gpu/gpgpu-sim/CudaCore.py) Wrapper for GPGPU-Sim shader_cor…

TDengine 查询引擎设计

简介 TDengine 作为一个高性能的时序大数据平台,其查询与计算功能是核心组件之一。该平台提供了丰富的查询处理功能,不仅包括常规的聚合查询,还涵盖了时序数据的窗口查询、统计聚合等高级功能。这些查询计算任务需要 taosc、vnode、qnode 和…

【官方正版,永久免费】Adobe Camera Raw 17.2 win/Mac版本 配合Adobe22-25系列软

Adobe Camera Raw 2025 年 2 月版(版本 17.2)。目前为止最新版新版已经更新2个月了,我看论坛之前分享的还是2024版,遂将新版分享给各位。 Adobe Camera Raw,支持Photoshop,lightroom等Adobe系列软件&#…

【论文精读】Reformer:高效Transformer如何突破长序列处理瓶颈?

目录 一、引言:当Transformer遇到长序列瓶颈二、核心技术解析:从暴力计算到智能优化1. 局部敏感哈希注意力(LSH Attention):用“聚类筛选”替代“全量计算”关键步骤:数学优化: 2. 可逆残差网络…

jmeter中监控服务器ServerAgent

插件下载: 将ServerAgent上传至需要监控的服务器,mac/liunx启动startAgent.sh(启动命令:./startAgent.sh) 在jmeter中添加permon监控组件 配置需要监控的服务器IP地址,添加需要监控的资源 注意&#xf…

网络结构及安全科普

文章目录 终端联网网络硬件基础网络协议示例:用户访问网页 OSI七层模型网络攻击(Hack)网络攻击的主要类别(一)按攻击目标分类(二)按攻击技术分类 网络安全防御 典型攻击案例相关名词介绍网络连接…