样本不均衡问题一直是深度学习领域一个不可忽略的问题,常说的长尾效应,说的就是这个问题。一类占据了主导地位,导致其他类无论怎么优化,都不能好转。
无论是纯纯的分类任务,还是稍微复杂一些的目标检测任务和分割任务,这种不均衡都一直存在,表现为三种形式:
- 类与类之间的数量不均衡;
- 前景与背景之间数量不均衡;
- 简单样本和难样本之间的不均衡。
如何在计算损失的时候,得到的最终的loss值,对各个数量不均衡的部分,回归的效果更佳均衡,让模型在更新的时候,不会侧重于某一面,而忽略了另一面的问题?值得探讨
本文就对类似这种不均衡的问题展开讨论,这也是在工作中常常会遇到的问题,同时,也是面试官最喜欢问到的问题。
与此同时,在UNet分割任务中,把目标分割任务的mask,转化为对像素点的分类任务。所以在计算损失的时候,论文里面的损失函数采用的就是交叉熵损失函数。
在后续的损失改进中,多引入dice loss或focal loss。我们就从交叉熵损失函数开始,探讨下它为什么可以应用在分割任务中。
一、CE(CrossEntropy)交叉熵损失
在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题。
在使用nn.CrossEntropyLoss()其内部会自动加上Sofrmax层,对输入pd进行归一化操作。
nn.CrossEntropyLoss()的pytorch文档计算公式,如下:

其中:
- C:number of classes
- N:batch size
- wc:对应class的weight权重
Pytorch中CrossEntropyLoss()函数的主要是将softmax+log+NLLLoss合并到一块得到的结果。(下面会给出代码验证,参考:https://blog.csdn.net/qq_39506862/article/details/126625515)
- Softmax后的数值都在0~1之间,所以
ln之后值域是负无穷到0;- log中0-1的输入,输出是0-正无穷;
- 乘以w加负号,值域是负无穷到0。
- 然后将
Softmax之后的结果取log,将乘法改成加法,减少计算量,同时保障函数的单调性; NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,去掉负号,再求均值。
log以e为底,nn.CrossEntropyLoss(pd, gt),有两个参数:
- 第一个参数:
pd是网络的最后一层的输出,其shape为[batchsize, class](函数要求第一个参数,也就是最后一层的输出为二维数据,每个向量中的值为不同种类的概率值) - 第二个参数:
gt是传入的标签,也就是某个类别的索引值,在上面公式没有参与计算。batch_size如果是1,那么就只有一个数字,0,1或者2,表示的就是此时这个样本对应的真实类别,如果为batch_size是2,那么就有两个数字,例如(0,1),分别表示这两条样本对应的真实类别。
对于(minibatch,C) 的情况:例如,输入(也就是最后一层的输出)为pd=[[4,8,3]],shape=(1,3),即batchsize=1,class=3。
第一步, 我们首先计算一下各个类别的损失,这里的类似是3,默认各类的权重都是1,对类别0、1、2类别对应的损失分别是:
loss(pd, 0) # class=0
=-pd[0]+log(exp(pd[0])+exp(pd[1])+exp(pd[2]))
=-4+log(exp(4)+exp(8)+exp(3))
=-4+8.0247
=4.0247
loss(pd, 1) # class=1
=-pd[1]+log(exp(pd[0])+exp(x[1])+exp(x[2]))
=-8+log(exp(4)+exp(8)+exp(3))
=-8+8.0247
=0.0247
loss(pd, 2) # class=2
=-pd[2]+log(exp(pd[0])+exp(pd[1])+exp(pd[2]))
=-3+log(exp(4)+exp(8)+exp(3))
=-3+8.0247
=5.0247
第二步, 用第二个参数(一维数据)也就是根据类别对应的索引值,在上面结果中选择对应的损失,例如若网络的输出为x=[[4,8,3]],而对应的标签为1,则得到损失函数loss=0.0247,其他类不累积计算。
在one-hot里面,也就是为one的计算损失,其他为0的,不计算损失。
换言之:
- 对于预测输出维度为
(batch, C)的情况,CrossEntropyLoss在计算损失时只会对应gt类别的那个值进行计算,其他类别的预测输出不参与计算。 - 在
PyTorch的实现中,这个过程被封装在了nn.functional.cross_entropy函数中。 - 具体地说,对于
batch内的每个样本,cross_entropy会根据gt标签的值在预测输出的张量中找到对应的位置,然后计算该位置上的预测输出和对应的对数概率,最终将这些计算结果取平均得到整个batch的损失值。
举个例子:
- 对应一个
minibatch=8, classes=15的多分类任务,传入nn.CrossEntropyLoss(pd, gt),计算损失。 - 其中:pd的shape是[8,15],gt的shape是[8],也就是一个minibatch的各个图对应的类别,如果传入的gt是one-hot形式,那么one-hot形式的shape是[8,15],需要经过
torch.max(label, 1)[1]处理,意思是说在label的维度为1上做max操作。 - 返回
tensors = (max, max_indices),取max_indices,得到对应的类别数。
如下案例:
>>> import torch
>>> a = torch.eye(8,15)
>>> a
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
>>> a.shape
torch.Size([8, 15])
>>> gt = torch.max(a, 1)[1]
>>> gt
tensor([0, 1, 2, 3, 4, 5, 6, 7])
>>> gt.shape
torch.Size([8])
再把前面手算的loss案例,拿过来计算下,如下:
>>> pd=torch.tensor([[4.0,8.0,3.0]])
>>> gt=torch.tensor([1])
>>> loss = c(pd, gt.long())
>>> loss
tensor(0.0247)
>>> gt=torch.tensor([2])
>>> loss = c(pd, gt.long())
>>> loss
tensor(5.0247)
可以看到,
- gt为1的,对应的loss是0.0247;
- gt为2的,对应的loss是5.0247;
- 与我们前面手算的一致,也符合只对one hot中为1的部分计算损失,其他不管。
对于 (minibatch,C,d1,d2,···,dk)的情况,其中k≥1的 K-dimensional case,同样适用。这里就是UNet分割网络计算损失时候的重点。前面说那么多也就是为了引出这个部分。案例如下:
>>> gt=torch.rand(8, 24, 24)
>>> pd=torch.rand(8, 5, 24, 24)
>>> loss = c(pd, gt.long())
>>> loss
tensor(1.6401)
>>>
换言之:
- 对于预测输出
pd的维度为(minibatch,C,d1,d2)的情况,输入的gt的shape应该为(minibatch,d1,d2),即每个样本的标签是一个d1*d2的二维矩阵。 - 其中,每个像素的值表示对应位置的像素属于哪个类别,通常使用整数表示类别的编号。
- 例如,如果有3个类别,那么像素可能被标记为0、1或2。在计算交叉熵损失时,会将预测
输出pd和对应的gt进行比较,计算每个像素的分类损失,然后对所有像素的损失求平均得到最终的损失值。 - 对于输入的预测张量,通常会在C维度上进行
softmax操作,使得每个通道(类别)的输出值都在[0,1]范围内,并且所有通道的输出值之和为1。 - 这样做的目的是将预测结果转换成概率分布,方便计算交叉熵损失。
- 在PyTorch中,
torch.nn.CrossEntropyLoss()函数会自动将输入进行softmax操作。
二、Balanced Cross Entropy
再Focal Loss for Dense Object Detection里面,引入了Balanced Cross Entropy的描述。在上文中pytorch的Cross Entropy里面的weight就是自己定义的类别权重,这是一个超参数。

其中:
- for class 1 ,α 属于 [0, 1]
- for class -1,weight=1- α
- α 设定是class 数量的倒数,数量越多,α 越小
三、Dice loss
Dice系数中的"Dice"实际上是一位科学家名字的缩写,其全名是Sørensen–Dice coefficient,常被称为Dice similarity coefficient或者F1 score。它由植物学家Thorvald Sørensen 和Lee Raymond Dice独立研制,分别于 1948 年和 1945 年发表。
Dice系数是一种常见的相似度计算方法,主要用于计算两个集合的相似度。在 Dice Loss 中,用 Dice 系数来计算预测结果和真实标签的相似度,因此得名 Dice Loss。
dice coefficient定义如下:

如果看作是对像素点类别的分类任务,也可以写成:

于是,dice loss就可以表示为:

Dice系数的中文名称为“Dice相似系数”或“Dice相似度”,因此 Dice Loss 也可以称为“Dice相似度损失”或“Dice相似系数损失”。
对于输出只有一个类别的分割任务,输出的channel=1。在计算dice loss的时候,需要先使用sigmoid将预测结果,归一化到0-1的范围。具体的代码如下所示:
def dice_loss(pred, target):
smooth = 1.
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))
# example usage:
batch_size, num_classes, height, width = 2, 1, 256, 256
pred = torch.randn(batch_size, num_classes, height, width)
# apply sigmoid to predicted output
pred = torch.sigmoid(pred)
print(pred)
target = torch.randint(0, 2, size=(batch_size, height, width)).float()
loss = dice_loss(pred, target)
print(loss)
对于多类别的loss,定义如下:
def dice_loss_2(pred, target, weight=None, num_classes=5):
smooth = 1.
# 各个类别分别计算损失
dice = torch.zeros(num_classes, dtype=torch.float32, device=pred.device)
for c in range(num_classes):
# 取出对应类的pred和target
p_c = pred[:, c].contiguous().view(-1)
print(p_c)
# target 标记是0、1、2、3、4,表示的是背景和各个类,下面操作类似于one-hot
t_c = (target == c).float().view(-1)
intersection = torch.sum(p_c * t_c)
union = torch.sum(p_c) + torch.sum(t_c)
dice[c] = (2. * intersection + smooth) / (union + smooth)
# apply class weights
if weight is not None:
dice = dice * weight.view(1, -1)
return 1 - dice.mean()
batch_size, num_classes, height, width = 2, 5, 32, 32
pred = torch.randn(batch_size, num_classes, height, width)
pred = torch.softmax(pred, dim=1) # 在num_classes维度上,进行softmax操作
target = torch.randint(0, num_classes, (batch_size, height, width))
print(target)
loss = dice_loss_2(pred, target)
print(loss)
打印target和p_c,可以看到下面的结果,其中红框区域内的和是1,是pred = torch.softmax(pred, dim=1) 的操作结果:

这个时候我就产生了一个疑惑:在沿着num_classes的维度上,进行softmax操作后,带入dice loss公式时候,是否需要将最大值置为1,其他置为0后再计算吗?还是直接0-1范围内的数值,直接和target的0 or 1进行计算呢?像上面的函数定义一样。为什么?
chatGPT给出了答案,如下:
在这个函数中,已经对预测值进行了softmax操作,使得每个像素点在每个类别上的预测值都在0到1之间,表示该像素点属于每个类别的概率。因此,在计算dice loss时,无需将最大值置为1,其他置为0。
该函数中,对于每个类别c,通过计算交集和并集,得到该类别的dice系数。然后将所有类别的dice系数取平均,并用1减去平均值,得到最终的dice loss。
因此,对于每个像素点,其在所有类别上的预测值都会对最终的dice loss产生影响,而不仅仅是最大值。
我继续追问:意思是说,如果取最大值置为0,其他置为0,那么就会导致只会在为1的维度上进行计算loss,而其他的损失未参与计算,是这样一个意思吗?
chatGPT回答:是的,如果将最大值置为1,其他置为0,那么只有在最大值对应的类别上计算的dice系数才会对最终的dice loss产生影响,而其他类别的dice系数则不会参与计算。这样会导致损失函数无法充分利用所有类别的信息,可能会影响模型的训练效果。因此,在计算dice loss时,应该将每个像素点在所有类别上的预测值都考虑进去。
更多内容,可以参考这里:图像分割必备知识点 | Dice损失 理论+代码
四、Focal Loss
在分类任务中,容易分类的negative构成了大部分的损失,并主导了梯度。但是这部分并不是最重要的,并且是比较容易区分的。Balanced Cross Entropy中的α balances the importance of positive/negative examples, it does not differentiate between easy/hard examples。
进而,我们提出了一种loss function去降低weight of easy example,让训练集中到hard negative。
Focal loss是最初由何恺明在论文:Focal Loss for Dense Object Detection论文中提出来的,最初用于图像领域解决数据不平衡造成的模型性能问题。
在论文中:

翻译过来就是:
直观地说,这个比例因子可以在训练过程中,自动降低简单示例的权重,并迅速将模型集中在困难示例上。实验表明,我们提出的Focal Loss使我们能够训练一个高精度的单阶段检测器,其性能明显优于使用抽样启发式或硬示例挖掘训练的替代方案,这是先前用于训练单阶段检测器的最先进技术。最后,我们注意到focal loss的确切形式并不重要,并且我们展示了其他实例可以实现类似的结果
Focal loss是基于二分类交叉熵CE的。它是一个动态缩放的交叉熵损失,通过一个动态缩放因子,可以动态降低训练过程中易区分样本的权重,从而将重心快速聚焦(focal)在那些难区分的样本(有可能是正样本,也有可能是负样本,但都是对训练网络有帮助的样本)。


论文中提到,我们注意到焦损的两个特性(We note two properties of the focal loss.)
- 对于预测结果,与真实的标注,会存在两种情况:
- 当样本被误分类时,pt较小,趋于0,调制因子接近1,权重不变,
loss不受影响; - 当样本被正确分类时,pt趋于1,调制因子接近0,对于分类较好的,降低权重。
(When an example is misclassified and pt is small, the modulating factor is near 1 and the loss is unaffected.As pt --> 1, the factor goes to 0 and the loss for well-classified examples is down-weighted)。 - 直观讲,就是计算损失时,降低正确分类的loss在总loss中的权重,预测差的保持loss,这样变相的提高了错误分类样本在loss中的权重。
- 当样本被误分类时,pt较小,趋于0,调制因子接近1,权重不变,
- 参数
γ降低了简单样本的降低权重操作( The focusing parameter γ smoothly adjusts the rate at which easy examples are down weighted. )
当γ = 0时,Focal Loss等于CE,随着γ的增大,调制因子的作用也随之增大(When γ = 0, FL is equivalent to CE, and as γ is increased the effect of the modulating factor is likewise increased )
对这块的解释,网上资料太多了,关键就在于你该如何理解。我搜罗了一些好理解的,发散开的,建议参考这里:何恺明大神的「Focal Loss」,如何更好地理解?
他从硬截断,去除,不参与total loss的计算:正样本的预测值大于 0.5 的,或者负样本的预测值小于 0.5 的,我都不更新了,把注意力集中在预测不准的那些样本,当然这个阈值可以调整。
导致模型只知道更新难的,错误的样本损失,不知道保持正确的分类。等等,相信上面的文章能给你启发。但是最后还是建议去看看原论文,介绍的很详细。
五、总结
除了上面介绍的几个分类常用到的损失函数,还一直会有更加优秀的loss出现。本篇只是开了一个头,并且上述介绍的损失函数,在各个深度学习框架内基本都包含了这些函数,也无需我们自己手写。
在后续如果有更优秀的loss出现,我也会把它归类到这里。喜欢就点赞收藏不迷路。
在最后的最后,如果你也在学习新的内容,一定要去翻看原论文。我一般把参考的论文链接,都放到了文章里面。在原论文里面,因为是第一次出现,所以一般描述的都是比较的详尽,包括:
- 为什么会有这个概念的出现?
- 他解决了之前存在的什么问题?
- 有什么优势和具体如何实现的?
都非常的详尽,是一个学习的第一手资料。希望对你有用,其余其中不正确和不详尽的地方,欢迎评论区留言,一起讨论,我也会学习后,分享补充进来。



















