Pytorch中预置数据集的加载方式

news2025/5/13 6:04:47

Pytorch中数据集加载方式

数据类型PyTorch 模块是否预置数据集
图像/视频torchvision.datasets✅ 是
音频torchaudio.datasets✅ 是
文本torchtext.datasets✅ 是(需安装)
自定义数据torch.utils.data❌ 否(需手动实现)
多模态/第三方数据Hugging Face 等✅ 是

深度学习任务中常用数据集来源 

        上表是我们在深度学习任务中常用的数据集来源,这里我们以torchvision为例,来说明预置数据集的加载方式。

torchvision内所有数据集 

torchvision加载数据集模版

CIFER-10数据集

CIFAR-10数据集共有60000个样本,每个样本都是一张32*32像素的RGB图像(彩色图像),每个RGB图像又必定分为3个通道(R通道、G通道、B通道)。这60000个样本被分成了50000个训练样本和10000个测试样本。

Cifer-10数据集概览图(来自官网)

加载数据集

'''
root:数据集在你的电脑上的路径,若download=True,root需设置为 './data'
train:训练集还是测试集,train=True,此时加载的是训练集,train=False,此时加载的是测试集
download:是否下载数据集,倘若你手头上没有数据集又想直接训练模型,那么需要将download设置为True
并且将root设置为',/data',因为预置数据集会默认下载到当前工作区域文件夹下,且名称为data
transform:数据预处理函数
'''
import torchvision
#训练集
trainset=torchvision.datasets.数据集名称(root='./data', train=True, download=True, 
transform=transform_train)
#测试集
testset=torchvision.datasets.数据集名称(root='./data', train=False, download=True, transform=transform_test)

 数据预处理

        torchvision中我们常用transforms模块来对图像数据进行预处理,其中ToTensor()函数是必须要有的(只有将图像数据转换为Tensor才能训练),其余操作可以自行添加,这里我使用了RandomCrop和RandomHorizontalFlip这两个操作

        RandomCrop会将图像随机裁剪为32×32像素,先填充4像素的边界(四周各填充4像素,图像变为40×40),再随机裁剪回32×32。

        RandomHorizontalFlip则是按照50%的概率水平翻转图像。

        经过这两个操作数据的多样性会增强,最后为了加快模型训练速度提升训练进度,一般而言我们都是需要对数据进行标准化处理的,Normalize函数传入的参数为均值mean与方差std

        这里,我们的数据是32x32的jpg图像,有三个通道,因此均值与方差都是三元组,这里的均值与方差的数据是CIFAR-10数据集的统计值。

DataLoder加载数据

        DataLoder用于高效加载和预处理数据,并将其批量提供给模型进行训练或推理。 

from torch.utils.data import DataLoader
'''
dataset:数据集,经过torchvision.datasets.数据集名称()后的训练集或数据集
batch_size:每个批次的样本数。
shuffle:是否是否在每个epoch开始时打乱数据顺序。
num_workers:用于数据加载的子进程数(多进程加速)。可以设置为CPU核心数-1(如4或8)
'''
trainloader=DataLoader(dataset=trainset, batch_size=128, shuffle=True, num_workers=8)
testloader=DataLoader(dataset=testset, batch_size=100, shuffle=False, num_workers=8)

 完整代码

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
'''数据预处理'''
transform_train=transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
'''加载CIFAR-10数据集'''
################################################################
#训练集
trainset=torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader=DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
################################################################################
#测试集
testset=torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader=DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
##############################################################################

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2328764.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ARM-----数据处理、异常处理、模式切换

实列一: 1. 异常向量表 area reset, code, readonly code32 entry area reset, code, readonly:定义一个名为reset的代码区域,只读。 code32:指示编译器生成32位ARM指令。 entry:标记程序的入口点。 2. 程序入口…

d202541

目录 一、分隔链表 二、旋转链表 三、删除链表中重复的数字 一、分隔链表 用两个list存一下小于和大于等于 x的节点 最后串起来就行 public ListNode partition(ListNode head, int x) {ListNode ret new ListNode(1);ListNode cur ret;List<ListNode> small new A…

YOLOv12 从预训练迈向自主训练,第一步数据准备

视频讲解&#xff1a; YOLOv12 从预训练迈向自主训练&#xff0c;第一步数据准备 前面复现过yolov12&#xff0c;使用pre-trained的模型进行过测试&#xff0c;今天来讲下如何训练自己的模型&#xff0c;第一步先准备数据和训练格式 https://gitcode.com/open-source-toolkit/…

【UVM学习笔记】更加灵活的UVM—通信

系列文章目录 【UVM学习笔记】UVM基础—一文告诉你UVM的组成部分 【UVM学习笔记】UVM中的“类” 文章目录 系列文章目录前言一、TLM是什么&#xff1f;二、put操作2.1、建立PORT和EXPORT的连接2.2 IMP组件 三、get操作四、transport端口五、nonblocking端口六、analysis端口七…

NSSCTF [HGAME 2023 week1]simple_shellcode

3488.[HGAME 2023 week1]simple_shellcode 手写read函数shellcode和orw [HGAME 2023 week1]simple_shellcode (1) motalymotaly-VMware-Virtual-Platform:~/桌面$ file vuln vuln: ELF 64-bit LSB pie executable, x86-64, version 1 (SYSV), dynamically linked, interpret…

数据集(Dataset)和数据加载器(DataLoader)-pytroch学习3

pytorch网站学习 处理数据样本的代码往往会变得很乱、难以维护&#xff1b;理想情况下&#xff0c;我们希望把数据部分的代码和模型训练部分分开写&#xff0c;这样更容易阅读、也更好维护。 简单说&#xff1a;数据和模型最好“分工明确”&#xff0c;不要写在一起。 PyTor…

数据结构|排序算法(一)快速排序

一、排序概念 排序是数据结构中的一个重要概念&#xff0c;它是指将一组数据元素按照特定的顺序进行排列的过程&#xff0c;默认是从小到大排序。 常见的八大排序算法&#xff1a; 插入排序、希尔排序、冒泡排序、快速排序、选择排序、堆排序、归并排序、基数排序 二、快速…

文件或目录损坏且无法读取:数据恢复的实战指南

在数字化时代&#xff0c;数据的重要性不言而喻。然而&#xff0c;在日常使用电脑、移动硬盘、U盘等存储设备时&#xff0c;我们难免会遇到“文件或目录损坏且无法读取”的提示。这一提示如同晴天霹雳&#xff0c;让无数用户心急如焚&#xff0c;尤其是当这些文件中存储着重要的…

leetcode数组-螺旋矩阵Ⅱ

题目 题目链接&#xff1a;https://leetcode.cn/problems/spiral-matrix-ii/ 给你一个正整数 n &#xff0c;生成一个包含 1 到 n2 所有元素&#xff0c;且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。 输入&#xff1a;n 3 输出&#xff1a;[[1,2,3],[8,9,4],[7…

小刚说C语言刷题——第14讲 逻辑运算符

当我们需要将一个表达式取反&#xff0c;或者要判断两个表达式组成的大的表达式的结果时&#xff0c;要用到逻辑运算符。 1.逻辑运算符的分类 (1)逻辑非(!) &#xff01;a&#xff0c;当a为真时&#xff0c;&#xff01;a为假。当a为假时&#xff0c;&#xff01;a为真。 例…

WPS宏开发手册——Excel实战

目录 系列文章5、Excel实战使用for循环给10*10的表格填充行列之和使用for循环将10*10表格中的偶数值提取到另一个sheet页使用for循环给写一个99乘法表按市场成员名称分类&#xff08;即市场成员A、B、C...&#xff09;&#xff0c;统计月内不同时间段表1和表2的乘积之和&#x…

【Cursor】切换主题

右键顶部&#xff0c;把菜单栏勾上 首选项-主题-颜色主题 选择和喜欢的颜色主题即可&#xff0c;一般是“现代深色”

spring druid项目中监控sql执行情况

场景 在 Spring Boot 结合 MyBatis 的服务中&#xff0c;实现 SQL 执行覆盖情况的监控&#xff0c;可以基于Druid提供的内置的 SQL 监控统计功能。 开启监控 在 application.yml 中启用 Druid 的 stat 和 wall 过滤器&#xff0c;并配置监控页面的访问权限 …

Obsidian按下三个横线不能出现文档属性

解决方案: 需要在标题下方的一行, 按下 键盘数字0后面那个横线(英文横线), 然后回车就可以了 然后点击横线即可

pyqt SQL Server 数据库查询-优化2

1、增加导出数据功能 2、增加删除表里数据功能 import sys import pyodbc from PyQt6.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QListWidget, QLineEdit, QPushButton, \QTableWidget, QTableWidgetItem, QLabel, QMessageBox from PyQt6.QtGui i…

Hyperlane:高性能 Rust HTTP 服务器框架评测

Hyperlane&#xff1a;高性能 Rust HTTP 服务器框架评测 在当今快速发展的互联网时代&#xff0c;选择一个高效、可靠的 HTTP 服务器框架对于开发者来说至关重要。最近&#xff0c;我在评估各种服务器框架性能时&#xff0c;发现了一个名为 Hyperlane 的 Rust HTTP 服务器库&a…

Laravel 中使用 JWT 作用户登录,身份认证

什么是JWT&#xff1a; JWT 全名 JSON Web Token&#xff0c;是一种开放标准 (RFC 7519)。 用于在网络应用环境间安全地传输信息作为 JSON 对象。 它是一种轻量级的认证和授权机制&#xff0c;特别适合分布式系统的身份验证。 核心特点 紧凑格式&#xff1a;体积小&#x…

VBA中类的解读及应用第二十二讲:利用类判断任意单元格的类型-5

《VBA中类的解读及应用》教程【10165646】是我推出的第五套教程&#xff0c;目前已经是第一版修订了。这套教程定位于最高级&#xff0c;是学完初级&#xff0c;中级后的教程。 类&#xff0c;是非常抽象的&#xff0c;更具研究的价值。随着我们学习、应用VBA的深入&#xff0…

STM32F103_LL库+寄存器学习笔记13 - 梳理外设CAN与如何发送CAN报文(串行发送)

导言 CAN总线因其高速稳定的数据传输与卓越抗干扰性能&#xff0c;在汽车、机器人及工业自动化中被广泛应用。它采用分布式网络结构&#xff0c;实现多节点间实时通信&#xff0c;确保各控制模块精准协同。在汽车领域&#xff0c;CAN总线连接发动机、制动、车身系统&#xff0c…