基于 ChatGLM2 和 OpenVINO™ 打造中文聊天助手

news2025/6/21 0:46:30

点击蓝字

关注我们,让开发变得更有趣

作者 | 英特尔 AI 软件工程师 杨亦诚

排版 | 李擎

基于ChatGLM2和OpenVINO™打造中文聊天助手

ChatGLM 是由清华大学团队开发的是一个开源的、支持中英双语的类 ChatGPT 大语言模型,它能生成相当符合人类偏好的回答,  ChatGLM2 是开源中英双语对话模型 ChatGLM 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,通过全面升级的基座模型,带来了更强大的性能,更长的上下文,并且该模型对学术研究完全开放,登记后亦允许免费商业使用。接下来我们分享一下如何基于 ChatGLM2-6B 和 OpenVINO™ 工具套件来打造一款聊天机器人。

项目仓库地址:https://github.com/OpenVINO-dev-contest/chatglm2.openvino

0b4b50ba28296589965c27c4c2a00e74.png

注1:由于 ChatGLM2-6B 对在模型转换和运行过程中对内存的占用较高,推荐使用支持 128Gb 以上内存的的服务器终端作为测试平台。

注2:本文仅分享部署 ChatGLM2-6B 原始预训练模型的方法,如需获得自定义知识的能力,需要对原始模型进行 Fine-tune;如需获得更好的推理性能,可以使用量化后的模型版本。

OpenVINO™

236141f115b4663168abb96a1253db51.gif

模型导出

第一步,我们需要下载 ChatGLM2-6B 模型,并将其导出为 OpenVINO™ 所支持的IR格式模型进行部署,由于 ChatGLM 团队已经将 6B 版本的预训练模型发布在 Hugging Face 平台上,支持通过 Transformer 库进行推理,但不支持基于Optimum 的部署方式(可以参考 Llama2 的文章),因此这里我们需要提取 Transformer 中的 ChatGLM2 的 PyTorch 模型对象,并实现模型文件的序列化。主要步骤可以分为:

1.获取 PyTorch 模型对象

通过Transformer库获取PyTorch对象,由于目前Transformer中原生的ModelForCausalLM类并不支持ChatGLM2模型架构,因此需要添加trust_remote_code=True参数,从远程模型仓库中获取模型结构信息,并下载权重。

model = AutoModel.from_pretrained(args.model_id,
                                  trust_remote_code=True).float()

2.模拟并获取模型的输入输出参数

在调用 torch.onnx.export 接口将模型对象导出为 ONNX 文件之前,我们首先需要获取模型的输入和输出信息。由于 ChatGLM2 存在 KV cache 机制,因此这个步骤中会模拟第一次文本生成时不带 cache 的输入,并将其输出作为第二次迭代时的 cache 输入,再通过第二次迭代来验证输入数据是否完整。以下分别第一次和第二次迭代的 PyTorch 代码:

outputs = model.forward(**input_tensors)


outputs2 = model.forward(input_ids=input_ids,
                         attention_mask=attention_mask,
                         position_ids=position_ids,
                         past_key_values=past_key_values)

3.导出为ONNX格式

在获取完整的模型输入输出信息后,我们可以利用 torch.onnx.export 接口将模型导出为 ONNX 文件,如果通过模型结构可视化工具查看该文件的话,不难发现原始模型对象中 attention_mask 这个 input layer 消失了,个人理解是因为 attention_mask 对模型的输出结果没有影响,并且其实际功能已经被 position_ids 代替了,所以 ONNX 在转化模型的过程中自动将其优化掉了。

4.利用 OpenVINO™ Model Optimizer 进行格式转换

最后一步可以利用 OpenVINO™ 的 Model Optimizer 工具将模型文件转化为 IR 格式,并压缩为 FP16 精度,将较原始 FP32 模式,FP16 模型可以在保证模型输出准确性的同时,减少磁盘占用,并优化运行时的内存开销。

6c2a6a54629ec43fd96531f6a19a3538.gif

模型部署

当完成 IR 模型导出后,我们首先需要构建一个简单的问答系统 pipeline,测试效果。如下图所示, Prompt 提示会送入 Tokenizer 进行分词和词向量编码,然后有 OpenVINO™ 推理获得结果(蓝色部分),来到后处理部分,我们会把推理结果进行进一步的采样和解码,最后生成常规的文本信息。这里 Top-K 以及 Top-P作 为答案的筛选方法,最终从筛选后的答案中进行随机采样输出结果。

b967e5a3ca55ca69ebbb34e6fd7e43a2.png

图:ChatGLM2 问答任务流程

整个 pipeline 的大部分代码都可以套用文本生成任务的常规流程,其中比较复杂一些的是 OpenVINO™ 推理部分的工作,由于 ChatGLM2-6B 文本生成任务需要完成多次递归迭代,并且每次迭代会存在 cache 缓存,因此我们需要为不同的迭代轮次分别准备合适的输入数据。接下来我们详细解构一下模型的运行逻辑:

de6c23c51e4cb300f4d545306a5700cb.png

图:ChatGLM2-6B模型输入输出原理

ChatGLM2 的 IR 模型的输入主要由三部分组成:

· input_ids 是向量化后的提示输入

· position_ids 用来描述输入的位置信息,例如原始的 prompt 数据为 “How are you”, 那这是 position_ids 就是[[1,2,3]], 如果输入为原始 prompt 的后的第一个被预测词:”I”, 那 position_ids 则为[[4]], 以此类推。

· past_key_values.x 是由一连串数据构成的集合,用来保存每次迭代过程中可以被共享的 cache.

ChatGLM2 的 IR 模型的输出则由两部分组成:

· Logits 为模型对于下一个词的预测,或者叫 next token

· present_key_values.x 则可以被看作 cache,直接作为下一次迭代的 past_key_values.x 值

整个 pipeline 在运行时会对 ChatGLM2 模型进行多次迭代,每次迭代会递归生成对答案中下一个词的预测,直到最终答案长度超过预设值 max_sequence_length,或者预测的下一个词为终止符 eos_token_id。

· 第一次迭代

如图所示在一次迭代时(N=1)input_ids 为提示语句,此时我们还需要利用 Tokenizer 分词器将原始文本转化为输入向量,而由于此时无法利用 cache 进行加速,past_key_values.x 系列向量均为空值,这里我们会初始化一个维度为[0,1,2,128]的空值矩阵

· 第N次迭代

当第一次迭代完成后,会输出对于答案中第一个词的预测 Logits,以及 cache 数据,我们可以将这个 Logits 作为下一次迭代的 input_ids 再输入到模型中进行下一次推理(N=2), 此时我们可以利用到上次迭代中的 cache 数据也就是 present_key_values.x,而无需每次将完整的“提示+预测词”一并送入模型,从而减少一些部分重复的计算量。这样周而复始,将当前的预测词所谓一次迭代的输入,就可以逐步生成所有的答案。

详细代码如下,这里可以看到如果 past_key_values 等于 None 就是第一次迭代,此时需要构建一个值均为空的 past_key_values 系列,如果不为 None 则会将真实的 cache 数据加入到输入中。

if past_key_values is not None:
                new_position_id += 1
                inputs["position_ids"] = new_position_id
                inputs.update(past_key_values)
            else:
                inputs["position_ids"] = position_ids
                shape_input_ids = input_ids.shape
                for input_name in past_names:
                    model_inputs = self.model.input(input_name)
                    shape = model_inputs.get_partial_shape()
                    if shape[0].is_dynamic:
                        shape[0] = 0
                    if shape[1].is_dynamic:
                        shape[1] = shape_input_ids[0]
                    inputs[input_name] = Tensor(
                        model_inputs.get_element_type(), shape.get_shape())

测试输出如下:

命令:python3 generate_ov.py -m  "THUDM/chatglm2-6b" -p "请介绍一下上海?"

ChatGLM2-6B 回答:

“上海是中国的一个城市,位于东部沿海地区,是中国重要的经济、文化和科技中心之一。

上海是中国的一个重要港口城市,是中国重要的进出口中心之一,也是全球著名的金融中心之一。上海是亚洲和全球经济的中心之一,拥有许多国际知名金融机构和跨国公司总部。

上海是一个拥有悠久历史和丰富文化的城市。上海是中国重要的文化城市之一,拥有许多历史文化名胜和现代文化地标。上海是中国的一个重要旅游城市,吸引了大量国内外游客前来观光旅游。“

上海是一个拥有重要经济功能的现代城市。“

OpenVINO™

2a2bf091ad2329c23b49f463533329b9.gif

聊天助手

官方示例中 ChatGLM2 的主要用途为对话聊天,相较于问答模型模式中一问一答的形式,对话模式则需要构建更为完整的对话,此时模型在生成答案的过程中还需要考虑到之前对话中的信息,并将其作为 cache 数据往返于每次迭代过程中,因此这里我们需要额外设计一个模板,用于构建每一次的输入数据,让模型能够给更充分理解哪些是历史对话,哪些是新的对话问题。

69f9fa7fa4bfed0ff3c3bfe6aa4497c7.png

图:ChatGLM2对话任务流程

这里的 text 模板是由“引导词+历史记录+当前问题(提示)”三部分构成:

· 引导词:描述当前的任务,引导模型做出合适的反馈

· 历史记录:记录聊天的历史数据,包含每一组问题和答案

· 当前问题:类似问答模式中的问题

def build_inputs(history: list[tuple[str, str]], query: str, system: str = ""):
    prompt = "{}\n".format(system)
    for i, (old_query, response) in enumerate(history):
        prompt += "[Round {}]\n问:{}\n答:{}\n".format(i + 1, old_query, response)
        prompt += "[Round {}]\n问:{}\n答:".format(len(history) + 1, query)
    print(prompt)
    return prompt

我们采用 streamlit 框架构建构建聊天机器人的 web UI 和后台处理逻辑,同时希望该聊天机器人可以做到实时交互,实时交互意味着我们不希望聊天机器人在生成完整的文本后再将其输出在可视化界面中,因为这个需要用户等待比较长的时间来获取结果,我们希望在用户在使用过程中可以逐步看到模型所预测的每一个词,并依次呈现。因此我们需要创建一个可以被迭代的方法 generate_iterate,可以依次获取模型迭代过程中每一次的预测结果,并将其依次添加到最终答案中,并逐步呈现。

当完成任务构建后,我们可以通过 streamlit run chat_robot.py 命令启动聊天机器,并访问本地地址进行测试。这里选择了几个常用配置参数,方便开发者根据机器人的回答准确性进行调整:

· 系统提示词:用于引导模型的任务方向

· max_tokens: 生成句子的最大长度。

· top-k: 从置信度对最高的k个答案中随机进行挑选,值越高生成答案的随机性也越高。

· top-p: 从概率加起来为 p 的答案中随机进行挑选, 值越高生成答案的随机性也越高,一般情况下,top-p 会在 top-k 之后使用。

· Temperature: 从生成模型中抽样包含随机性, 高温意味着更多的随机性,这可以帮助模型给出更有创意的输出。如果模型开始偏离主题或给出无意义的输出,则表明温度过高。

26e30ee089e43fc4bd8f5f4cd826cf9b.png

注3:由于 ChatGLM2-6B 模型比较大,首次硬件加载和编译的时间会相对比较久

OpenVINO™

a9d6fc10dd80eb1d2a33a9cd3d101ba4.gif

总结

作为当前最热门的双语大语言模型之一,ChatGLM2  凭借在各大基准测试中出色的成绩,以及支持微调等特性被越来越多开发者所认可和使用。利用  OpenVINO™ 构建 ChatGLM2 系列任务可以进一步提升其模型在英特尔平台上的性能,并降低部署门槛。

参考资料 

1.Hugging Face Transformer: 

  https://huggingface.co/docs/transformers

2.ChatGLM26B Hugging Face:

  https://huggingface.co/THUDM/chatglm2-6b

3.ChatGLM2-6B-TensorRT

  https://github.com/Tlntin/ChatGLM2-6B-TensorRT

OpenVINO™

--END--

你也许想了解(点击蓝字查看)⬇️➡️ 基于 Llama2 和 OpenVINO™ 打造聊天机器人➡️ OpenVINO™ DevCon 2023重磅回归!英特尔以创新产品激发开发者无限潜能➡️ 5周年更新 | OpenVINO™  2023.0,让AI部署和加速更容易➡️ OpenVINO™5周年重头戏!2023.0版本持续升级AI部署和加速性能➡️ OpenVINO™2023.0实战 | 在 LabVIEW 中部署 YOLOv8 目标检测模型➡️ 开发者实战系列资源包来啦!➡️ 以AI作画,祝她节日快乐;简单三步,OpenVINO™ 助你轻松体验AIGC
➡️ 还不知道如何用OpenVINO™作画?点击了解教程。➡️ 几行代码轻松实现对于PaddleOCR的实时推理,快来get!➡️ 使用OpenVINO 在“端—边—云”快速实现高性能人工智能推理➡️ 图片提取文字很神奇?试试三步实现OCR!➡️【Notebook系列第六期】基于Pytorch预训练模型,实现语义分割任务➡️使用OpenVINO™ 预处理API进一步提升YOLOv5推理性能
扫描下方二维码立即体验 
OpenVINO™ 工具套件 2023.0

点击 阅读原文 立即体验OpenVINO 2023.0

c927a4a8c93ac4c32300602a0c139e8c.png

文章这么精彩,你有没有“在看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/862052.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

160. 相交链表 题解

题目描述:160. 相交链表 - 力扣(LeetCode) 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 注:本题中链表相交是“Y”型的&am…

KDD 2023 | 美团技术团队精选论文解读

本文精选了美团技术团队被KDD 2023收录的7篇论文进行解读,论文覆盖了Feed流推荐、多模态数据、实例分割、用户意图预测等多个方向。这些论文也是美团技术团队与国内多所高校、科研机构合作的成果。希望给从事相关研究工作的同学带来一些启发或者帮助。 ACM SIGKDD&a…

(el-switch)操作(不使用 ts):Element-plus 中 Switch 将默认值修改为 “true“ 与 “false“(字符串)来控制开关

Ⅰ、Element-plus 提供的 Switch 开关组件与想要目标情况的对比: 1、Element-plus 提供 Switch 组件情况: 其一、Element-ui 自提供的 Switch 代码情况为(示例的代码): // Element-plus 自提供的代码: // 此时是使用了 ts 语言环…

如何理解MySQL隔离性---3个记录隐藏字段、undo日志、Read View

目录 一、3个记录隐藏字段 二、undo 日志 三、read view 一、3个记录隐藏字段 本片文章是帮助理解上篇文章Mysql隔离性的辅助知识。 mysql在建表时,不仅仅创建了表的结构,还创建了3个隐藏字段。 DB_TRX_ID :6 byte,最近修改( 修…

uniapp-原生地图截屏返回base64-进行画板编辑功能

一、场景 vue写uniapp打包安卓包,实现原生地图截屏(andirod同事做的)-画板编辑功能 实现效果: 二、逻辑步骤简略 1. 由 原生地图nvue部分,回调返回 地图截屏生成的base64 数据, 2. 通过 uni插件市场 im…

Go异常处理机制panic和recover

recover 使用panic抛出异常后, 将立即停止当前函数的执行并运行所有被defer的函数,然后将panic抛向上一层,直至程序crash。但是也可以使用被defer的recover函数来捕获异常阻止程序的崩溃,recover只有被defer后才是有意义的。 func main() { p…

如何让ES低成本、高性能?滴滴落地ZSTD压缩算法的实践分享

前文分别介绍了滴滴自研的ES强一致性多活是如何实现的、以及如何提升ES的性能潜力。由于滴滴ES日志场景每天写入量在5PB-10PB量级,写入压力和业务成本压力大,为了提升ES的写入性能,我们让ES支持ZSTD压缩算法,本篇文章详细展开滴滴…

Ceph集群安装部署

Ceph集群安装部署 目录 Ceph集群安装部署 1、环境准备 1.1 环境简介1.2 配置hosts解析(所有节点)1.3 配置时间同步2、安装docker(所有节点)3、配置镜像 3.1 下载ceph镜像(所有节点执行)3.2 搭建制作本地仓库(ceph-01节点执行)3.3 配置私有仓库(所有节点执行)3.4 为 Docker 镜像…

C语言可变数组 嵌套的可变数组,翻过了山跨过了河 又掉进了坑

可变数组 ​专栏内容: postgresql内核源码分析 手写数据库toadb 并发编程 个人主页:我的主页 座右铭:天行健,君子以自强不息;地势坤,君子以厚德载物. 概述 数组中元素是顺序存放,这一特性让我们…

Java获取路径时Class.getResource()和ClassLoader.getResource()区别

Java中取资源时,经常用到Class.getResource()和ClassLoader.getResource(),Class.getResourceAsStream()和ClassLoader().getResourceAsStream(),这里来看看他们在取资源文件时候的路径有什么区别的问题。 环境信息: 系统&#…

css3瀑布流布局遇见截断下一列展示后半截现象

css3 瀑布流布局遇见截断下一列展示后半截现象 注:css3实现瀑布流布局简直不要太香~~~~~ 场景-在uniapp项目中 当瀑布流布局column-grap:10px 相邻两列之间的间隙为10px,column-count:2,2列展示…

基于k8s的devOps自动化运维平台架构设计(中英文版本)

▲ 点击上方"DevOps和k8s全栈技术"关注公众号 In the rapidly evolving landscape of software development and IT operations, DevOps has emerged as a transformative approach to bridge the gap between development and operations teams. One of the key ena…

第五期(2022-2023)传统行业云原生技术落地调研报告——央国企篇

随着国务院国资委印发《关于加快推进国有企业数字化转型工作的通知》,开启了国有企业数字化转型的新篇章。大型央、 国企纷纷顺应趋势,加速云化布局,将数字化转型工作定位为“十四五”时期重点任务。同时,越来越多的企业通过云原生…

【Leetcode】155. 最小栈、JZ31 栈的压入、弹出序列

作者:小卢 专栏:《Leetcode》 喜欢的话:世间因为少年的挺身而出,而更加瑰丽。 ——《人民日报》 155. 最小栈 155. 最小栈 题目描述; 设计一个支持 push ,pop ,top …

C语言笔记7

#include <stdio.h> int main(void) {int a123;int b052;//十进制42int c0xa2;//十进制162printf("a%d b%o c%x \n",a,b,c);//分别是十进制 八进制 十六进制printf("a%d b%d c%d \n",a,b,c);printf("Hello 凌迟老头\n");return …

uniapp 使用canvas画海报(微信小程序)

效果展示&#xff1a; 项目要求&#xff1a;点击分享绘制海报&#xff0c;并实现分享到好友&#xff0c;朋友圈&#xff0c;并保存 先实现绘制海报 <view class"data_item" v-for"(item,index) in dataList" :key"index"click"goDet…

并发——线程池,Executor 框架

文章目录 1 简介2 Executor 框架结构(主要由三大部分组成)1) 任务(Runnable /Callable)2) 任务的执行(Executor)3) 异步计算的结果(Future) 3 Executor 框架的使用示意图 1 简介 Executor 框架是 Java5 之后引进的&#xff0c;在 Java 5 之后&#xff0c;通过 Executor 来启动…

vue+springboot基于web的火车高铁铁路订票管理系统

铁路订票管理系统按照权限的类型进行划分&#xff0c;分为用户和管理员两个模块。管理员模块主要针对整个系统的管理进行设计&#xff0c;提高了管理的效率和标准。主要功能包括个人中心、用户管理、火车类型管理、火车信息管理、车票预订管理、车票退票管理、系统管理等&#…

解决遥感技术在生态、能源、大气等领域的碳排放监测及模拟问题

以全球变暖为主要特征的气候变化已成为全球性环境问题&#xff0c;对全球可持续发展带来严峻挑战。2015年多国在《巴黎协定》上明确提出缔约方应尽快实现碳达峰和碳中和目标。2019年第49届 IPCC全会明确增加了基于卫星遥感的排放清单校验方法。随着碳中和目标以及全球碳盘点的现…

单源最短路

无负环 Dijkstra 迪杰斯特拉算法 采用的贪心的策略 每次遍历到始点距离最近且未访问过的顶点的邻接节点&#xff0c;直到扩展到终点为止 Dijkstra求最短路 I 给定一个 n 个点 m 条边的有向图&#xff0c;图中可能存在重边和自环&#xff0c;所有边权均为正值。 请你求出 1 …