Pytorch中KL loss

news2025/7/15 3:47:49

1. 概念

KL散度可以用来衡量两个概率分布之间的相似性,两个概率分布越相近,KL散度越小。
KL
上述公式表示P为真实事件的概率分布,Q为理论拟合出来的该事件的概率分布。D(P||Q)(P拟合Q)和D(Q||P)(Q拟合P)是不一样的。

2. 举例

班里男生人数占40%,女生占60%,则班里随机抽取一个人的性别的概率分布是Q = [0.4, 0.6]。作为真实事件的概率分布。
小明猜测班里男生占30%,女生占70%,则小明拟合的概率分布P1 = [0.3, 0.7]。
小红猜测班里男生占20%,女生占80%,则小红拟合的概率分布P2 = [0.2, 0.8].
那么现在,小明和小红谁预测的概率分布离真实分布比较近?这时候就可以用KL散度来衡量P1与Q的相似性、P2与Q的相似性,然后对比可得谁更相似。

小明是模拟概率分布(对应Q1),真实概率分布对应P,所以 KL1 = KL(P||Q) = KL([0.4, 0.6] | [0.3, 0.7]) = (0.4log0.4 - 0.4log0.3) + (0.6log0.6 - 0.6log0.7) = 0.0226;同理小红是模拟概率分布(对应Q2),真实概率分布对应PKL2=KL(P||Q2) = KL([0.4, 0.6] | [0.2, 0.8]) = (0.4log0.4 - 0.4log0.2) + (0.6log0.6 - 0.6log0.8) = 0.1046。
KL1比KL2小,说明Q1与P更接近。

这个例子很直观,不用计算就可以猜测出结果,但是当分布复杂的情况下,用KL散度就比较好度量。如一个数据集分布未知,想用数学公式来表达,比如高斯分布、泊松分布、韦伯分布等,这些分布哪个更适合用来表示数据集的分布。则可以计算拟合曲线与数据集真实分布的KL散度,选择KL散度最小的作为数据集的概率分布表达式。
如:用高斯分布拟合数据集分布时,统计均值μ,标准差σ,则可得到高斯分布表达式:
再用高斯分布表达式不同自变量x1,x2,…计算出不同类别的概率q1,q2…,即概率分布Q=[q1, q2,…],与真实的概率分布P = [p1,p2,…]通过上面公式计算得到KL散度。
同理,计算其他拟合分布与真实分布的KL散度,对比得到最优用来拟合真实数据的概率分布表达式。

3. Pytorch计算KL散度

现在,明白了什么是KL散度,可以用pytorch自带的库函数来计算KL散度。
使用pytorch进行KL散度计算,可以使用pytorch的kl_div函数,假设Y_true为真实分布,Y_pred为预测分布。

import torch.nn.functional as F
kl = F.kl_div(Y_pred.log_softmax(dim=-1).log(), Y_true.softmax(dim=-1), reduction='sum')

其中kl_div接收三个参数,第一个为预测分布,第二个为真实分布,第三个为reduction。(其实还有其他参数,只是基本用不到)

这里有一些细节需要注意,第一个参数与第二个参数都要进行softmax(dim=-1),目的是使两个概率分布的所有值之和都为1,若不进行此操作,如果x或y概率分布所有值的和大于1,则可能会使计算的KL为负数。softmax接收一个参数dim,dim=-1表示在最后一维进行softmax操作。除此之外,第一个参数还要进行log()操作(至于为什么,大概是为了方便pytorch的代码组织,pytorch定义的损失函数都调用handle_torch_function函数,方便权重控制等),才能得到正确结果

第三个参数reduction有三种取值,为 none 时,各点的损失单独计算,输出损失与输入(x)形状相同;为 mean 时,输出为所有损失的平均值;为 sum 时,输出为所有损失的总和。

需要清晰的一点解释是:D(P||Q)中P和Q的实际意义,P代表真实概率,也就是对应的是ground truth归一化+log(是否进行log由kl_div()的最后一个参数log_target确定,默认为False即认为输入kl_div()的第二个参数target未进行log)。那么Q就是对应的log(softmax(logit))。这两点才是实际中的定义,所以并没有相反一说,并且调用kl_div()是参数名称也非常明确了,第一个参数是input,第二个参数是target。

代码举例:

#target没有log
import torch
import torch.nn as nn
import torch.nn.functional as F
kl_loss = nn.KLDivLoss(reduction="batchmean")
# input should be a distribution in the log space
input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
# Sample a batch of distributions. Usually this would come from the dataset
target = F.softmax(torch.rand(3, 5), dim=1)
output = kl_loss(input, target)

target没有log输出结果:

输出结果:tensor(0.3441, grad_fn=<DivBackward0>)
#target有log
import torch
import torch.nn as nn
import torch.nn.functional as F
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
log_target = F.log_softmax(torch.rand(3, 5), dim=1)
output = kl_loss(input, log_target)

target有log输出结果:

tensor(0.4346, grad_fn=<DivBackward0>)

4. 我理解的交叉熵和KL

交叉熵作为深度学习常用的损失函数,可以理解为是KL散度的一个特例。当概率分布中的值只取1或0时,可以看作KL散度。但是两者又有区别,KL散度中概率分布所有值之和为1,而交叉熵则可以大于1,如[0,1,0,1,0,0,]。

从概念上讲,KL 散度通常用来度量两个概率分布之间的差异
交叉熵用来求目标与预测值之间的差距,数据分布不一定是概率分布

设数据的真实分布为 P(x),而Q(x)表示我们模型预测出来的数据分布,那么KL散度则为:
KL
化简就是:
KL

因为P(x)是真实分布,也即是由上面公式可知D(P||Q)前面一项是固定的,所以只要后面的项越小,KL散度就越小,也就是损失越小

而交叉熵是KL的一个特例,也用上面的公式计算loss,因为label是采用one-hot格式,即是正确label处的值为1,其余label处的值为0,因此D(P||Q)前面一项是0,就只剩后面一项,因此定义了一个计算loss的交叉熵损失函数,也就是,因此KL散度等于KL前面一项(熵)加上交叉熵,一定程度上优化kl散度和优化交叉熵是等价的
KL

5.参考链接

KL散度理解以及使用pytorch计算KL散度
为什么 不用KL散度作为损失函数? 感觉这个问题描述得不怎么准确???

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/38174.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ajax之Content-Type示例

参考资料: Content-Type详解【SpringBoot】SpringBoot接收请求的n种姿势 目录前期准备0. Content-Type概念解释1. application/x-www-form-urlencoded1.1 form表单示例1.2 jQuery的ajax示例2. application/json2.1 指定contentType为json,不使用RequestBody接收2.2 不指定cont…

01 OSI七层网络排查 troubleshooting 思路及对应工具

文章目录1 .前言2. OSI 的七层模型&#xff0c;和 TCP/IP 的四层 / 五层模型区别2.1 网络专业术语2.2 TLS 解释2.3 什么是TCP 流&#xff1f;3. 网络各层排查工具3.1 应用层3.1.1 浏览器的开发者工具3.1.1.1 找到有问题的服务端IP3.1.1.2 辅助排查网页慢的问题3.1.1.3 解决失效…

嵌入式数据库sqlite3

一、数据库 数据库的基本概念 常用的数据库 大型数据库 Oracle公司是最早开发关系数据库的厂商之一&#xff0c;其产品支持最广泛的操作系统平台。目前Oracle关系数据库产品的市场占有率名列前茅。 IBM 的DB2是第一个具备网上功能的多媒体关系数据库管理系统&#xff0c;支…

bootstrap学习(四)

bootstrap中图片、按钮、表单 按钮&#xff1a; 不加样式的按钮&#xff1a; 在bootstrap中a标签也可以生成按钮&#xff1a; 默认按钮尺寸可以不加&#xff0c;它是自动显示默认尺寸 加btn-block&#xff1a;class 图片&#xff1a; 表单&#xff1a; 垂直表单&#xff1a;…

【语音识别】MFCC+VAD端点检测智能语音门禁系统【含GUI Matlab源码 451期】

⛄一、MFCC简介 1 引言 语音识别是一种模式识别, 就是让机器通过识别和理解过程把语音信号转变为相应的文本或命令的技术。语音识别技术主要包括特征提取技术、模式匹配准则及模型训练技术3个方面。目前一些语音识别系统的适应性比较差, 主要体现在对环境依赖性强, 因此要提高…

[mysql] 深入分析MySQL版本控制MVCC规则--实测 (mysql 8.0 innodb引擎)

背景&#xff1a;基于之前的一篇文章 可重复读&#xff1a;可重复读隔离级别的实现是每个事务在打开时都会生成一个一致的视图。 当其他事务提交时&#xff0c;不会影响当前事务中的数据。 为了保证这一点&#xff0c;MySQL是通过多版本控制机制MVCC来实现的&#xff1b; 我们亲…

Go语言面试题合集(2022)

基础语法 Go 支持默认参数或可选参数吗&#xff1f; 不支持。但是可以利用结构体参数&#xff0c;或者…传入参数切片数组。 // 这个函数可以传入任意数量的整型参数 func sum(nums ...int) {total : 0for _, num : range nums {total num}fmt.Println(total) }Go 语言 tag…

pandas算术运算、逻辑运算、统计运算describe()函数、统计函数、累计统计函数及自定义函数运算

一、pandas算术运算 直接对数据进行加、减、乘、除等运算&#xff0c;可使用函数add()、sub()、mul()、div()或、-、、 代码如下 数据生成 import pandas as pd import numpy as np# 数据生成代码 num np.random.randint(50, 100, (3, 5))# 传入标签索引 column [第一列, …

[hadoop全分布部署]安装Hadoop、配置Hadoop 配置文件①

&#x1f468;‍&#x1f393;&#x1f468;‍&#x1f393;博主&#xff1a;发量不足 个人简介&#xff1a;耐心&#xff0c;自信来源于你强大的思想和知识基础&#xff01;&#xff01; &#x1f4d1;&#x1f4d1;本期更新内容&#xff1a;安装Hadoop、配置Hadoop 配置文件…

基于SSM的高校课程评价系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

一文带你深入理解【Java基础】· 注解

写在前面 Hello大家好&#xff0c; 我是【麟-小白】&#xff0c;一位软件工程专业的学生&#xff0c;喜好计算机知识。希望大家能够一起学习进步呀&#xff01;本人是一名在读大学生&#xff0c;专业水平有限&#xff0c;如发现错误或不足之处&#xff0c;请多多指正&#xff0…

多线程编程【条件变量】

条件变量&#x1f4d6;1. 为什么需要条件变量&#xff1f;&#x1f4d6;2. 条件变量概念&#x1f4d6;3. 发信号时总是持有锁&#x1f4d6;4. 生产者消费者问题&#x1f4d6;5. 基于阻塞队列的生产者消费者模型&#x1f4d6;1. 为什么需要条件变量&#xff1f; 在很多情况下&a…

Android开发音效增强中铃声播放Ringtone及声音池调度SoundPool的讲解及实战(超详细 附源码)

需要源码请点赞关注收藏后评论区留下QQ~~~ 一、铃声播放 虽然媒体播放器MediaPlayer既可用来播放视频&#xff0c;也可以用来播放音频&#xff0c;但是在具体的使用场合&#xff0c;MediaPlayer存在某些播音方面的不足之处 包括以下几点 1&#xff1a;初始化比较消耗资源 尤其…

软件开发工程师笔试记录--关键路径,浮点数计算,地址变换,中断向量,I/O接口,海明码

时间&#xff1a;2022年11月26日 10&#xff1a;00 -11&#xff1a;00 &#xff08;可提前登录15分钟&#xff09; 公司&#xff1a;XX&#xff08;rongyu&#xff09; 岗位&#xff1a;软件开发工程师&#xff08;我的简历语言是Java&#xff09; 题型&#xff1a;选择题&…

一次应用多次fgc原因的排查及解决

应用多次fgc性能排查&#xff08;一次抢购引起的性能问题&#xff09; 大家好我是魔性的茶叶&#xff0c;今天分享一个项目jvm多次fgc的整个排查流程 上班后不久运维突然通知我们组&#xff0c;有一个应用在短时间内多次fgc&#xff0c;即将处于挂掉的状态。 首先我登录skyw…

客户听不进去,很强势,太难沟通了,怎么办?

案例 最近接手一项目,项目范围蔓延,成本超支,进度延期,问项目经理怎么回事? 项目经理C无奈诉苦到:用户A大领导,是业主B的领导,咱们业主B不敢反驳他,让我直接与用户A对接,说A提的需求,默认答应,我有什么办法啊,只能接了。 分析 经过复盘,导致项目蔓延的主要原因…

【参赛经历总结】第五届“传智杯”全国大学生计算机大赛(初赛B组)

成绩 比赛链接 比赛过程 被虐4h 比赛体验不是很好我开始五分后才在qq上看到题第一题签到题第二题调了1h吧&#xff0c;算法题做的不多&#xff0c;题目没读全&#xff0c;wa了几发&#xff0c;有几发是网络问题&#xff0c;交了显示失败&#xff0c;但还是判wa了第三题知道…

mulesoft What‘s the typeOf(payload) of Database Select

Whats the typeOf payload of Database SelectQuestionOptionExplanationMule ApplicationDebugQuestion Refer to the exhibit. The Database Select operation returns five rows from a database. What is logged by the Logger component? Option A “Array” B “Objec…

第五届传智杯-初赛【B组-题解】

A题 题目背景 在宇宙射线的轰击下&#xff0c;莲子电脑里的一些她自己预定义的函数被损坏了。 对于一名理科生来说&#xff0c;各种软件在学习和研究中是非常重要的。为了尽快恢复她电脑上的软件的正常使用&#xff0c;她需要尽快地重新编写这么一些函数。 你只需输出fun(a,…

数据库错误知识集3(摘)

&#xff08;摘&#xff09; 逻辑独立性是外模式不变&#xff0c;模式改变时&#xff0c;如增加新的关系&#xff0c;新的属性&#xff0c;改变属性的数据类型&#xff0c;由数据库管理员对各个外模式/模式的映像做相应改变&#xff0c;可以使得外模式不变&#xff0c;因为应用…