ONNX模型的动态和静态量化

news2025/6/2 9:15:35

引言
 通常我们将模型转换为onnx格式之后,模型的体积可能比较大,这样在某些场景下就无法适用。最近想在移动端部署语音识别、合成模型,但是目前的效果较好的模型动辄几个G,于是便想着将模型压缩一下。本文探索了两种压缩方法,适用的场景也不相同。

对比
 我觉得额两者之间最大的区别有两点,支持的模型类型、是否需要校准数据。

 结合我的实践谈一下这两点,①在做语音合成模型量化的时候,模型中大量的卷积操作,使用netron打开后的结果如下图所示。因为一开始直接就用动态量化,但是搞了一圈发现各种报错。原因就是动态量化对卷积操作的支持力度不够,导致各种问题频出。于是兜兜转转又开始用静态量化的方法,最后才成功实现量化。 ②但是我当时量化的时候使用的校准数据是函数生成的,不是真是的输入数据,导致虽然模型量化成功了,但是效果却大打折扣。
在这里插入图片描述

 下面给出两种方法的对比:

特性动态量化静态量化
量化时机推理时动态量化(运行时对权重/激活量化)推理前离线量化(事先量化权重/激活)
是否需要校准数据不需要需要标定数据进行校准
支持的模型类型RNN、Transformer 等支持良好CNN、稳定数据分布模型
部署复杂度简单,一键量化复杂,需要数据流过模型进行校准
加速效果中等(通常1.2–2x)最优(通常2–4x)
加速效果小(但一般略大于静态)最小(通过校准降低损失)

动态量化
 直接贴代码进行分析

import onnx
from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic

name = 'model-steps-3'
# 加载原始 ONNX 模型
model_path = f"matcha-icefall-zh-baker/{name}.onnx"
quantized_model_path = f"matcha-icefall-zh-baker/{name}_quant.onnx"

# 动态量化模型
quantize_dynamic(
    model_input=model_path,  # 输入的原始模型路径
    model_output=quantized_model_path,  # 输出的量化模型路径
    weight_type=QuantType.QInt8,  # 权重量化类型,使用 INT8
    nodes_to_exclude=['/encoder/prenet/conv_layers.0/Conv', '/encoder/prenet/conv_layers.1/Conv_quant', 'Transpose'] 
    op_types_to_quantize=  [
    "MatMul", "Mul", "Add", 
    "Constant", "Shape", "Unsqueeze", "Reshape", 
    "Relu", "Sigmoid", "Softmax", "Tanh", 
    "InstanceNormalization", "Softplus", "Slice", 
    "Where", "RandomNormalLike", "Pad"
]         # ['MatMul', 'Attention', 'LSTM', 'Gather', 'EmbedLayerNormalization', 'Conv']
)
# 如果不想量化某个算子,可以直接通过nodes_to_exclude忽略掉
# 也可以通过op_types_to_quantize选择要对哪些算子进行量化

print(f"Quantized model saved to {quantized_model_path}")

静态量化
 这个代码贴的比较长,其实主要就是针对量化前后的模型进行了分析。

import onnx
import numpy as np
import onnxruntime as ort
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat
import os
import time

class MelCalibrationDataReader(CalibrationDataReader):
    """用于校准的数据读取器,提供MEL频谱图数据"""
    def __init__(self, batch_size=8, num_batches=10, input_shape=(80, 100)):
        """
        初始化校准数据读取器
        Args:
            batch_size: 每批数据的大小
            num_batches: 批次数量
            input_shape: MEL频谱图的形状 (特征维度, 时间步长)
        """
        self.batch_size = batch_size
        self.num_batches = num_batches
        self.input_shape = input_shape
        self.data_counter = 0
        
        # 创建随机数据,在实际应用中应替换为真实数据集
        self.data = []
        for _ in range(num_batches):
            # 随机生成FLOAT32类型的MEL频谱图数据
            mel_data = np.random.randn(batch_size, *input_shape).astype(np.float32) * 1.5
            self.data.append(mel_data)

    def get_next(self):
        """返回下一批校准数据"""
        if self.data_counter >= self.num_batches:
            return None
        
        mel_batch = self.data[self.data_counter]
        input_feed = {'mel': mel_batch}  # 'mel'是输入节点的名称,需要与模型匹配
        self.data_counter += 1
        
        return input_feed

    def rewind(self):
        """重置数据计数器"""
        self.data_counter = 0

def load_real_calibration_data(data_path, batch_size=8, num_batches=10):
    """
    加载真实的校准数据(如果有)
    Args:
        data_path: 数据目录路径
        batch_size: 每批数据的大小
        num_batches: 批次数量
    Returns:
        CalibrationDataReader实例
    """
    # 这里实现从文件加载真实MEL数据的逻辑
    # 如果有特定格式的数据,应该在这里进行加载和预处理
    # 为简化示例,这里仍使用随机数据
    return MelCalibrationDataReader(batch_size, num_batches)

def quantize_onnx_model(model_path, quantized_model_path, calibration_data_reader=None):
    """
    对ONNX模型进行静态量化
    Args:
        model_path: 原始ONNX模型路径
        quantized_model_path: 量化后模型保存路径
        calibration_data_reader: 校准数据读取器
    """
    # 加载原始模型
    model = onnx.load(model_path)
    
    # 输出原始模型信息
    print(f"原始模型输入: {[i.name for i in model.graph.input]}")
    print(f"原始模型输出: {[o.name for o in model.graph.output]}")
    
    # 设置量化参数
    # QuantType.QInt8 - 8位整数量化
    # QuantType.QUInt8 - 无符号8位整数量化
    quant_type = QuantType.QInt8
    
    # 量化格式选择
    # QuantFormat.QDQ - 使用QuantizeLinear/DequantizeLinear节点对
    # QuantFormat.QOperator - 使用量化算子

    quant_format = QuantFormat.QDQ
    
    # 指定需要量化的算子类型
    op_types_to_quantize = ['Conv' , 'ConvTranspose']  # 
    
    # 指定不需要量化的节点名称(如果有的话)
    # nodes_to_exclude = ['某些不需要量化的节点名称']
    
    # 执行静态量化
    print(f"开始对模型 {model_path} 进行静态量化...")
    start_time = time.time()
    
    quantize_static(
        model_input=model_path,
        model_output=quantized_model_path,
        calibration_data_reader=calibration_data_reader,
        quant_format=quant_format,
        # op_types_to_quantize=op_types_to_quantize,
        per_channel=True,  # 使用每通道量化可以提高精度
        weight_type=quant_type,
        activation_type=quant_type,
        # optimize_model=True  # 在量化前优化模型
    )
    
    quantization_time = time.time() - start_time
    print(f"量化完成,耗时: {quantization_time:.2f} 秒")
    
    return quantized_model_path

def compare_models(original_model_path, quantized_model_path):
    """
    比较原始模型与量化模型的大小和性能
    Args:
        original_model_path: 原始模型路径
        quantized_model_path: 量化模型路径
    """
    # 比较文件大小
    original_size = os.path.getsize(original_model_path) / (1024 * 1024)  # MB
    quantized_size = os.path.getsize(quantized_model_path) / (1024 * 1024)  # MB
    
    print(f"原始模型大小: {original_size:.2f} MB")
    print(f"量化模型大小: {quantized_size:.2f} MB")
    print(f"压缩比: {original_size / quantized_size:.2f}x")
    
    # 创建测试数据
    test_input = np.random.randn(1, 80, 100).astype(np.float32)
    
    # 测试原始模型推理性能
    session_options = ort.SessionOptions()
    session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    
    # 原始模型推理
    print("测试原始模型推理性能...")
    original_session = ort.InferenceSession(original_model_path, session_options)
    input_name = original_session.get_inputs()[0].name
    
    start_time = time.time()
    num_runs = 100
    for _ in range(num_runs):
        original_output = original_session.run(None, {input_name: test_input})
    original_time = (time.time() - start_time) / num_runs
    
    # 量化模型推理
    print("测试量化模型推理性能...")
    quantized_session = ort.InferenceSession(quantized_model_path, session_options)
    input_name = quantized_session.get_inputs()[0].name
    
    start_time = time.time()
    for _ in range(num_runs):
        quantized_output = quantized_session.run(None, {input_name: test_input})
    quantized_time = (time.time() - start_time) / num_runs
    
    print(f"原始模型平均推理时间: {original_time*1000:.2f} ms")
    print(f"量化模型平均推理时间: {quantized_time*1000:.2f} ms")
    print(f"速度提升: {original_time/quantized_time:.2f}x")

def evaluate_model_accuracy(original_model_path, quantized_model_path, test_data=None):
    """
    评估量化模型的精度损失
    Args:
        original_model_path: 原始模型路径 
        quantized_model_path: 量化模型路径
        test_data: 测试数据,如果为None则生成随机数据
    """
    # 创建测试数据
    if test_data is None:
        num_samples = 10
        test_data = np.random.randn(num_samples, 80, 100).astype(np.float32)
    
    # 创建会话
    original_session = ort.InferenceSession(original_model_path)
    quantized_session = ort.InferenceSession(quantized_model_path)
    
    # 获取输入输出名称
    input_name = original_session.get_inputs()[0].name
    
    # 计算均方误差和平均相对误差
    total_mse = 0
    total_rel_err = 0
    
    for i, sample in enumerate(test_data):
        # 添加批次维度
        sample = np.expand_dims(sample, axis=0)
        
        # 运行推理
        original_output = original_session.run(None, {input_name: sample})[0]
        quantized_output = quantized_session.run(None, {input_name: sample})[0]
        
        # 计算均方误差
        mse = np.mean((original_output - quantized_output) ** 2)
        total_mse += mse
        
        # 计算相对误差
        # 避免除以零
        epsilon = 1e-10
        rel_err = np.mean(np.abs((original_output - quantized_output) / (np.abs(original_output) + epsilon)))
        total_rel_err += rel_err
        
        if i < 3:  # 只打印前几个样本的结果
            print(f"样本 {i+1} - MSE: {mse:.6f}, 相对误差: {rel_err:.6f}")
    
    avg_mse = total_mse / len(test_data)
    avg_rel_err = total_rel_err / len(test_data)
    
    print(f"平均均方误差 (MSE): {avg_mse:.6f}")
    print(f"平均相对误差: {avg_rel_err:.6f}")
    
    return avg_mse, avg_rel_err

def main():
    # 模型路径
    original_model_path = "./my_model.onnx"  # 替换为实际模型路径
    quantized_model_path = "./my_model_quant.onnx"  # 量化后模型的保存路径
    
    # 创建校准数据读取器
    # 在实际应用中,应使用真实数据替代随机数据
    calibration_data_reader = MelCalibrationDataReader(
        batch_size=8,
        num_batches=10,
        input_shape=(80, 100)  # 根据实际输入形状调整
    )
    
    # 执行静态量化
    quantized_model_path = quantize_onnx_model(
        original_model_path,
        quantized_model_path,
        calibration_data_reader
    )
    
    # 比较模型大小和性能
    compare_models(original_model_path, quantized_model_path)
    
    # 评估量化精度
    evaluate_model_accuracy(original_model_path, quantized_model_path)
    
    print("ONNX模型静态量化完成!")

if __name__ == "__main__":
    main()

总结
 在对onnx进行量化时,我们要根据自己的模型的结构类型和是否能得到真实的校准数据来选择量化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2393502.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何用Python抓取Google Scholar

文章目录 [TOC](文章目录) 前言一、为什么要抓取Google Scholar&#xff1f;二、Google Scholar 抓取需要什么三、为什么代理对于稳定的抓取是必要的四、一步一步谷歌学者抓取教程4.1. 分页和循环4.2. 运行脚本 五、完整的Google Scholar抓取代码六、抓取Google Scholar的高级提…

Wireshark对usb设备进行抓包找不到USBPcap接口的解决方案

引言 近日工作需要针对usb设备进行抓包&#xff0c;但按照wireshark安装程序流程一步步走&#xff0c;即使勾选了安装USBPcap安装完成后开启wireshark依然不显示USBPcap接口&#xff0c;随设法进行解决。 最终能够正常显示USBPcap接口并能够正常使用进行抓包 解决方案&#x…

Socket 编程 UDP

目录 1. UDP网络编程 1.1 echo server 1.1.1 接口 1.1.1.1 创建套接字 1.1.1.2 绑定 1.1.1.3 bzero 1.1.1.4 htons&#xff08;主机序列转网络序列&#xff09; 1.1.1.5 inet_addr&#xff08;主机序列IP转网络序列IP&#xff09; 1.1.1.6 recvfrom&#xff08;让服务…

Jenkins实践(8):服务器A通过SSH调用服务器B执行Python自动化脚本

Jenkins实践(8):服务器A通过SSH调用服务器B执行Python自动化脚本 1、需求: 1、Jenkins服务器在74上,Python脚本在196服务器上 2、需要在服务器74的Jenkins上调用196上的脚本执行Python自动化测试 2、操作步骤 第一步:Linux Centos7配置SSH免密登录 Linux Centos7配置S…

lua的注意事项2

总之&#xff0c;下面的返回值不是10&#xff0c;a&#xff0c;b 而且

前端八股之HTML

前端秘籍-HTML篇 1. src和href的区别 src 用于替换当前元素&#xff0c;href 用于在当前文档和引用资源之间确立联系。 &#xff08;1&#xff09;src src 是 source 的缩写&#xff0c;指向外部资源的位置&#xff0c;指向的内容将会嵌入到文档中当前标签所在位置&#xff1…

鲲鹏Arm+麒麟V10,国产化信创 K8s 离线部署保姆级教程

Rainbond V6 国产化部署教程&#xff0c;针对鲲鹏 CPU 麒麟 V10 的离线环境&#xff0c;手把手教你从环境准备到应用上线&#xff0c;所有依赖包提前打包好&#xff0c;步骤写成傻瓜式操作指南。别说技术团队了&#xff0c;照着文档一步步来&#xff0c;让你领导来都能独立完成…

【C++ Qt】认识Qt、Qt 项目搭建流程(图文并茂、通俗易懂)

每日激励&#xff1a;“不设限和自我肯定的心态&#xff1a;I can do all things。 — Stephen Curry” 绪论​&#xff1a; 本章将开启Qt的学习&#xff0c;Qt是一个较为古老但仍然在GUI图形化界面设计中有着举足轻重的地位&#xff0c;因为它适合嵌入式和多种平台而被广泛使用…

IoT/HCIP实验-1/物联网开发平台实验Part2(HCIP-IoT实验手册版)

文章目录 概述产品和设备实例的产品和设备产品和设备的关联单个产品有多个设备为产品创建多个设备产品模型和物模型设备影子&#xff08;远程代理&#xff09; 新建产品模型定义编解码插件开发编解码插件工作原理消息类型与二进制码流添加消息&#xff08;数据上报消息&#xf…

Replacing iptables with eBPF in Kubernetes with Cilium

source: https://archive.fosdem.org/2020/schedule/event/replacing_iptables_with_ebpf/attachments/slides/3622/export/events/attachments/replacing_iptables_with_ebpf/slides/3622/Cilium_FOSDEM_2020.pdf 使用Cilium&#xff0c;结合eBPF、Envoy、Istio和Hubble等技术…

数学建模之最短路径问题

1 问题的提出 这个是我们的所要写的题目&#xff0c;我们要用LINGO编程进行编写这个题目&#xff0c;那么就是需要进行思考这个怎么进行构建这个问题的模型 首先起点&#xff0c;中间点&#xff0c;终点我们要对这个进行设计 2 三个点的设计 起点的设计 起点就是我们进去&am…

测试概念 和 bug

一 敏捷模型 在面对在开发项目时会遇到客户变更需求以及合并新的需求带来的高成本和时间 出现的敏捷模型 敏捷宣言 个人与交互重于过程与工具 强调有效的沟通 可用的软件重于完备的文档 强调轻文档重产出 客户协作重于合同谈判 主动及时了解当下的要求 相应变化…

zynq 级联多个ssd方案设计(ECAM BUG修改)

本文讲解采用zynq7045芯片如何实现200T容量高速存储方案设计&#xff0c;对于大容量高速存储卡&#xff0c;首先会想到采用pcie switch级联方式&#xff0c;因为单张ssd的容量是有限制的&#xff08;目前常见的m.2接口容量为4TB&#xff0c;U.2接口容量为16TB&#xff09;&…

brep2seq 论文笔记

Brep2Seq: a dataset and hierarchical deep learning network for reconstruction and generation of computer-aided design models | Journal of Computational Design and Engineering | Oxford Academic 这段文本描述了一个多头自注意力机制&#xff08;MultiHead Attenti…

【运维实战】Linux 中设置 sudo ,8个有用的 sudoers 配置!

在Linux及其他类Unix操作系统中&#xff0c;只有 root 用户能够执行所有命令并进行关键系统操作&#xff0c;例如安装更新软件包、删除程序、创建用户与用户组、修改重要系统配置文件等。 但担任 root 角色的系统管理员可通过配置sudo命令&#xff0c;允许普通系统用户执行特定…

江科大SPI串行外设接口hal库实现

hal库相关函数 初始化结构体 typedef struct {uint32_t Mode; /*SPI模式*/uint32_t Direction; /*SPI方向*/uint32_t DataSize; /*数据大小*/uint32_t CLKPolarity; /*时钟默认极性控制CPOL*/uint32_t CLKPhase; /*…

[网页五子棋][对战模块]前后端交互接口(建立连接、连接响应、落子请求/响应),客户端开发(实现棋盘/棋子绘制)

文章目录 约定前后端交互接口建立连接建立连接响应针对"落子"的请求和响应 客户端开发实现棋盘/棋子绘制部分逻辑解释 约定前后端交互接口 对战模块和匹配模块使用的是两套逻辑&#xff0c;使用不同的 websocket 的路径进行处理&#xff0c;做到更好的耦合 建立连接 …

【ArcGIS Pro微课1000例】0071:将无人机照片生成航线、轨迹点、坐标高程、方位角

文章目录 一、照片预览二、生成轨迹点三、照片信息四、查看方位角五、轨迹点连成线一、照片预览 数据位于配套实验数据包中的0071.rar,解压之后如下: 二、生成轨迹点 地理标记照片转点 (数据管理),用于根据存储在地理标记照片文件(.jpg 或 .tif)元数据中的 x、y 和 z 坐…

Ubuntu Zabbix 钉钉报警

文章目录 概要Zabbix警监控脚本技术细节配置zabbix告警 概要 提示&#xff1a;本教程用于Ubuntu &#xff0c;zabbix7.0 Zabbix警监控脚本 提示&#xff1a;需要创建一个脚本 #检查是否有 python3 和版本 rootzabbix:~# python3 --version Python 3.12.3在/usr/lib/zabbix/…

threejs顶点UV坐标、纹理贴图

1. 创建纹理贴图 通过纹理贴图加载器TextureLoader的load()方法加载一张图片可以返回一个纹理对象Texture&#xff0c;纹理对象Texture可以作为模型材质颜色贴图.map属性的值。 const geometry new THREE.PlaneGeometry(200, 100); //纹理贴图加载器TextureLoader const te…