数据增强方法汇总

news2025/8/12 9:51:19

数据增强

  • 1.有监督数据增强
    • 1.1 单样本数据增强
      • augly安装
      • augly使用方法
    • 1.2 多样本数据增强
      • 1.2.1 SMOTE
        • python实现
      • 1.2.2 SamplePairing
        • python实现
      • 1.2.3 mixup
        • python实现
  • 2.无监督数据增强
    • 2.1 GAN
    • 2.2 Diffunsion
    • 2.3 Autoaugmentation

1.有监督数据增强

1.1 单样本数据增强

augly安装

AugLy是一个数据增强库,目前支持四种模式(音频、图像、文本和视频)和100多种增强。每个模态的增强包含在自己的子库中。这些子库包括基于函数和基于类的变换、组合运算符,并可以选择提供有关所应用转换的元数据,包括其强度。

在这里插入图片描述
该库基于Python,至少需要Python 3.6+版本、

官网地址 :https://github.com/facebookresearch/AugLy
安装该库方法

pip install augly[all]

也可以只安装某一个应用,比如只安装音频

pip install augly[audio]

也可以克隆git

git clone git@github.com:facebookresearch/AugLy.git && cd AugLy
[Optional, but recommended] conda create -n augly && conda activate augly && conda install pip
pip install -e .[all]

augly使用方法

augly.image所有函数都接受要作为输入增强的图像或PIL图像对象的路径,并返回增强的PIL图像对象。如果指定了输出路径,图像也将保存到文件中。

import augly.image as imaugs

image_path = "your_img_path.png"
output_path = "your_output_path.png"

aug_image = imaugs.overlay_emoji(image_path, opacity=1.0, emoji_size=0.15)

#增强功能也可以接受PIL图像作为输入
aug_image = imaugs.pad_square(aug_image)

#如果指定了输出路径,图像也将保存到文件中
aug_image = imaugs.overlay_onto_screenshot(aug_image, output_path=output_path)

🎵🕺🗣🏀使用augly与pytorch transforms集成

import torchvision.transforms as transforms
import augly.image as imaugs

COLOR_JITTER_PARAMS = {
    "brightness_factor": 1.2,
    "contrast_factor": 1.2,
    "saturation_factor": 1.4,
}

AUGMENTATIONS = [
    imaugs.Blur(),
    imaugs.ColorJitter(**COLOR_JITTER_PARAMS),
    imaugs.OneOf(
        [imaugs.OverlayOntoScreenshot(), imaugs.OverlayEmoji(), imaugs.OverlayText()]
    ),
]

#一种是应用到PIL图像
image = Image.open("your_img_path.png")
TRANSFORMS = imaugs.Compose(AUGMENTATIONS)
aug_image = TRANSFORMS(image)

#也可以应用到tensor张量
TENSOR_TRANSFORMS = transforms.Compose(AUGMENTATIONS + [transforms.ToTensor()])
aug_tensor_image = TENSOR_TRANSFORMS(image)

增强函数方法很多 也可以查看文档
https://augly.readthedocs.io/en/latest/augly.image.html

1.2 多样本数据增强

1.2.1 SMOTE

SMOTE全称“Synthetic Minority Over-sampling Technique”
过度采样技术

主要是针对不平衡的数据集使用。可以采用随机欠采样或过采样技术,对数据集进行调整,使其趋于均衡。

原理
它是基于随机过采样算法的一种改进方案,由于随机过采样采取简单复制样本的策略来增加少数类样本,这样容易产生模型过拟合的问题,即使得模型学习到的信息过于特别(Specific)而不够泛化(General),算法的基本思想是对少数类样本进行分析并根据少数类样本人工合成新样本添加到数据集中,具体如下图所示:

在这里插入图片描述

算法流程

  1. 将每一个图像看成是一个特征空间,将特征映射到空间上类似于聚类算法将各数据点展示出来
  2. 对于小样本的图像 ( x , y ) (x,y) (x,y),按欧式距离找出k个最近邻的样本,从中随机选择一个样本点,选择为近邻点 ( x n , y n ) (x_n,y_n) (xn,yn)
  3. 将特征空间中的样本点与最近邻的样本点连线。随机选取线上一点作为新的样本(这里的特征空间样本点,也可以是样本的中心聚类中心这种)
  4. 重复1,2,3直到数据集平衡

算法公式
( x n e w , y n e w = ( x , y ) + r a n d ( 0 , 1 ) ∗ ( ( x n − x ) , ( y n − y ) ) (x_{new},y_{new}=(x,y)+rand(0,1)*((x_n -x),(y_n-y)) (xnew,ynew=(x,y)+rand(0,1)((xnx),(yny))

python实现

安装

pip install imbalanced-learn

我们将通过一个不平衡的二分类问题来实现SMOTE算法

from collections import Counter
from sklearn.datasets import make_classification
from matplotlib import pyplot
from numpy import where
# 定义数据集
X, y = make_classification(n_samples=10000, n_features=2, n_redundant=0,
	n_clusters_per_class=1, weights=[0.99], flip_y=0, random_state=1)
# 汇总类分布
counter = Counter(y)
print(counter)
# 绘制散点图
for label, _ in counter.items():
	row_ix = where(y == label)[0]
	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

Counter输出分类汇总({0: 9900, 1: 100})
在这里插入图片描述
接下来,使用SMOTE对少数类进行过采样,并绘制转换后的数据集。

oversample = SMOTE()
X, y = oversample.fit_resample(X, y)
counter = Counter(y)
print(counter)
for label, _ in counter.items():
	row_ix = where(y == label)[0]
	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

Counter输出分类汇总({0: 9900, 1: 9900})
在这里插入图片描述

1.2.2 SamplePairing

Sample是通过样本配对的方法实现数据增强,顾名思义使用几张样本图像,进行图像组合融合来实现数据增强的功能

Sample Pairing 采用最简单的方法混合两个图像;取两张图片的平均值。它只使用第一个标签并丢弃第二个标签

python实现

class SamplePairing(object):
    def __init__(self, train_data_dir, p=1):
        self.train_data_dir = train_data_dir
        self.pool = self.load_dataset()
        self.p = p

    def load_dataset(self):
        dataset_train_raw = vdatasets.ImageFolder(self.train_data_dir)
        return dataset_train_raw

    def __call__(self, img):
        toss = np.random.choice([1, 0], p=[self.p, 1 - self.p])

        if toss:
            img_a_array = np.asarray(img)

            # pick one image from the pool
            img_b, _ = random.choice(self.pool)
            img_b_array = np.asarray(img_b.resize((197, 197)))

            # mix two images
            mean_img = np.mean([img_a_array, img_b_array], axis=0)
            img = Image.fromarray(np.uint8(mean_img))
            
            # could have used PIL.Image.blend

        return img

1.2.3 mixup

mixup也是将两个样本图像配对混合在一起,创建新的图像样本

算法原理,
假设数据和标签的配对 ( x 1 , y 1 ) , ( x 2 , y 2 ) (x1,y1),(x2,y2) (x1,y1),(x2,y2)
新的训练样本。 ( x , y ) (x,y) (x,y) 制作在这里标签 y 1 , y 2 y1,y2 y1y2

假设是one-hot表示的向量。 x 1 , x 2 x1,x2 x1,x2是任意的向量或张量

x = λ x 1 + ( 1 − λ ) x 2 x=\lambda x_1+(1−\lambda)x_2 x=λx1+(1λ)x2 y = λ y 1 + ( 1 − λ ) y 2 y=\lambda y_1+(1−\lambda)y_2 y=λy1+(1λ)y2
在这里 λ ∈ [ 0 , 1 ] \lambda \in[0,1] λ[0,1]

python实现

import numpy as np


class MixupGenerator():
    def __init__(self, X_train, y_train, batch_size=32, alpha=0.2, shuffle=True, datagen=None):
        self.X_train = X_train
        self.y_train = y_train
        self.batch_size = batch_size
        self.alpha = alpha
        self.shuffle = shuffle
        self.sample_num = len(X_train)
        self.datagen = datagen

    def __call__(self):
        while True:
            indexes = self.__get_exploration_order()
            itr_num = int(len(indexes) // (self.batch_size * 2))

            for i in range(itr_num):
                batch_ids = indexes[i * self.batch_size * 2:(i + 1) * self.batch_size * 2]
                X, y = self.__data_generation(batch_ids)

                yield X, y

    def __get_exploration_order(self):
        indexes = np.arange(self.sample_num)

        if self.shuffle:
            np.random.shuffle(indexes)

        return indexes

    def __data_generation(self, batch_ids):
        _, h, w, c = self.X_train.shape
        _, class_num = self.y_train.shape
        X1 = self.X_train[batch_ids[:self.batch_size]]
        X2 = self.X_train[batch_ids[self.batch_size:]]
        y1 = self.y_train[batch_ids[:self.batch_size]]
        y2 = self.y_train[batch_ids[self.batch_size:]]
        l = np.random.beta(self.alpha, self.alpha, self.batch_size)
        X_l = l.reshape(self.batch_size, 1, 1, 1)
        y_l = l.reshape(self.batch_size, 1)

        X = X1 * X_l + X2 * (1 - X_l)
        y = y1 * y_l + y2 * (1 - y_l)

        if self.datagen:
            for i in range(self.batch_size):
                X[i] = self.datagen.random_transform(X[i])

        return X, y

2.无监督数据增强

2.1 GAN

最常见的就是GAN,其原理可以阅读鄙人的这篇文章
Generative Model - 李宏毅笔记

2.2 Diffunsion

还有以及现在最火的ai 扩散算法,原理也可以阅读下面这篇文章

Diffusion Model算法

2.3 Autoaugmentation

AutoAugment是Google提出的自动选择最优数据增强方案的研究,这是无监督数据增强的重要研究方向。它的基本思路是使用增强学习从数据本身寻找最佳图像变换策略

在这里插入图片描述

例如指导基本图像转换操作的选择,例如水平/垂直翻转图像、旋转图像、更改图像颜色等。AutoAugment不仅预测要组合哪些图像转换,还可以预测所用变换的每个图像概率和大小,因此图像并不总是以相同的方式进行操作。AutoAugment能够大量的图像转换可能性的搜索空间中选择最佳策略。

AutoAugment根据运行的数据集学习不同的转换。例如,对于涉及房屋号码街景(SVHN)的图像,包括数字的自然场景图像,AutoAugment专注于剪切和平移等几何变换,这些变换代表了该数据集中常见的失真。此外,鉴于世界上不同建筑和房屋编号材料的多样性,AutoAugment已经学会了完全反转原始SVHN数据集中自然出现的颜色。
在这里插入图片描述
简单说就是AutoAugment,自己不需要去设定需要增强的元素,算法会自动结合所有的增强策略选择最优的策略进行输出,输出一个增强后的图像

DeepAugment是一个专注于数据增强的AutoML工具。它利用贝叶斯优化来发现您的图像数据集的特点,并量身定制的数据增强策略。DeepAugment的主要优势和特点是:

  • 降低CNN模型的错误率
  • 通过自动化流程节省时间
  • 比谷歌之前的解决方案快50倍-AutoAugment

安装

pip install deepaugment

使用方法

from deepaugment.deepaugment import DeepAugment

deepaug = DeepAugment(my_images,my_labels)

best_policies = deepaug.optimize(300)

通过配置DeepAugment,实现增强用法

from keras.datasets import cifar10

# my configuration
my_config = {
    "model": "basiccnn",
    "method": "bayesian_optimization",
    "train_set_size": 2000,
    "opt_samples": 3,
    "opt_last_n_epochs": 3,
    "opt_initial_points": 10,
    "child_epochs": 50,
    "child_first_train_epochs": 0,
    "child_batch_size": 64
}

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# X_train.shape -> (N, M, M, 3)
# y_train.shape -> (N)
deepaug = DeepAugment(x_train, y_train, config=my_config)

best_policies = deepaug.optimize(300)

CIFAR-10在WRN-28-10上测试的最佳政策

方法:Wide-ResNet-28-10通过最佳策略使用CIFAR-10增强图像和未增强图像(其他一切相同)进行训练。

结果:DeepAugment的误差减少60%(准确率提高8.5%)

在这里插入图片描述
在这里插入图片描述

有关更详细的安装/使用信息 参考官方地址
https://github.com/barisozmen/deepaugment

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/15158.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MongoDB(4.0.9)数据从win迁移到linux

服务器从win迁移导了linux上了,对应的md里面的数据也需要做全量迁移,在网上找了一大堆方案,不是缺胳膊就是少腿,没有一个是完整的,最终加以分析和整理,得出这套方案,希望对你有用 第一步&#…

Java集合框架【二容器(Collection)[Vector容器类]】

文章目录三 Vector容器类3.5 Vector容器类3.5.1 Vector的使用3.5.2 Stack容器3.5.3.1 Stack容器介绍3.5.3.2 操作栈方法Stack的使用案例三 Vector容器类 3.5 Vector容器类 Vector底层是用数组实现的,相关的方法都加了同步检查,因此“线程安全&#xff…

D. Divide and Sum(组合数学)

Problem - 1445D - Codeforces 题意: 给你一个长度为2n的数组a。考虑将数组a划分为两个子序列p和q,每个子序列的长度为n(数组a的每个元素应该正好在一个子序列中:要么在p中,要么在q中)。 让我们以非递减顺序对p进行排…

matplotlib笔记

一、安装matplotlib总是超时导致失败 鉴于公司内网服务器上直接pip install matplotlib容易超时退出的问题,可以采用下面的方法解决: 方法一:指定更新源 pip install -i Simple Index matplotlib3.2.2 注意选择3.2.2,因为最新版本…

AP22615AWU-7、SLG5NT1758V配电开关 驱动器 IC资料

AP22615配电开关具有输出过压保护 (OVP) 功能,设计用于USB和其他热插拔应用。该器件提供输出过压保护,可保护这些应用的系统。具有输出过压保护、反向电流阻断、过流、过热和短路保护功能。其他功能包括受控上升时间和欠压锁定功能。 AP22615具有可调限…

【Java篇】备战面试——你真的了解“基本数据类型”吗?

目录 基本介绍: 整数类型 浮点类型 布尔类型和char类型 自动类型转换 数据类型转换必须满足如下规则: 基本介绍: Java是一门强类型语言,这就意味着必须为每一个变量声明一种类型。Java为我们提供了八种基本类…

[附源码]计算机毕业设计JAVA归元种子销售管理系统

[附源码]计算机毕业设计JAVA归元种子销售管理系统 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM my…

【毕业设计】大数据分析的航空公司客户价值分析 - python

文章目录0 前言1 数据分析背景2 分析策略2.1 航空公司客户价值分析的LRFMC模型2.2 数据2.3 分析模型3 开始分析3.1 数据预处理3.1.1 数据预览3.1.2 数据清洗3.2 变量构建3.3 建模分析4 数据分析结论4.1 整体结论4.2 重要保持客户4.3 重要挽留客户4.4 一般客户与低价值客户5 最后…

Cadence Allegro PCB设计88问解析(十七) 之 Allegro中焊盘的全连接和花焊盘

一个学习信号完整性仿真的layout工程师 上一篇文章和大家分享了关于铜皮shape的一些基本操作。我们进行铺铜是为了连接网络(焊盘、过孔等),一般都是GND或者电源网络。Shape和走线还是不一样的,走线直接从焊盘或者过孔等直接拉出一根layout,但…

【MySQL数据库笔记 - 进阶篇】(三)SQL优化

✍个人博客:https://blog.csdn.net/Newin2020?spm1011.2415.3001.5343 📚专栏地址:暂定 📝视频地址:黑马程序员 MySQL数据库入门到精通 📣专栏定位:这个专栏我将会整理 B 站黑马程序员的 MySQL…

LeetCode[剑指Offer54]二叉搜索树的第K大节点

难度:简单 题目: 给定一棵二叉搜索树,请找出其中第 k 大的节点的值。 示例 1: 输入: root [3,1,4,null,2], k 13/ \1 4\2 输出: 4 示例 2: 输入: root [5,3,6,2,4,null,null,1], k 35/ \3 6/ \2 4/1 输出: 4 限制: …

[附源码]java毕业设计旅游网站

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

Nodejs编写接口

编写接口 1.自定义路由模块 const expressrequire(express) const routerexpress.Router()// 挂载对应的路由 router.get(/get,(req,res)>{// 通过req.query获取客户端通过查询字符串,发送到服务器的数据const queryreq.query// 调用res.send()方法&#xff0c…

集成学习-Bagging和Boosting算法

文章目录集成学习Bagging随机森林BostingAdaboostGBDTXGBoost集成学习 集成学习(ensemble learning)博采众家之长,通过构建并结合多个学习器来完成学习任务。“三个臭皮匠顶个诸葛亮”,一个学习器(分类器、回归器&…

【微服务】SpringCloud微服务续约源码解析

目录 一、前言 二、客户端续约 1、入口 1.1、构造初始化 1.2、initScheduledTasks() 调度执行心跳任务 2、TimedSupervisorTask组件 2.1、构造初始化 2.2、TimedSupervisorTask#run()任务逻辑 3、心跳任务 3.1、HeartbeatThread私有内部类 3.2、发送心跳 3、发送心…

使用OpenAPI提升网关安全的开源软件,诚邀小伙伴参与

看过我博客的人都知道,我们是一家推广OpenAPI的企业。 OpenAPI是一种用于定义API结构的规范,在Java里我们可以使用swagger进行自动生成。其他语言也可以(Golang等)。通过这种对开发人员零成本的工具,我们可以高效的获…

典型的偏微分方程数值解法

马上要参加亚太杯啦,听说今年亚太杯有经典的物理题,没什么好说的,盘它! 偏微分方程的数值解十分重要 椭圆型偏微分方程(不含时) 数值解法 二维拉普拉斯方程 例 边界条件 import numpy as np import mat…

教你如何使用云服务器搭建我的世界Minecraft服务器(超级简单-10分钟完成)

一个人玩游戏没啥意思,和朋友一块联机呢,距离太远,家庭局域网宽带又没有公网ip,你的朋友没办法与你联机,然而你只需要一台服务器即可搞定了;但是很多用户没没接触过相关的内容,具体的该怎么操作…

怎样做音乐相册怎样制作?手把手教你制作

大家平时出门游玩的时候,会拍摄一些好看的照片吗?那你们会将这些照片分享在社交平台上吗?普通的照片分享,有时会显得比较枯燥单调,其实我们可以将这些照片制作成音乐相册,这样就可以丰富照片的内容&#xf…

传输层-用户数据报协议(UDP)

UDP协议概述 用户数据报协议 UDP 是 Internet 传输层协议,提供无连接、不可靠、数据报尽力传输服务。 无连接:因此在支持两个进程间通信时,没有握手过程。不可靠:当应用进程将一个报文发送近 UDP 套接字时,UDP 并不能…