模型量化笔记--KL散度量化

news2025/7/14 3:19:28

KL散度量化

前面介绍的非对称量化中,是将数据中的min值和max值直接映射到[-128, 127]。
同样的,前面介绍的对称量化是将数据的最大绝对值 ∣ m a x ∣ |max| max直接映射到127。
上面两种直接映射的方法比较粗暴,而TensorRT中的int8量化是基于KL散度来选取最佳的阈值T来映射到127中。超出阈值 ± ∣ T ∣ \pm|T| ±T的数据会直接映射为阈值(类似于截断映射)。

KL散度定义

KL散度常用来衡量两个分布P和Q之间的差异,KL散度越小,两个分布越相似,其公式定义如下:
D K L = ∑ i P ( x i ) l o g ( P ( x i ) Q ( x i ) ) D_{KL} = \sum_{i}P(x_i)log(\frac{P(x_i)}{Q(x_i)}) DKL=iP(xi)log(Q(xi)P(xi))

TensorRT实现KL散度量化的步骤

  1. 基于原始输入数据生成拥有2048个bin的直方图
hist, bin_edges = np.histogram(P, bins = 2048)
  1. 在[128, 2048]返回内循环执行3-5步,寻找最佳的划分 b i n i bin_{i} bini
  2. [ 0 , b i n i ] [0,bin{i}] [0,bini]范围内的直方图数据作为原始P, 并将 b i n i bin_{i} bini之后的直方图数据进行求和,并累加到 b i n i − 1 bin_{i-1} bini1中形成以 b i n i bin_{i} bini作为划分的最终P分布
  3. 对P分布进行量化形成Q分布(一般是划分和合并bins,计算合并后的平均值作为Q分布对应bins的值)
  1. 计算P分布和Q分布的KL散度
  2. 根据最小的KL散度来选取最佳的 b i n b e s t bin_{best} binbest,将bin_edges[ b i n b e s t bin_{best} binbest]作为最终的阈值threshold,即映射到127的阈值T
  3. 根据最佳的阈值T来计算scale
    s c a l e = T i n t m a x = T 127 scale = \frac{T}{int_{max}} = \frac{T}{127} scale=intmaxT=127T
  4. 根据对称量化来量化原始数据(权重、激活值等等)

TensorRT使用KL散度量化的目的

通过KL散度选取合适的阈值T,根据阈值计算对应的缩放系数scale,力求int8量化后的数值能更准确表示出量化前的FP32数值。

代码案例

import random
import numpy as np
import matplotlib.pyplot as plt  
import copy
import scipy.stats as stats

# 随机生成测试数据
def generator_P(size):
    walk = []
    avg = random.uniform(3.000, 600.999)
    std = random.uniform(500.000, 1024.959)
    for _ in range(size):
        walk.append(random.gauss(avg, std)) # 生成符合高斯分布的随机数
    return walk

# 平滑p和q,防止出现nan值,因为KL散度会计算log(p/q), 当q为0值时会出现nan
def smooth_distribution(p, eps=0.0001):
    is_zeros = (p == 0).astype(np.float32)
    is_nonzeros = (p != 0).astype(np.float32)
    n_zeros = is_zeros.sum()
    n_nonzeros = p.size - n_zeros
    if not n_nonzeros:
        raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
    eps1 = eps * float(n_zeros) / float(n_nonzeros)
    assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
    hist = p.astype(np.float32)
    hist += eps * is_zeros + (-eps1) * is_nonzeros
    assert (hist <= 0).sum() == 0
    return hist

def threshold_distribution(distribution, target_bin = 128):
    distribution = distribution[1:]
    length = distribution.size # 2047
    threshold_sum = sum(distribution[target_bin:]) # [128: ]
    kl_divergence = np.zeros(length - target_bin) # 初始化 2047 - 128 = 1919 个KL散度值
    
    for threshold in range(target_bin, length): # 遍历threshold寻找KL散度最低的阈值
        sliced_nd_hist = copy.deepcopy(distribution[:threshold]) # [0, threshold)内的作为P
        p = sliced_nd_hist.copy() # 生成p

        p[threshold - 1] += threshold_sum # 把 [threshold:] 后的累加和加到 p[threshold - 1] 中
        threshold_sum = threshold_sum - distribution[threshold] # 更新下一轮的累加和,即上一轮的累加和减去即将移入P分布的区间数据

        is_nonzeros = (p != 0).astype(np.int64) # [0:threshold]内不为0的区间
        
        quantized_bins = np.zeros(target_bin, dtype = np.int64) # 初始化量化后的bins
        num_merged_bins = sliced_nd_hist.size // target_bin # 计算多少个区间需要合并来计算平均值,例如最初有8个bins,需要合并到4个bins,则每两个bins需要进行合并

        # 合并bins
        for j in range(target_bin): 
            start = j * num_merged_bins # 合并开始的bins
            stop = start + num_merged_bins # 合并结束的bins
            quantized_bins[j] = sliced_nd_hist[start:stop].sum() # 计算区间内bins的总和
        quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum()

        # 计算q
        q = np.zeros(sliced_nd_hist.size, dtype = np.float64) # 初始化量化后的q
        for j in range(target_bin):
            start = j * num_merged_bins
            if j == target_bin - 1:
                stop = -1
            else:
                stop = start + num_merged_bins # 每num_merged_bins个bins进行合并组成q
            norm = is_nonzeros[start:stop].sum() # 看看合并区间里,不为0的区间个数
            if norm != 0:
                q[start:stop] = float(quantized_bins[j]) / float(norm) # 用均值(假如区间内都不为0)填充q
        
        # 平滑p和q
        p = smooth_distribution(p)
        q = smooth_distribution(q)
        # 计算p和q之间的KL散度
        kl_divergence[threshold - target_bin] = stats.entropy(p, q)

    # 寻找最小KL散度对应threshold的索引
    min_kl_divergence = np.argmin(kl_divergence)
    threshold_value = min_kl_divergence + target_bin # 计算真正的threshold, 基于最初的128, 因为一开始就是从128开始不断向外计算来扩大P的范围

    return threshold_value

if __name__ == '__main__':
    # 随机初始化测试数据
    size = 20480 
    P = generator_P(size) 
    P = np.array(P)
    P = P[P > 0] # 保留大于0的数
    # print("maximum activation value", max(np.absolute(P))) # 最大的激活值

    hist, bin_edges = np.histogram(P, bins = 2048) # 生成直方图 hist表示每一个bins对应的数量, bins表示截止 
    threshold = threshold_distribution(hist, target_bin = 128) # 返回KL散度最小的划分bins
    print("threshold: ", threshold)
    print("threshold edges:", bin_edges[threshold]) # 截止到threshold对应的bins, 能够表示的范围 bin_edges[-1]表示上面最大的激活值,即能够表示所有数

    # 计算scale
    # scale = bin_edges[threshold] / int_max # 即bin_edges[threshold] / 127 
    # 在最初的对称量化中,我们是用绝对值最大的数值作为bin_edges[threhold], 而TensorRT就是利用KL散度来评估最佳的bin_edges[threshold]

    # 分成 split_zie 组, density表示是否要normed
    plt.title("Relu activation value Histogram")
    plt.xlabel("Activation values")
    plt.ylabel("Normalized number of Counts")
    plt.hist(P, bins=2047)
    plt.vlines(bin_edges[threshold], 0, 30, colors = "r", linestyles = "dashed") # 红线向左就是能够表示的所有范围
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1103211.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

有没有好用的扫描爆破工具?来看这里

工具介绍 goon集合了fscan和kscan等优秀工具功能的扫描爆破工具。功能包含&#xff1a;ip探活、port扫描、web指纹扫描、title扫描、压缩文件扫描、fofa获取、ms17010、mssql、mysql、postgres、redis、ssh、smb、rdp、telnet、tomcat等爆破以及如netbios探测等功能。 参数说…

Linux高性能服务器编程 学习笔记 第十六章 服务器调制、调试和测试

Linux平台的一个优秀特性是内核微调&#xff0c;即我们可以通过修改文件的方式来调整内核参数。 服务器开发过程中&#xff0c;可能会碰到意想不到的错误&#xff0c;一种调试方法是用tcpdump抓包&#xff0c;但这种方法主要用于分析程序的输入和输出&#xff0c;对于服务器的…

Atlassian Confluence OGNL表达式注入RCE CVE-2021-26084

影响版本 All 4.x.x versions All 5.x.x versions All 6.0.x versions All 6.1.x versions All 6.2.x versions All 6.3.x versions All 6.4.x versions All 6.5.x versions All 6.6.x versions All 6.7.x versions All 6.8.x versions All 6.9.x versions All 6.1…

uniapp使用uQRCode绘制二维码,下载到本地,调起微信扫一扫二维码核销

1.效果 2.在utils文件夹下创建uqrcode.js // uqrcode.js //--------------------------------------------------------------------- // github https://github.com/Sansnn/uQRCode //---------------------------------------------------------------------let uQRCode {…

小学数学题AI自动出题系统源码,支持在线打印及导出PDF!

今天给大家开发了个好东西&#xff0c;小学数学作业练习册AI自动出题网站源码&#xff0c;全面支持打印机打印机转成PDF文件&#xff0c;快给你家娃娃整一套吧&#xff0c;AI自动出题&#xff0c;让娃练习算数&#xff0c;解放双手&#xff0c;让您的孩子成绩蒸蒸日上&#xff…

【微服务 SpringCloudAlibaba】实用篇 · Nacos注册中心

微服务&#xff08;5&#xff09; 文章目录 微服务&#xff08;5&#xff09;1. 认识和安装Nacos2. 服务注册到nacos和拉取服务1&#xff09;引入依赖2&#xff09;配置nacos地址3&#xff09;重启 3. 服务分级存储模型3.1 给user-service配置集群3.2 同集群优先的负载均衡 4. …

高程DEM-等高线生成-AutoCAD等高线

高程DEM-等高线生成-AutoCAD等高线 发布时间&#xff1a;2018-01-17 版权&#xff1a; 同步视频教程&#xff1a;卫星地图_高清卫星地图_卫星地图视频_下载高程等高线使用视频教程 专题地图制作视频教程&#xff1a;卫星地图_高清卫星地图_卫星地图视频_地图数据应用&#xf…

【C++】--遇到抛异常没有及时释放的空间该怎么办??---智能指针来帮你解决(以及定制删除器)

&#x1f496;作者&#xff1a;小树苗渴望变成参天大树&#x1f388; &#x1f389;作者宣言&#xff1a;认真写好每一篇博客&#x1f4a4; &#x1f38a;作者gitee:gitee✨ &#x1f49e;作者专栏&#xff1a;C语言,数据结构初阶,Linux,C 动态规划算法&#x1f384; 如 果 你 …

04、RocketMQ -- 核心基础使用

目录 核心基础使用1、入门案例生产者消费者 2、消息发送方式方式1&#xff1a;同步消息方式2&#xff1a;异步消息方式3&#xff1a;一次性消息管控台使用过程中可能出现的问题 3、消息消费方式集群模式&#xff08;默认&#xff09;广播模式 4、顺序消息分析图&#xff1a;代码…

[uni-app] canvas绘制圆环进度条

文章目录 需求参考链接基本问题的处理1:画布旋转的问题2:注意arc()的起始位置是3点钟方向3: 如果绘制1.9*Matn.PI的圆环, 要保证其实位置在0点方向?4:小线段怎么画, 角度怎么处理? 源码 需求 要绘制一个如此的进度条 参考链接 uni-app使用canvas绘制时间刻度以及不显示问…

线段树【java实现】

一、解决问题 区间最值和区间求和问题 力扣相关题目&#xff1a; ​​​​​​303. 区域和检索 - 数组不可变 729. 我的日程安排表 I 二、线段树定义 平衡二叉树&#xff0c;数组中的元素都存储在叶子结点中&#xff0c;如图是一个求区间最大值的线段树。 已知数组arr[11…

电源特性测试之电源模块负载调整率测试方法及测试条件

负载调整率是衡量电源好坏的重要指标&#xff0c;它反映的是当负载电流变化时&#xff0c;稳压电源输出电压相应的变化情况。好的电源负载变化时引起的输出变化较小&#xff0c;通常是在3%-5%。负载调整率是电源模块测试的一个重要步骤&#xff0c;今天纳米软件将为大家介绍负载…

Yakit工具篇:综合目录扫描与爆破的使用

简介&#xff08;来自官方文档&#xff09; 目录扫描是一种常用的Web应用程序安全测试技术&#xff0c;用于发现Web应用程序中存在的可能存在的漏洞和弱点。其原理是通过对Web应用程序中的目录和文件进行遍历&#xff0c;来发现可能存在的安全漏洞和风险。 具体来说&#xff…

大语言模型在推荐系统的实践应用

本文从应用视角出发&#xff0c;尝试把大语言模型中的一些长处放在推荐系统中。 01 背景和问题 传统的推荐模型网络参数效果较小(不包括embedding参数)&#xff0c;训练和推理的时间、空间开销较小&#xff0c;也能充分利用用户-物品的协同信号。但是它的缺陷是只能利用数据…

《进化优化》第5章 进化规划

文章目录 5.1 连续进化规划5.2 有限状态机优化5.3 离散进化规划5.4 囚徒困境5.5 人工蚂蚁问题 5.1 连续进化规划 目的&#xff1a;最小化f(x)&#xff0c; 这里的x是一个n维向量&#xff0c;假定对所有的x, f(x)>0。 进化规划从随机生成的一个个体种群{xi}开始, 按如下方式…

Umi3实战教程

一、框架介绍 umi是蚂蚁金服的前端开发框架&#xff0c;它内置了路由、web/移动端UI库、数据流、权限控制、常用hooks库、构建、部署、测试、等等一些工具&#xff0c;几乎涵盖了正常前端开发要用到的所有工具。 二、环境准备 pnpm 相比npm、yarn&#xff0c;pnpm更小更快扁平…

为大模型而生!顶流大佬发起成立学术会议 COLM,或成为未来 NLP 最强顶会?!

夕小瑶科技说 原创 作者 | 智商掉了一地、ZenMoore 前段时间&#xff0c;ACL 2024 的主席公开抨击称“ arXiv是科研的毒瘤”&#xff0c;这引发了大范围的争论。 一时间&#xff0c;大家对 *CL 的抵触情绪愈发高涨&#xff0c;绝大多数学界都在这场辩论中站在了支持 arXivTwit…

PreparedStatement

使用参数化查询&#xff1a;使用预编译的语句和参数化查询来执行SQL语句&#xff0c;而不是将用户输入直接嵌入到SQL语句中。这将帮助防止恶意输入注入SQL语句。

Zoho WorkDrive荣获专业研究机构评定的“Leader”称号

近年&#xff0c;在云计算、大数据、移动互联网、社交所引领的数字化转型变革中&#xff0c;企业对于数字资产的保护和利用愈加重视。相较于结构化数据&#xff0c;企业对于非结构化数据&#xff08;文档、图片、音视频等&#xff09;管理的需求更强、难度更大。 同时&#xf…