深度理解PyTorch的WeightedRandomSampler处理图像分类任务的类别不平衡问题

news2025/6/16 9:41:59

最近做活体检测任务,将其看成是一个图像二分类问题,然而面临的一个很大问题就是正负样本的不平衡问题,也就是正样本(活体)很多,而负样本(假体)很少,如何处理好数据集的类别不平衡问题有很多方法,如使用加权的交叉熵损失(nn.CrossEntropyLoss(weight=weight)),但是更加有效的一个实践是在模型训练的过程中过采样少数类样本,增加这些少数类样本被模型看到的频率。

pytorch提供了一个WeightedRandomSampler 帮助完成以上任务。

torch.utils.data — PyTorch 2.0 documentation

通用的使用方法如下:


[步骤1] class_sample_count = [10, 1, 20, 3, 4] # dataset has 10 class-1 samples, 1 class-2 samples, etc.
[步骤2] weights = 1 / torch.Tensor(class_sample_count)
[步骤3] # 将weights赋予所有的训练样本,作为每个训练样本的权重,标记为 samples_weight
[步骤4] sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, num_samples=len(samples_weight), replacement=True) 
[步骤5] trainloader = data_utils.DataLoader(train_dataset, batch_size = 20, sampler = sampler) 

特别注意: 第5步骤中,一旦设置了sampler, 就不能再设置shuffle,参考:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

怎么理解这些东西呢?

首先第一步需要计算所有类别的个数,这样可以知道哪些类别数量大,而哪些类别数量少。

其次,计算每个类别的权重,只需要简单的取每个类别的个数之倒数即可。

第三,对于每个训练样本,根据其标签,获取其对应类别的权重,作为训练时样本采样的概率。

第四,就是创建采样器,其中num_samples通常取为weights的个数,即训练样本的个数。

大家可以想想,对于样本少的类别,以二分类为例,

[假, 真, 真, 真, 真, 真]

假的权重 = 1/1 , 真的权重= 1/5, 

于是上面的samples_weight取为:

[1, 0.2,0.2, 0.2,0.2, 0.2],

因此取得假样本的期望值 = 1*1=1, 取得真样本的期望值=0.2*5=1, 这样二者就平衡了。因此假样本被过采样,相对地,真样本被欠采样了。

试验

 由此看来总体是多次采样的样本基本是真假平衡的。

那么也许很多人会问:

1. 按照上述概率采样方式,是否所有的训练样本都被模型看到了呢?

2. 如果我不想让数据集均匀分布,而是想达到其他比例呢?

关于这两个问题,博文《Demystifying PyTorch’s WeightedRandomSampler by example》给了一个很详细的回答。

我总结在这里:

1. 在一轮(epoch)训练中,确实可能存在部分样本没有被模型看到,增加num_samples 为训练数据集的样本数量的两倍,会使得一轮迭代过程中看到更多的图像,但是一般仍然推荐设置num_samples 为训练数据集的样本数量,并且相信,当我们训练更多轮以后,所有的图像都将在某一个点处被看到。

2.  对于类别不平衡的数据集,一般在9-10轮以后就会看全所有的样本,而对于类别均衡的数据集,采用上述方法采样,需要大致经过5轮才能看完所有的样本(这种情况下就不用采取这种采样策略了)。

3.  看博文吧。

注:该方法多用于分类问题。即一个训练样本对应一个标签。对于分割问题,一个样本中有很多标签,用该方法就不太方便。分割问题推荐给损失函数添加权重,如nn.CrossEntropyLoss(weight=weight)。

参考文献:

1. Demystifying PyTorch’s WeightedRandomSampler by example
https://gist.github.com/Chris-hughes10/260c70650c5a6f322d273a8a8728b91a

2. Pytorch样本比例不均衡时采用WeightedRandomSampler进行采样

3. torch.utils.data.WeightedRandomSampler样本不均衡情况下带权重随机采样

4. WeightedRandomSampler 理解了吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/412054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springboot实现邮箱验证码功能

引言 邮箱验证码是一个常见的功能,常用于邮箱绑定、修改密码等操作上,这里我演示一下如何使用springboot实现验证码的发送功能; 这里用qq邮箱进行演示,其他都差不多; 准备工作 首先要在设置->账户中开启邮箱POP…

ChatAudio 通过TTS + STT + GPT 实现语音对话(低仿微信聊天)

效果图什么是 STT 和 TTS?STT 是语音转文字(Speech To Text)TTS 是文字转语音(Text To Speech)为什么要使用 SST TTS 如果用户直接输入音频,OpenAI 的 API 中并没有直接使用语音和 GPT 进行对话的功能。所…

(C++)模板分离编译面对的问题

什么是分离编译模板的分离编译什么是分离编译 一个程序(项目)由若干个源文件共同实现,而每个源文件单独编译生成目标文件,最后将所有目标文件链接起来形成单一的可执行文件的过程称为分离编译模式。 模板的分离编译 假如有以下…

Java锁机制

Java锁机制1. 什么是锁JVM运行时内存结构2. 对象、对象头结构Mark Word中的字段3. synchronizedMonitor原理四种锁状态的由来4. 锁的4种状态4.1 无锁CAS(Compare and Swap)4.2 偏向锁实现原理4.3 轻量级锁如何判断线程和锁之间的绑定关系自旋4.4 重量级锁…

【计算机视觉·OpenCV】使用Haar+Cascade实现人脸检测

前言 人脸检测的目标是找出图像中所有的人脸对应的位置,算法的输出是人脸的外接矩形在图像中的坐标。使用 haar 特征和 cascade 检测器进行人脸检测是一种传统的方式,下面将给出利用 OpenCV 中的 haarcascade 进行人脸检测的代码。 程序流程 代码 impo…

摩兽Pesgo plus首发爆卖,全网关注度破亿!中国潮玩跨骑电自浪潮已至?

2023年4月11日,TROMOX摩兽圆满举办了“跨骑潮电,大有所玩”Pesgo plus新品发布会。发布会在抖音、天猫、视频号平台进行了同步直播并开启线上预定。发布会直播当晚,摩兽Pesgo plus即狂揽线上订单,全网各大平台相关话题累计热度已破…

XXL-JOB分布式任务调度平台详细介绍

一、概述 在平时的业务场景中,经常有一些场景需要使用定时任务,比如: 时间驱动的场景:某个时间点发送优惠券,发送短信等等。 批量处理数据:批量统计上个月的账单,统计上个月销售数据等等。 固…

用SQL语句操作oracle数据库--数据查询(上篇)

SQL操作Oracle数据库进行数据查询 Oracle 数据库是业界领先的关系型数据库管理系统之一,广泛应用于企业级应用和数据仓库等场景中。本篇博客将介绍如何使用 SQL 语句对 Oracle 数据库进行数据查询操作。 1.连接到数据库 在开始查询之前,需要使用合适的…

素材管理系统概念导入

引言 由于工作上的调整安排,有幸参加营销素材管理系统的产品建设工作中,营销宣传领域一直是我的知识盲区,所以素材管理系统的产品建设对我来说是个富有挑战性的工作,在这过程中,我也秉持着“好记性不如烂笔头”的原则&…

Golang每日一练(leetDay0033) 二叉树专题(2)

目录 97. 交错字符串 Interleaving String 🌟🌟 98. 验证二叉搜索树 Validate Binary Search Tree 🌟🌟 99. 恢复二叉搜索树 Recover Binary Search Tree 🌟🌟 🌟 每日一练刷题专栏 &am…

中国人工智能企业CIMCAI世界前三大船公司落地,智能船公司产品20秒AI自动验箱,箱信息箱况+精确地点报备智慧港航中国人工智能企业

中国人工智能企业CIMCAI世界前三大船公司落地,智能船公司产品20秒AI自动验箱,箱信息箱况精确地点报备智慧港航。小程序全时全域自动化箱况检测信息识别,CIMCAI全球领先新一代集装箱管理方案,人工智能AI自动化箱信息识别箱况检测地…

Python 小型项目大全 21~25

二十一、DNA 可视化 原文:http://inventwithpython.com/bigbookpython/project21.html 脱氧核糖核酸是一种微小的分子,存在于我们身体的每个细胞中,包含着我们身体如何生长的蓝图。它看起来像一对核苷酸分子的双螺旋结构:鸟嘌呤、…

【跟着陈七一起学C语言】今天总结:C语言的函数相关知识

友情链接:专栏地址 知识总结顺序参考C Primer Plus(第六版)和谭浩强老师的C程序设计(第五版)等,内容以书中为标准,同时参考其它各类书籍以及优质文章,以至减少知识点上的错误&#x…

太阳能电池板AI视觉检测:不良品全程阻断,高效助力光伏扩产

2022年,面对复杂严峻的国内外形势,我国光伏行业依然实现高速增长,多晶硅、硅片、电池片、组件产量稳居全球首位。2023年以来,扩产项目已多点开花。光伏装机量天花板将不断提升,分布式电站占比也将逐年上升。中国光伏行…

4月软件测试面试太难,吃透这份软件测试面试笔记后,成功跳槽涨薪30K

4 月开始,生活工作渐渐步入正轨,但金三银四却没有往年顺利。昨天跟一位高级架构师的前辈聊天时,聊到今年的面试。有两个感受,一个是今年面邀的次数比往年要低不少,再一个就是很多面试者准备明显不足。不少候选人能力其…

python学籍管理系统

1,创建登陆的首页面,且封装起来。LoginPage.py import tkinter as tk#导入tk模块 from tkinter import messagebox#导入消息提示模块 from tkinter import messagebox from db import db #导入数据库db class LoginPage:#把整个登陆页面创建一个class类…

搭建自己的饥荒Don‘t Starve服务器-饥荒Don‘t Starve开服教程

前言 饥荒这个游戏,虽然首发于2016年,但是贵在好玩呀。和Minecraft一样,可玩性很高,并且有很多mods,最近和小伙伴玩的过程中,就想着搭建一个服务器,方便在主机玩家不在线时候,也可以…

Linux软件安装---Tomcat安装

安装Tomcat 操作步骤: 使用xftp上传工具将tomcat的 二进制发布包上传到Linux解压安装包,命令为tar -zxvf apache-tomcat*** -C /usr/local进入Tomcat的bin的启动目录,命令为sh startup.sh或者./startup.sh 验证Tomcat启动是否成功&#xff0…

LeetCode:376. 摆动序列——说什么贪心和动规~

🍎道阻且长,行则将至。🍓 🌻算法,不如说它是一种思考方式🍀算法专栏: 👉🏻123 一、🌱376. 摆动序列 题目描述:如果连续数字之间的差严格地在正数和…

Python 小型项目大全 46~50

# 四十六、百万骰子投掷统计模拟器 原文:http://inventwithpython.com/bigbookpython/project46.html 当你掷出两个六面骰子时,有 17%的机会掷出 7。这比掷出 2 的几率好得多:只有 3%。这是因为只有一种掷骰子的组合给你 2(当两个…