小白的进阶之路系列之十二----人工智能从初步到精通pytorch综合运用的讲解第五部分

news2025/6/6 8:56:00

在本笔记本中,我们将针对Fashion-MNIST数据集训练LeNet-5的变体。Fashion-MNIST是一组描绘各种服装的图像瓦片,有十个类别标签表明所描绘的服装类型。

# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms

# Image display
import matplotlib.pyplot as plt
import numpy as np

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

# In case you are using an environment that has TensorFlow installed,
# such as Google Colab, uncomment the following code to avoid
# a bug with saving embeddings to your TensorBoard directory

# import tensorflow as tf
# import tensorboard as tb
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

在TensorBoard中显示图像

让我们首先将数据集中的样本图像添加到TensorBoard:

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

if __name__ == '__main__':
    
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])

    # Store separate training and validations splits in ./data
    training_set = torchvision.datasets.FashionMNIST('./data',
        download=True,
        train=True,
        transform=transform)
    validation_set = torchvision.datasets.FashionMNIST('./data',
        download=True,
        train=False,
        transform=transform)

    training_loader = torch.utils.data.DataLoader(training_set,
                                                  batch_size=4,
                                                  shuffle=True,
                                                  num_workers=2)


    validation_loader = torch.utils.data.DataLoader(validation_set,
                                                    batch_size=4,
                                                    shuffle=False,
                                                    num_workers=2)

    # Class labels
    classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
            'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')



    # Extract a batch of 4 images
    dataiter = iter(training_loader)
    images, labels = next(dataiter)

    # Create a grid from the images and show them
    img_grid = torchvision.utils.make_grid(images)
    matplotlib_imshow(img_grid, one_channel=True)
    plt.show()

输出为:

在这里插入图片描述

上面,我们使用TorchVision和Matplotlib创建了一个小批量输入数据的视觉网格。下面,我们在SummaryWriter上使用add_image()调用来记录TensorBoard使用的图像,并且我们还调用flush())来确保它立即写入磁盘。

    # Default log_dir argument is "runs" - but it's good to be specific
    # torch.utils.tensorboard.SummaryWriter is imported above
    writer = SummaryWriter('runs/fashion_mnist_experiment_1')

    # Write image data to TensorBoard log dir
    writer.add_image('Four Fashion-MNIST Images', img_grid)
    writer.flush()

    # To view, start TensorBoard on the command line with:
    #   tensorboard --logdir=runs
    # ...and open a browser tab to http://localhost:6006/

如果您在命令行启动TensorBoard并在新的浏览器选项卡中打开它(通常在localhost:6006),您应该在IMAGES选项卡下看到图像网格。

绘制标量以可视化训练

TensorBoard对于跟踪您的训练进度和效果非常有用。下面,我们将运行一个训练循环,跟踪一些指标,并保存数据供TensorBoard使用。

让我们定义一个模型来对图像块进行分类,以及一个用于训练的优化器和损失函数:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 4 * 4, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(sel

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2398645.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2025年06月03日Github流行趋势

项目名称:onlook 项目地址url:https://github.com/onlook-dev/onlook项目语言:TypeScript历史star数:12871今日star数:624项目维护者:Kitenite, drfarrell, spartan-vutrannguyen, apps/devin-ai-integrati…

【数据分析】基于Cox模型的R语言实现生存分析与生物标志物风险评估

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载导入数据数据预处理生存分析画图输出图片其他标记物的分析总结系统信息介绍 分析生存数据与多种生物标志物之间的关系。它通过Cox比例风险模型来评估不同生物标志物…

使用nginx配置反向代理,负载均衡

首先啥叫反向代理 咋配置呢,那当然是在nginx目录下改conf文件了 具体咋改呢,那就新增一个新的server配置,然后在location里新增你想代理的服务器 实际上负载均衡也就是根据反向代理的思路来的,如下所示 配置的话实际上也与上…

从 iPhone 备份照片: 保存iPhone图片的5种方法

随着智能手机越来越融入我们的生活,我们的照片已成为我们设备上最有价值的数据形式之一。然而,iPhone内部存储空间仍然有限,因此我们需要将iPhone中的照片备份到另一个地方,以释放空间并确保珍贵的图像记忆的安全。阅读本指南&…

Spring Ai 从Demo到搭建套壳项目(一)初识与实现与deepseek对话模式

前言 为什么说Java长青,主要是因为其生态圈完善,Spring又做了一款脚手架,把对接各个LLM厂商的sdk做了一遍,形成一系列的spring-ai-starter-** 的依赖。 目前为止版本去到1.0.0.M6,golang跟不上了吧, Make …

快速上手pytest

1. pytest的基础 1.1 什么是pytest pytest 是 python 中的一个测试框架,用于编写简洁、可扩展的测试代码,帮助开发者验证结果是否与预期相符。 python 中有很多的测试框架,那我们为什么要学习 pytest 呢? pytest 的优势&…

设备驱动与文件系统:02 键盘

操作系统中键盘驱动的讲解 在这一讲中,我将为大家讲解键盘相关内容。从上一讲开始,我们进入了操作系统第四个部分的学习,也就是操作系统对设备的驱动与管理。 上一讲我们探讨的是显示器,并且提到,一个终端设备是由显示…

交通违法拍照数据集,可识别接打电话,不系安全带的行为,支持YOLO,COCO JSON,VOC XML格式的标注数据集 最高正确识别率可达88.6%

交通违法拍照数据集 数据集概述 数据来源:交通监控摄像头、执法记录仪、公开数据集数据类型:图像、视频、元数据(时间、地点、车辆信息)违法类型标注:接打电话、未系安全带 数据采集与标注方法 采集设备&#xff1…

Qt OpenGL 3D 编程入门

Qt 提供了强大的 OpenGL 集成功能,使得在 Qt 应用中实现 3D 图形变得更加简单。以下是使用 Qt 进行 OpenGL 3D 编程的基础知识。 1. 环境配置 创建 Qt 项目 新建 Qt Widgets Application 项目 在 .pro 文件中添加 OpenGL 模块: qmake QT co…

性能优化 - 工具篇:基准测试 JMH

文章目录 Pre引言1. JMH 简介2. JMH 执行流程详解3. 关键注解详解3.1 Warmup3.2 Measurement3.3 BenchmarkMode3.4 OutputTimeUnit3.5 Fork3.6 Threads3.7 Group 与 GroupThreads3.8 State3.9 Setup 与 TearDown3.10 Param3.11 CompilerControl 4. 示例代码与分析4.1 关键点解读…

Nginx网站服务:从入门到LNMP架构实战

🏡作者主页:点击! Nginx-从零开始的服务器之旅专栏:点击! 🐧Linux高级管理防护和群集专栏:点击! ⏰️创作时间:2025年5月30日14点22分 前言 说起Web服务器&#xff0c…

Java面试八股--08-数据结构和算法篇

1、怎么理解时间复杂度和空间复杂度 时间复杂度和空间复杂度一般是针对算法而言,是衡量一个算法是否高效的重要标准。先纠正一个误区,时间复杂度并不是算法执行的时间,在纠正一个误区,算法不单单指冒泡排序之类的,一个…

Java面试八股--06-Linux篇

目录 一、Git 1、工作中git开发使用流程(命令版本描述) 2.Reset与Rebase,Pull与Fetch的区别 3、git merge和git rebase的区别 4、git如何解决代码冲突 5、项目开发时git分支情况 二、Linux 1、Linux常用的命令 2、如何查看测试项目的…

dvwa7——SQL Injection

LOW: f12打开hackbar 一:判断注入类型 输入id1报错 闭合单引号 ,页面恢复正常 所以为单引号字符型 二:开始攻击 1.判断列数 ?id1 order by 2-- 到3的时候开始报错,所以一共两列 2.爆回显位置 ?id-1 union s…

sqlite3 命令行工具详细介绍

一、启动与退出 启动数据库连接 sqlite3 [database_file] # 打开/创建数据库文件(如 test.db) sqlite3 # 启动临时内存数据库 (:memory:) sqlite3 :memory: # 显式启动内存数据库文件不存在时自动创建不指定文件名则使用临时内…

ubuntu 22.04 编译安装nignx 报错 openssl 问题

前言 Ubuntu 20.04 中安装 Nginx (通过传包编译的方式)、开启关闭防火墙、开放端口号 在ubuntu 22.04.3 服务器上照着上面的文章 通过传包编译的方式安装nginx-1.18.0 的时候报错,报错内容如下: src/event/ngx_event_openssl.c: In function ‘ngx_ssl…

线程相关面试题

提示:线程相关面试题,持续更新中 文章目录 一、Java线程池1、Java线程池有哪些核心参数,分别有什么的作用?2、线程池有哪些拒绝策略?3、说一说线程池的执行流程?4、线程池核心线程数怎么设置呢?4、Java线程…

pikachu通关教程-目录遍历漏洞(../../)

目录遍历漏洞也可以叫做信息泄露漏洞、非授权文件包含漏洞等. 原理:目录遍历漏洞的原理比较简单,就是程序在实现上没有充分过滤用户输入的../之类的目录跳转符,导致恶意用户可以通过提交目录跳转来遍历服务器上的任意文件。 这里的目录跳转符可以是../…

Maven-生命周期

目录 1.项目对象模型 2.依赖管理模型 3.仓库:用于存储资源,管理各种jar包 4.本地仓库路径 1.项目对象模型 2.依赖管理模型 3.仓库:用于存储资源,管理各种jar包 4.本地仓库路径

Matlab实现LSTM-SVM回归预测,作者:机器学习之心

Matlab实现LSTM-SVM回归预测,作者:机器学习之心 目录 Matlab实现LSTM-SVM回归预测,作者:机器学习之心效果一览基本介绍程序设计参考资料 效果一览 基本介绍 代码主要功能 该代码实现了一个LSTM-SVM回归预测模型,核心流…