基于Pytorch实现图像分类——基于jupyter

news2025/6/2 5:57:20

分类任务

  • 网络基本构建与训练方法,常用函数解
  • torch.nn.functional模块
  • nn.Module模块

MNIST数据集下载

from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

解压数据集

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

查阅数据

from matplotlib import pyplot
import numpy as np

pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

在这里插入图片描述

网络模型搭建

在这里插入图片描述

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

在这里插入图片描述

常用函数介绍

import torch.nn.functional as F

loss_func = F.cross_entropy

def model(xb):
    return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
bs = 64
bias = torch.zeros(10, requires_grad=True)

print(loss_func(model(xb), yb))

模型搭建

from torch import nn

class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out  = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
net = Mnist_NN()
print(net)

Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)

for name, parameter in net.named_parameters():
    print(name, parameter,parameter.size())

dataset数据接口

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)

def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
from torch import optim
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))

def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1592892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言之typeof用法实例(九十二)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

MySQL-触发器:触发器概述、触发器的创建、查看删除触发器、 触发器的优缺点

触发器 触发器1. 触发器概述2. 触发器的创建2.1 创建触发器语法2.2 代码举例 3. 查看、删除触发器3.1 查看触发器3.2 删除触发器 4. 触发器的优缺点4.1 优点4.2 缺点4.3 注意点 注:此为笔者学习尚硅谷-宋红康MySQL的笔记,其中包含个人的笔记和理解&#…

APIGateway的认证

APIGateway的支持的认证如下: 我们从表格中可以看到,HTTP API 不支持资源策略的功能,另外是通过JWT的方式集成Cognito的。 对于REST API则是没有显示说明支持JWT认证,这个我们可以通过Lambda 自定义的方式来实现。 所以按照这个…

【C Hash Map from Redis】

将Redis源码中的哈希表底层逻辑提取,并进行最小demo级测试将对应文件抽出,通过宏替换等方式保证源码编译通过main.c编写测试demo ,注册哈希函数和值比较函数(必选项) /* Hash Tables Implementation.** This file imp…

ThingsBoard通过服务端获取客户端属性或者共享属性

MQTT基础 客户端 MQTT连接 通过服务端获取属性值 案例 1、首先需要创建整个设备的信息,并复制访问令牌 ​2、通过工具MQTTX连接上对应的Topic 3、测试链接是否成功 4、通过服务端获取属性值 5、在客户端查看对应的客户端属性或者共享属性的key 6、查看整个…

MySQL之索引失效、覆盖、前缀索引及单列、联合索引详细总结

索引失效 最左前缀法则 如果索引了多列(联合索引),要遵守最左前缀法则,最左前缀法则指的是查询从索引的最左列开始,并且不跳过索引中的列。如果跳跃某一列,索引将部分失效(后面的字段索引失效)。 联合索…

Golang(一):基础、数组、map、struct

目录 hello world 变量 常量,iota 函数 init函数和导包过程 指针 defer 数组和动态数组 固定长度数组 遍历数组 动态数组 len 和 cap 截取 切片的追加 map 四种声明方式 遍历map 删除 查看键是否存在 结构体 声明 作为形参 方法 封装 继承…

【PG-1】PostgreSQL体系结构概述

1. PostgreSQL体系结构概述 代码结构 其中,backend是后端核心代码,包括右边的几个dir: access:处理数据访问方法和索引的代码。 bootstrap:数据库初始化相关的代码。 catalog:系统目录(如表和索引的元数据…

富文本回显 p 标签?去不掉怎么办?如何解决?

使用前端框架富文本控件上传的上传的数据&#xff0c;回显到文本框时显示<p></p>标签&#xff0c;并且数据库里面的数据也为带有p标签的数据&#xff0c;如何去掉 解决办法 使用正则表达式来讲HTML的内容进行替换更改&#xff0c;在vue中定义方法 //移除HTML标签…

Linux的学习之路:9、冯诺依曼与进程(1)

摘要 本章主要是说一下冯诺依曼体系结构和进程的一部分东西。 目录 摘要 一、冯诺依曼体系结构 二、操作系统的概念 三、设计OS的目的 四、管理 五、进程的基本概念 六、PCB 七、在Linux环境下查看进程 八、使用代码创建进程 九、思维导图 一、冯诺依曼体系结构 如…

pyside6自定义部件库和软件框架的建设记录

自定义的部件库原则上尽量做到前后端分离&#xff0c;接口方便&#xff0c;复制简单。 单选框部件 # encoding: utf-8 ################################################### # 自定义的单选框 #################################################### 对外接口&…

【MYSQL】MySQL整体结构之系统服务

一、系统服务层 学习了MySQL网络连接层后&#xff0c;接下来看看系统服务层&#xff0c;MySQL大多数核心功能都位于这一层&#xff0c;包括客户端SQL请求解析、语义分析、查询优化、缓存以及所有的内置函数&#xff08;例如&#xff1a;日期、时间、统计、加密函数...&#xff…

【Godot4.2】CanvasItem绘图函数全解析 - 9.绘制表格

概述 之前介绍TextLine和TextParagraph的时候&#xff0c;提到了用制表符和设定列宽形式来绘制简易表格&#xff0c;但是很明显&#xff0c;单纯使用此种方式很难获得对表格的精确控制。 所以对于表格绘制问题&#xff0c;我决定单独开坑&#xff0c;单独深入研究。 目前比较…

Steam平台游戏发行流程

Steam平台游戏发行流程 大家好我是艾西&#xff0c;一个做服务器租用的网络架构师也是游戏热爱者&#xff0c;经常在steam平台玩各种游戏享受快乐生活。去年幻兽帕鲁以及雾锁王国在年底横空出世&#xff0c;幻兽帕鲁更是在短短一星期取得了非常好的成绩&#xff0c;那么作为游戏…

Redis从入门到精通(十七)多级缓存(二)Lua语言入门、OpenResty集群的安装与使用

文章目录 前言6.4 Lua语法入门6.4.1 初识Lua6.4.2 Hello World6.4.3 变量6.4.3.1 Lua的数据类型6.4.3.2 声明变量 6.4.4 循环6.4.5 函数6.4.6 条件控制 6.5 实现多级缓存6.5.1 安装和启动OpenResty6.5.2 实现ajax请求反向代理至OpenResty集群6.5.2.1 反向代理配置6.5.2.2 OpenR…

古月·ROS2入门21讲——学习笔记(二)常用工具部分15-21讲

上篇&#xff1a;古月ROS2入门21讲——学习笔记&#xff08;一&#xff09;核心概念部分1-14讲-CSDN博客 第十五讲&#xff1a;Launch&#xff1a;多节点启动与配置脚本 到目前为止&#xff0c;每当我们运行一个ROS节点&#xff0c;都需要打开一个新的终端运行一个命令。机器人…

使用ioctl扫描wifi信号获取信号属性的实例(二)

使用工具软件扫描 wifi 信号是一件很平常的事情&#xff0c;在知晓 wifi 密码的前提下&#xff0c;通常我们会尽可能地连接信号质量比较好的 wifi 信号&#xff0c;但是如何通过编程来扫描 wifi 信号并获得这些信号的属性(比如信号强度等)&#xff0c;却鲜有文章提及&#xff0…

第16天:信息打点-CDN绕过业务部署漏洞回链接口探针全网扫描反向邮件

第十六天 本课意义 1.CDN服务对安全影响 2.CDN服务绕过识别手法 一、CDN服务-解释差异识别 1.前置知识&#xff1a; 传统访问&#xff1a;用户访问域名–>解析服务器IP–>访问目标主机普通CDN&#xff1a;用户访问域名–>CDN节点–>真实服务器IP–>访问目标…

C++项目 -- 负载均衡OJ(一)comm

C项目 – 负载均衡OJ&#xff08;一&#xff09;comm 文章目录 C项目 -- 负载均衡OJ&#xff08;一&#xff09;comm一、项目宏观结构1.项目功能2.项目结构 二、comm公共模块1.util.hpp2.log.hpp 一、项目宏观结构 1.项目功能 本项目的功能为一个在线的OJ&#xff0c;实现类似…

LLM生成模型在生物单细胞single cell的应用:scGPT

参考&#xff1a; https://github.com/bowang-lab/scGPT https://www.youtube.com/watch?vXhwYlgEeQAs 主要是把单细胞测序出来的基因表达量的拼接起来构建成的序列&#xff0c;这里不是用的基因的ATCG&#xff0c;是直接用的基因名称 训练数据&#xff1a;scGPT全人模型是在3…