刘二大人《Pytorch深度学习实践》第八讲加载数据集

news2025/9/14 0:59:20

文章目录

  • Epoch、Batch-Size、Iterations
  • Dataset、DataLoader
  • 课上代码
  • torchvision中数据集的加载

Epoch、Batch-Size、Iterations

在这里插入图片描述
1、所有的训练集进行了一次前向和反向传播,叫做一个Epoch
2、在深度学习训练中,要给整个数据集分成多份,即mini-batch,每个mini-batch所包含的样本的数量叫做Batch-Size
3、因为数据集分成了多个mini-batch,有多少份mini-batch就有多少个Iteration,每进行一次mini-batch的前向和后向传播,就会进行一次权重参数的更新,在一个Epoch中,有多少个Iteration,就更新了多少次权重参数

Dataset、DataLoader

在这里插入图片描述
1、DataSet 是抽象类,不能实例化对象,需要自己定义类继承该抽象类并实现其中的方法
2、init()函数里面主要用来加载数据集,分成x_data,y_data
3、__getitem()__主要根据下表来获取数据集
4、len() 主要用来返回数据集的个数
5、DataLoader是Pytorch中用来处理模型输入数据的一个工具类。组合了数据集(dataset) + 采样器(sampler),并在数据集上提供单线程或多线程(num_workers )的可迭代对象。在DataLoader中有多个参数,这些参数中重要的几个参数的含义说明如下:

 1. epoch:所有的训练样本输入到模型中称为一个epoch; 
 2. iteration:一批样本输入到模型中,成为一个Iteration;
 3. batchszie:批大小,决定一个epoch有多少个Iteration;
 4. 迭代次数(iteration)=样本总数(epoch)/批尺寸(batchszie)
 5. dataset (Dataset) – 决定数据从哪读取或者从何读取;
 6. batch_size (python:int, optional) – 批尺寸(每次训练样本个数,默认为1)
 7. shuffle (bool, optional) –每一个 epoch是否为乱序 (default: False);
 8. num_workers (python:int, optional) – 是否多进程读取数据(默认为0);
 9. drop_last (bool, optional) – 当样本数不能被batchsize整除时,最后一批数据是否舍弃(default: False)
 10. pin_memory(bool, optional) - 如果为True会将数据放置到GPU上去(默认为false) 

在这里插入图片描述

课上代码

import torch 
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class DiabetesDataset (Dataset):
  def __init__(self):
    xy = np.loadtxt('diabetes.csv', delimiter=',', dtype = np.float32)
    self.len = xy.shape[0]
    self.x_data = torch.from_numpy (xy[:,:-1])
    self.y_data = torch.from_numpy (xy[:,[-1]])
  
  def __getitem__(self, index):
    return self.x_data[index], self.y_data[index]
  
  def __len__(self):
    return self.len

dataset = DiabetesDataset()
train_loader = DataLoader (dataset=dataset, batch_size=32, shuffle=True, num_workers=0)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6) # 输入数据x的特征是8维,x有8个特征
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid() # 将其看作是网络的一层,而不是简单的函数使用
 
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x)) # y hat
        return x
 
 
model = Model()

criterion = torch.nn.BCELoss(reduction='mean')  
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range (100):
  for i, data in enumerate (train_loader, 0):
    inputs, labels = data
    y_pred = model (inputs)
    loss = criterion (y_pred, labels)
    print (epoch, i, loss.item())

    optimizer.zero_grad()
    loss.backward ()
    optimizer.step()

torchvision中数据集的加载

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/412490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【密码学】ElGamal加密算法原理 以及 例题讲解

目录前言1. 原理2. 例题2.1 例题一2.2 例题二前言 具体的性质: 非对称加密算法应用于一些技术标准中,如数字签名标准(DSS)、S/MIME 电子邮件标准算法定义在任何循环群 G 上,安全性取决于 G 上的离散对数难题 1. 原理…

元宇宙地产暴跌,林俊杰亏麻了

文/章鱼哥出品/陀螺财经随着元宇宙的兴起,元宇宙地产曾一度被寄予厚望,成为各大投资者追捧的对象。然而,最近的一次元宇宙地产价值暴跌再次提醒我们,高收益背后可能伴随着高风险。根据元宇宙分析平台WeMeta的数据显示,…

400以内的蓝牙耳机哪款好?400以内蓝牙耳机排行榜

谈起TWS,无论是传统的音频厂商还是手机厂商,都是其不可或缺的重要产品线,现在很多许多蓝牙耳机都不是千篇一律得形状,市场也鲜有商家在外观上下功夫,下面分享几款400元以内,内外兼具的耳机品牌。 一、南卡…

Spring boot+Vue3博客平台:修改密码与找回密码的设计与实现

修改密码与找回密码功能的设计与实现涉及到前后端的配合。本文将详细介绍如何通过设计思路、技术实现和代码示例实现这两个功能。 一、修改密码功能 设计思路 在设计修改密码功能时,需要注意以下几点: 用户输入的当前密码需要正确新密码需要满足一定的…

查询优化器:选择最优的查询路径

当我们通过解析器理解了SQL语句要干什么之后,接着会找查询优化器(Optimizer)来选择一个最优的查询路径。 可能有同学这里就不太理解什么是最优的查询路径了,这个看起来确实很抽象,当然,这个查询优化器的工…

总结819

学习目标: 4月(复习完高数18讲内容,背诵21篇短文,熟词僻义300词基础词) 第二周: 学习内容: 暴力英语:早上背诵《think different》记150词,默写了两篇文章&#xff0c…

BUUCTF-warmup_csaw_2016

1.checksec/file 64位的linux文件 2.ida 找到主函数 发现致命函数 get() 因为get可以无限输入 看看有没有什么函数我们可以返回的 双击进入sub_40060d 直接发现这个函数是取flag的 所以我们开始看这个函数的地址 所以函数地址是 0x40060d 我们看看get什么时候开始的 发现g…

Stable Diffusion复现——基于 Amazon SageMaker 搭建文本生成图像模型

众所周知,Stable Diffusion扩散模型的训练和推理非常消耗显卡资源,我之前也是因为资源原因一直没有复现成功。而最近我在网上搜索发现,亚马逊云科技最近推出了一个【云上探索实验室】刚好有复现Stable Diffusion的活动,其使用亚马…

超详细从入门到精通,pytest自动化测试框架实战-fixture高级进阶(十)

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 1、fixture的autous…

jQuery+AJAX技术(简单的用户注册功能)

目录1、jQuery是什么?2、AJAX是什么?3、jQuery与AJAX的关系?使用jQuery实现AJAX示例:4、jQueryAJAX技术实现用户注册验证功能。1、jQuery是什么? jQuery 是一个快速,小型且功能丰富的JavaScript库。它使 诸…

gradle编译项目报错Execution failed for task ‘:bootJar‘,‘:mainClass‘,‘:compileJava‘.

目录1.问题2.问题查找3.更多1.问题 idea导入Gradle管理的SpeingBoot多模块项目,依赖下载不下来,执行编译报错 报错信息: 2.问题查找 首先怀疑是不是idea的版本与gradle版本冲突,我用的是idea2022.3.3,gradle是7.5.…

做一个内心强大的人

想想类似如下这种心灵鸡汤,本不太愿意在这个平台发布,但是偶尔喝点又何尝不可! 语录摘抄/分享: 1、你开始炫耀自己,往往都是灾难的开始,就像老子在《道德经》里写到: 光而不耀,静水流深。 2、…

汽车电子相关术语介绍

一、相关术语介绍 1、汽车OTA 全称“Over-The-Air technology ”,即空中下载技术,通过移动通信的接口实现对软件进行远程管理,传统的做法到4S店通过整车OBD对相应的ECU进行软件升级。OTA技术最早2000年出现在日本,目前通过OTA方式…

HashMap源码解析超详细

HashMap源码详解1、概述2、源码解析1.HashMap底层存储结构问题一: 为什么直接就用数组呢?问题二:什么是红黑树呢?问题三:为什么不一下子把整个链表变为红黑树呢?2.HashMap的重要成员变量3.构造方法解析4.Put方法解析取…

渗透测试工具库-收藏版

1.前言 浩二一开始做渗透测试的时候收集超多的资料和工具,一直在文档里吃灰。今天全部放出来分享给大家,需要的自己收藏。 2.漏洞练习平台 WebGoat漏洞练习平台: https://github.com/WebGoat/WebGoat webgoat-legacy漏洞练习平台: http…

好程序员:Java书籍推荐,程序员必看的5本Java书籍,赶紧收藏!

今天好程序员给大家推荐5本Java书籍,各大高校都在使用(具体名单如下),所有学习Java的程序员都不应该错过! 第一本Java书籍《Java EE(SSM框架)企业应用实战》 本书全面介绍了JavaEE中MyBatis、Sp…

ChatGLM-6B论文代码笔记

ChatGLM-6B 文章目录ChatGLM-6B前言一、原理1.1 优势1.2 实验1.3 特点:1.4 相关知识点二、实验2.1 环境基础2.2 构建环境2.3 安装依赖2.4 运行2.5 数据2.6 构建前端页面3 总结前言 Github:https://github.com/THUDM/ChatGLM-6B 参考链接: ht…

“QT 快速上手指南“ 之 计算器(二)

文章目录前言一、QT 基本组件用法介绍:1. QLabel :2. QPushButton :3. QLineEdit:二、坐标系统三、窗口部件的大小设置1. setSize( ) 函数:2. resize( )函数:3. setFixedSize( )函数:4. setFixedWidth( ) 和 setFixedHeight( )函数…

大一被忽悠进了培训班

大家好,我是帅地。 最近我的知识星球开始营业,不少大一大二的小伙伴也是纷纷加入了星球,并且咨询的问题也是五花八门,反正就是,各种迷茫,其中有一个学弟,才大一,就报考培训班&#…

命令注入概述

概述命令注入即 Command Injection。是指在开发需求中,需要调用一些系统的命令来完成某些特定的功能。当未对用户输入的参数进行严格的过滤时,则有可能发生命令注入。攻击者可以通过提交恶意构造的参数破坏命令语句结构,从而达到执行恶意命令…