3.1 PTQ与QAT的介绍

news2025/7/19 9:05:24

1. 前言

TensorRT有两种量化模式,分别是implicitly量化(隐式量化)以及explicitly量化(显性量化)。

隐式量化(trt7 版本之前)

  • 只具备 PTQ 一种量化形式
  • 各层精度不可控
  • 显示量化

显性量化(trt8 版本之后)

  • 支持带 QDQ 节点的 PTQ 以及 支持带 QDQ 节点的 QAT 两种量化形式
  • 带 QDQ 节点的 PTQ 是没有进行 Finetune 的,只是插入了对应的 QDQ 节点
  • 带 QDQ 节点的 QAT 是进行了 Finetune 的
  • 显示量化是必须带 QDQ 节点的
  • 各层精度可控

2. Post-Training Quantizationn(训练后的量化)

PTQ(Post-Training Quantization)是隐式量化,因为它在训练之后对模型权重进行量化,而不是在训练过程中进行显式量化。PTQ量化不需要训练,在INT8的量化中,只需要提供一些样本图片,然后在已经训练好的模型上进行校验,统计模型每一层的scale就可以量化,大致流程如下:

  • 准备好的校准数据集
  • 使用校准数据集来校准模型(校准数据可以是训练集的子集)
  • 计算网络模型中权重和激活的动态范围,用于计算得到量化的参数
  • 使用量化参数进行模型量化

目前 tensorRT 提供了多种校准方法,分别适合于不同的任务:

  • EntropyCalibratorV2—适合于基于 CNN 的网络
  • MinMaxCalibrator—适合于 NLP 任务,如BERT
  • EntropyCalibrator
  • LegacyCalibrator

注释:通过上述这些算法量化时,TensorRT会在优化网络的时候尝试INT8精度,假如某一层在INT8精度下速度优于默认精度(FP32或者FP16)则优先使用INT8。这个时候我们无法控制某一层的精度,因为TensorRT是以速度优化为优先的(很有可能某一层你想让它跑int8结果却是fp32)。即使我们使用API去设置也不行,比如set_precision这个函数,因为TensorRT还会做图级别的优化,它如果发现这个op(显式设置了INT8精度)和另一个op可以合并,就会忽略你设置的INT8精度。

3. Quantization Aware Training

3.1 QAT的相关介绍

QAT即训练中量化也叫显式量化.它是 tensorRT8 的一个新特性,这个特性其实是指 tensorRT 有直接加载 QAT 模型的能力。而 QAT 模型在这里是指包含 QDQ 操作的量化模型,而 QDQ 操作就是指量化和反量化操作。

QAT 量化中最重要的就是 FQ(Fake-Quan) 量化算子即 QDQ 算子

QAT 量化需要插入 QAT 算子且需要训练进行微调,大概流程如下:

  • 准备一个预训练模型
  • 在模型中添加QAT算子
  • 微调带有QAT算子的模型
  • 将微调后模型的量化参数,即q-params存储下来
  • 量化模型执行推理

带有 QAT 量化信息的模型如下图所示:

在这里插入图片描述
从上图中我们可以看到带 QAT 量化信息的模型中有 QuantizeLinear 和 DequantizeLinear 模块,也就是对应的 QDQ 模块,它包含了该层和该激活值的量化 scale 和 zero-point。什么是 QDQ 呢?QDQ 其实就是 Q(量化) 和 DQ(反量化)两个 op

QDQ 模块会参与训练,负责将输入的 FP32 张量量化为 INT8,随后再进行反量化将 INT8 的张量再变为 FP32。值得注意的是,实际网络中训练使用的精度还是 FP32,只不过这个量化算子在训练中可以学习到量化和反量化的尺度信息,这样训练的时候就可以让模型权重和量化参数更好地适应量化过程(scale参数也是可以学习的),量化后地精度也相对更高一些。

QDQ 的用途主要体现在两方面:

  1. 可以存储量化信息,比如 scale 和 zero_point,这些信息可以放在 Q 和 DQ 操作中
  2. 可以当作是显示指定哪一层是量化层,我们可以默认认为包在 QDQ 操作中间的 op 都是 INT8 类型的 op,也就是我们需要量化的 op

QDQ的优势

对比显式量化(explicitly),tensorRT的隐式量化(implicitly)就没有那么直接,在 tensorRT-8 版本之前我们一般都是借助 tensorRT 的内部量化算法去量化(闭源),在构建 engine 的时候传入图像进行校准,执行的是训练后量化(PTQ)的过程。

而有了 QDQ 信息,tensorRT 在解析模型的时候会根据 QDQ 的位置找到可量化的 op,然后与 QDQ 融合(吸收尺度信息到 op 中):
在这里插入图片描述
融合后的算子就是实打实的 INT8 算子,经过一系列的融合优化后,最终生成量化版的 engine:
在这里插入图片描述
因为 tensorRT8 可以直接加载通过 QAT 量化后导出为 onnx 的模型,官方也提供了 Pytorch 量化配套工具,可谓是一步到位。

tensorRT 的量化性能是非常好的,可能有些模型或者 op 已经被其他库超越(比如openppl或者tvm),不过tensorRT 胜在支持的比较广泛,用户很多,大部分模型都有前人踩过坑,经验相对较多些,而且支持动态 shape,适用的场景也较多。

不过 tensorRT 也有缺点,就是自定义的 INT8 插件支持度不高,很多坑要踩,也就是自己添加新的 op 难度稍大一些。对于某些层不支持或者有 bug 的情况,除了在 issue 中催一下官方尽快更新之外,也没有其它办法了。

3.2 QAT实战

3.2.1 环境配置

本次代码参考自https://github.com/NVIDIA/TensorRT/tree/release/8.6/tools/pytorch-quantization

需要安装 pytorch-quantization 包来用于后续的工作,安装指令如下:

pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com

3.2.2 整个模型自动插入QDQ节点

我们使用 pytorch-quantization 的 API 来实现对 resnet 网络所有节点的 QDQ 算子插入,其示例代码如下:

import torch
import torchvision
from pytorch_quantization import tensor_quant, quant_modules
from pytorch_quantization import nn as quant_nn

quant_modules.initialize()

model = torchvision.models.resnet18()
model.cuda()

inputs = torch.randn(1, 3, 224, 224, device='cuda')
quant_nn.TensorQuantizer.use_fb_fake_quant = True
torch.onnx.export(model, inputs, 'quant_resnet18.onnx', opset_version=13)

上述示例代码通过指定 quant_nn.TensorQuantizer.use_fb_fake_quant 来将 resnet18 模型中的所有节点替换为 QDQ 算子,并导出为 ONNX 格式的模型文件,实现了模型的量化。值得注意的是:

  • quant_modules.initialize() 函数会把 PyTorch-Quantization 库中所有的量化算子按照数据类型、位宽等特性进行分类,并将其保存在全局变量 _DEFAULT_QUANT_MAP 中
  • 导出的带有 QDQ 节点的 ONNX 模型中,对于输入 input 的整个 tensor 是共用一个 scale,而对于权重 weight 则是每个 channel 共用一个 scale
  • 导出的带有 QDQ 节点的 ONNX 模型中,x_zero_point 是之前量化课程中提到的偏移量,其值为0,因为整个量化过程是对称量化,其偏移量 Z 为0

3.2.3 模型中手动控制某个QDQ节点的插入

如果某些层量化后对整体精度影响大,我们不希望该层插入 QDQ 节点,而是正常用 FP16 去跑,我们应该如何去做控制呢?

在上节课中使用 quant_modules.initialize() 自动插入 QDQ 节点,如何使能某些层插入 QDQ 节点,某些层不插入 QDQ 节点呢?在代码层面我们通过 disable_quantization 以及 enable_quantization 两个类来进行控制。示例代码如下:

import torch
import torchvision
from pytorch_quantization import tensor_quant
from pytorch_quantization import quant_modules
from pytorch_quantization import nn as quant_nn
from pytorch_quantization.nn.modules import _utils as quant_nn_utils
from pytorch_quantization import calib
from typing import List, Callable, Union, Dict

class disable_quantization:
    def __init__(self, model):
        self.model = model

    def apply(self, disabled=True):
        for name, module in self.model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module._disabled = disabled

    def __enter__(self):
        self.apply(True)    
    
    def __exit__(self, *args, **kwargs):
        self.apply(False)


class enable_quantization:
    def __init__(self, model):
        self.model = model
    
    def apply(self, enabled=True):
        for name, module in self.model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module._disabled = not enabled
    
    def __enter__(self):
        self.apply(True)
        return self
    
    def __exit__(self, *args, **kwargs):
        self.apply(False)

def quantizer_state(module):
    for name, module in module.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            print(name, module)

quant_modules.initialize()  # 对整个模型进行量化
model = torchvision.models.resnet50()
model.cuda()

disable_quantization(model.conv1).apply() # 关闭某个节点的量化
# enable_quantization(model.conv1).apply() # 开启某个节点的量化
inputs = torch.randn(1, 3, 224, 224, device='cuda')
quant_nn.TensorQuantizer.use_fb_fake_quant = True
torch.onnx.export(model, inputs, 'quant_resnet50_disabelconv1.onnx', opset_version=13)

上述示例代码演示了如何在 Pytorch 中开启或禁用量化器(Quantizer)对指定节点的量化过程。

我们对模型的 conv1 模块禁用量化器对该模块的量化,在导出的 ONNX 模型中可以看到该节点没有被插入 QDQ 节点量化

以 Conv 层为例,量化前后模型的属性会发生变化。量化后的 Conv 层会在原有属性的基础上新增两个属性即 input_quantizer 以及 weight_quantizer,用于记录输入和权重的量化信息。而量化后的 Conv 层属于 quant_nn.TensorQuantizer 类型,表示这个层被量化了。而在 TensorQuantizer 类中是通过 _disabled 这个属性来控制是否进行量化的,因此我们就可以利用这个属性来控制某些层是否插入 QDQ 节点。

3.2.4 自定义QDQ节点

对通用的网络层如conv、bn和relu等手动插入QDQ节点,tensorRT提供了常用的QDQ节点的实现。而对于自己设计的网络层,需要我们自己对自定义层进行手动插入QDQ节点。其中,自定义层分为两种,一种是只有input,一种是包含input和weight。

下面是只包含 input 自定义层 MultiAdd 量化的示例代码:

import torch
from pytorch_quantization import nn as quant_nn
from pytorch_quantization.tensor_quant import QuantDescriptor

class QuantMultiAdd(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._input_quantizer = quant_nn.TensorQuantizer(QuantDescriptor(num_bits=8, calib_method="histgoram"))
    
    def forward(self, x, y, z):
        return self._input_quantizer(x) + self._input_quantizer(y) + self._input_quantizer(z)

model = QuantMultiAdd()
model.cuda()
input_a = torch.randn(1, 3, 224, 224, device='cuda')
input_b = torch.randn(1, 3, 224, 224, device='cuda')
input_c = torch.randn(1, 3, 224, 224, device='cuda')
quant_nn.TensorQuantizer.use_fb_fake_quant = True
torch.onnx.export(model, (input_a, input_b, input_c), 'quantMultiAdd.onnx', opset_version=13)

在上述示例代码中,首先定义了 QuantMultiAdd 自定义层,它包含一个输入量化器 _input_quantizer 基于 pytorch_quantization 库中的 TensorQuantizer 类来创建的,使用 8 位量化位数,并采用直方图作为校准方法进行模型量化,然后在前向传播过程中,将三个输入都通过输入量化器进行量化操作,并返回它们的量化结果之和。
在这里插入图片描述

3.2.5 手动实现quant_modules.initialize()

对整个模型插入 QDQ 节点我们是通过 quant_modules.initialize() 来实现的,我们能否自定义实现整个模型的 QDQ 节点插入呢?而不用上述方法,官方提供的接口可控性、灵活度较差,我们自己来实现整个过程。示例代码如下:

import torch
import torchvision
from pytorch_quantization import tensor_quant
from pytorch_quantization import quant_modules
from pytorch_quantization import nn as quant_nn
from pytorch_quantization.nn.modules import _utils as quant_nn_utils
from pytorch_quantization import calib
from typing import List, Callable, Union, Dict


def transfer_torch_to_quantization(nninstace : torch.nn.Module, quantmodule):
    
    quant_instance = quantmodule.__new__(quantmodule)
    for k, val in vars(nninstace).items():
        setattr(quant_instance, k, val) # 继承所有的属性
    
    def __init__(self):

        if isinstance(self, quant_nn_utils.QuantInputMixin): # 只有input,没有weight
            quant_desc_input = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__, input_only=True)
            self.init_quantizer(quant_desc_input)

            # Turn on torch hist to enable higher calibration speeds
            if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
                self._input_quantizer._calibrator._torch_hist = True
        else:
            quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__)
            self.init_quantizer(quant_desc_input, quant_desc_weight)

            # Turn on torch_hist to enable higher calibration speeds
            if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
                self._input_quantizer._calibrator._torch_hist  = True  # 提速
                self._weight_quantizer._calibrator._torch_hist = True  #
    
    __init__(quant_instance)
    return quant_instance
            

def replace_to_quantization_module(model : torch.nn.Module, ignore_policy : Union[str, List[str], Callable] = None):
    
    module_dict = {}
    for entry in quant_modules._DEFAULT_QUANT_MAP:
        module = getattr(entry.orig_mod, entry.mod_name)
        module_dict[id(module)] = entry.replace_mod
    
    def recursive_and_replace_module(module, prefix=""):
        for name in module._modules:
            submodule = module._modules[name]
            path      = name if prefix == "" else prefix + "." + name
            recursive_and_replace_module(submodule, path)

            submodule_id = id(type(submodule))
            if submodule_id in module_dict:
                module._modules[name] = transfer_torch_to_quantization(submodule, module_dict[submodule_id])
        
    recursive_and_replace_module(model)
            

# quant_modules.initialize() # 如何实现自定义QDQ节点插入?
model = torchvision.models.resnet50()
model.cuda()

replace_to_quantization_module(model)
inputs = torch.randn(1, 3, 224, 224, device='cuda')
quant_nn.TensorQuantizer.use_fb_fake_quant = True
torch.onnx.export(model, inputs, 'quant_resnet50_replace_to_quantization.onnx', opset_version=13)

上述示例代码实现了自定义整个模型的 QDQ 节点插入。主要包括两个函数即 transfer_torch_to_quantization 和 replace_to_quantization_module。

其中,replace_to_quantization_module 函数的作用是将原始模型中的指定层替换成对应的量化层,并返回替换后的模型。具体来说,该函数遍历整个模型的层,如果当前层是被替换层,则调用 transfer_torch_to_quantization 函数将其转换为量化层。

transfer_torch_to_quantization 函数的作用是将原始模型的一个层转换成对应的量化层。该函数首先创建一个新的量化层实例 quant_instance,然后将原始层的所有属性复制到这个实例中。接着根据不同的 OP 算子类型来进行初始化,具体根据原始层是否有 weight,来初始化 quant_instance 的 input_quantizer 和 weight_quantizer 两个属性。最后,将 quant_instance 返回。

这两个函数的组合实现了一个自定义的 QDQ 节点插入函数,它不依赖于 quant_modules.initialize() 接口,而是通过遍历模型层并替换成对应的量化层来实现。如果你想只让某些层进行量化,则可以加入一些过滤条件,通过这样的方式灵活控制,实现手动插入 QDQ 节点。

上述代码中将 self._weight_quantizer._calibrator._torch_hist 设置为 True 是为了提高权重量化时的校准速度。当使用直方图来确定数据分布时,由于直方图的计算量较大,所以开启 _torch_hist 可以使用 PyTorch 内置的直方图函数来提高校准速度。因此,当使用 HistogramCalibrator 进行校准时,将 _torch_hist 设置为 True 可以提高校准速度。

参考链接:
https://blog.csdn.net/qq_40672115/article/details/130489067

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/619158.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从零开始,申请开通微信小程序全流程

本系列文章适合三类同学:1.希望学习小程序开发;2.希望无代码、低代码拥有自己的小程序;3.快速搭建小程序交作业、交毕设的大学生 本系列文章将推出配套桌面端软件,配合软件,可实现傻瓜式开发小程序,请有需求…

计算机体系结构-期末复习

计算机体系结构-期末复习 第一章 量化设计与分析基础 | 1.2.6 并行度与并行体系结构的分类 应用程序中主要有两种并行: 数据级并行:同时操作许多数据项实现的并行任务级并行:创建能够单独处理并大量采用并行方式执行的工作任务 所有计算…

如何理解网络—网络框架介绍

目录 前言 一.计算机网络背景 二.局域网和广域网 三.网络协议 3.1产生的背景 3.2分层实现 四.OSI七层模型 4.1OSI七层模型的结构 4.2如何理解OSI七层模型 五.TCP/IP五层(或四层)模型 六.网络传输基本流程 7.网络中的地址管理 7.1IP地址 7.2MAC地址 7.3MAC地址和IP地址的区别和联…

2021-2023浙江省内八大MBA项目招生情况:注意大小年啊~

现如今国内的MBA教育呈现出一片繁华景象,过去的这些年来每年几乎都有新增加的MBA招生院校,浙江省内目前共有九大MBA招生院校,除了浙大独领风骚之外,其余八个MBA项目也都有自己的一席之地。纵观2023年的招生录取,小立老…

ValueError: Object arrays cannot be loaded when allow_pickle=False

一、问题 使用numpy读取数据时出现错误,ValueError: Object arrays cannot be loaded when allow_pickleFalse。 查了一下numpy.load()函数 用法 numpy.load(file, mmap_modeNone, allow_pickLeFalse, fix_mportsTrue, encoding‘ASCII’) 参数 file:…

【Android】AMS(三)APP启动流程

启动方式 在 Android 系统中,启动一个应用程序可以分为三种启动方式:热启动、冷启动和温启动。它们分别表示了不同的启动方式和启动过程。 热启动 热启动是指在已经打开并处于后台运行的应用程序中,再次通过图标进入应用程序的启动方式。这…

Spring Security OAuth停更了?探索官方进化版Spring Authorization Server的革新之处!

1、背景 Spring Security OAuth(spring-security-oauth2)停更 主要意思是:生命周期终止通知 Spring Security OAuth(spring-security-oauth2)项目已达到生命周期结束,不再由VMware,Inc.积极维护。 此项目已被Spring Security和Spring Author…

信创办公–基于WPS的EXCEL最佳实践系列 (设置多级列表)

信创办公–基于WPS的EXCEL最佳实践系列 (设置多级列表) 目录 应用背景操作步骤1、删除重复项2、部门绑定3、填入相关信息 应用背景 当我们在使用电子表格时,很多类型重复输入很麻烦,看起来也很复杂,我们就可以设置多级…

关于输入输出格式符的测试

对输出%m.nf的测试 m代表宽度,表示数据可以占m列n代表精确,表示小数占n列 以下用%6.3f进行测试,有两个问题: 1、这个m是包括小数点位数吗?(todo未果) 2、精确度n超过了是怎么处理的&#xff1f…

2023年第六届广西大学生程序设计竞赛(正式赛)题解

比赛题目链接,可以继续提交代码: 2023年第六届广西大学生程序设计竞赛(正式赛) | 知乎:如何评价第六届广西大学生程序设计竞赛? 难度题号备注签到题A J K已给出题解和代码普通题B D E H已给出题解和代码中等题C G–难题F I L M–…

机器学习方法在生态经济学领域中的应用

查看原文>>>基于R语言机器学习方法在生态经济学领域中的实践技术 近年来,人工智能领域已经取得突破性进展,对经济社会各个领域都产生了重大影响,结合了统计学、数据科学和计算机科学的机器学习是人工智能的主流方向之一&#xff0c…

moviepy快速切分视频并保存片段

文章目录 1、直接使用ffmepg2、使用moviepy本身 moviepy安装最新版本: pip install moviepy --pre --upgrade版本是v2.0.0.dev2。 有两种方法一种快速的: 1、直接使用ffmepg from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip ffmpeg…

Lecture 12 Discourse

目录 Discourse 语篇三个关键的语篇任务Discourse Segmentation 语篇分段Unsupervised Approaches 无监督方法Supervised Approaches 有监督方法有监督语篇分段器Discourse Analysis 语篇解析语篇解析RST: Discourse UnitsRST: Discourse RelationsNucleus vs. Satellite 核心 …

2022计算机系统期末

直接导入无图片,先不改了,不过图片都是原题目,不影响对答案。 计算机系统2022期末 本课程的复习请以知识点复习为重,全部内容共有大小280个知识点,都可能在期末考试出现,仅通过往年试卷复习是远远不够的&…

从小白到大神之路之学习运维第35天---第三阶段---mysql数据库之主从复制配置

第三阶段基础 时 间:2023年6月7日 参加人:全班人员 内 容: Mysql数据库之主从复制配置 目录 前提环境配置 MySQL 5.7 版本的主从复制配置步骤 主 库 1. 在主库上开启二进制日志记录功能 2. 在主库上创建一个用于从库访问的备份用户 3. 在主…

VTK学习之光照和相机

目录 一、VTK光照 1、关于vtkLight常用的方法 2、最终效果 二、相机设置 1、相机设置 2、效果 一、VTK光照 通过设置光照,可以达到不同颜色的目的,参考博客: VTK修炼之道7_三维场景基本要素:光照_vtk 光照_沈子恒的博客-CSDN博客 1…

吴恩达 ChatGPT Prompt Engineering for Developers 系列课程笔记--02 Guidelines

02 Guidelines 本节将配合代码,介绍一些构建Prompt的基本原则和策略。 1) OpenAI API 首先开发者需要在OpenAI网站(https://platform.openai.com/account/api-keys)注册一个key,然后通过pip install openai安装openai三方库,再将key导入当…

Express路由

一、目标 能够使用express.static()快速托管静态资源能够使用express路由精简项目结构能够使用常见的express中间件能够使用express创建API接口能够在express中启用cors跨域资源共享 二、目录 初始ExpressExpress路由Express中间件使用Express写接口 1.1Express简介 1.什么…

微信小程序原生开发功能合集二十:导航栏、tabbar自定义及分包功能介绍

本章实现导航栏及tabbar的自定义处理的相关方法介绍及效果展示。   另外还提供小程序开发基础知识讲解课程,包括小程序开发基础知识、组件封装、常用接口组件使用及常用功能实现等内容,具体如下:    1. CSDN课程: https://edu.csdn.net/course/detail/37977    2. 5…

uCOSii系统的中断管理

uCOSii系统的中断管理 1、在使用uCOSii系统时,中断服务程序需要调用两个函数OSIntEnter()和OSIntExit()。 OSIntEnter() 进入中断时,用OSIntNesting来统计中断嵌套次数,告知uCOSii系统,当前中断服务程序正在执行; OS…