《Pytorch深度学习实践》ch8-多分类

news2025/7/28 6:27:35

                                                        ------B站《刘二大人》

1.Softmax Layer

  • 在多分类问题中,输出的是每类的概率:

  • 计算公式:保证了每类概率大于 0 ,又由保证了概率之和为 1;

  • 举例如下:

2.Cross Entropy

  • 计算损失:

  • y = np.array([1, 0, 0]):是目标标签的 one-hot 编码。假设有 3 个类别,这里表示正确的类别是第一个类别;
import numpy as np
y = np.array([1, 0, 0])
z = np.array([0.2, 0.1, -0.1])
y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss) # 0.9729189131256584
  • 交叉熵损失函数: 

  • y 是一个长度为 1 的长整型张量,是标签类别的 索引[0] 表示正确的类别是类别 0;
import torch
y = torch.LongTensor([0])
z = torch.Tensor([[0.2, 0.1, -0.1]])
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(z, y)
print(loss) # tensor(0.9729)
  • Mini - Batch
import torch
criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2, 0, 1])

Y_pred1 = torch.Tensor([[0.1, 0.2, 0.9],
                        [1.1, 0.1, 0.2],
                        [0.2, 2.1, 0.1]])
Y_pred2 = torch.Tensor([[0.8, 0.2, 0.3],
                        [0.2, 0.3, 0.5],
                        [0.2, 0.2, 0.5]])

loss1 = criterion(Y_pred1, Y) # Batch Loss1 =  tensor(0.4966)
loss2 = criterion(Y_pred2, Y) # Batch Loss2 =  tensor(1.2389)
print('Batch Loss1 = ', loss1.data, '\nBatch Loss2 = ', loss2.data)

3.MNIST

  • 导包
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
  • 准备数据集
    • ToTensor():将图片转换为PyTorch的张量。
    • Normalize(mean, std):使用指定的均值和标准差对图片进行标准化。

batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
])

train_dataset = datasets.MNIST('data/MNIST/', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST('data/MNIST/', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
  • 构造模型
    • 输入层:784个神经元(因为每张图片是28x28,展平后变成784维)。
    • 隐藏层:4个全连接层,神经元数量分别为512、256、128和64。
    • 输出层:10个神经元,分别对应数字0到9。
    • 最后一层不做激活,因为后面调用 torch.nn.CrossEntropyLoss。
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = torch.nn.Linear(784, 512)
        self.linear2 = torch.nn.Linear(512, 256)
        self.linear3 = torch.nn.Linear(256, 128)
        self.linear4 = torch.nn.Linear(128, 64)
        self.linear5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        x = self.linear5(x) # 不用激活函数,因为 torch.nn.CrossEntropyLoss = softmax + nllloss
        return x
    
model = Net()
  • 损失与优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
  • 训练与测试
    • torch.max:返回最大值和对应的下标。
    • dim=1,说明是在行的维度。 0是列,1是行。
# training
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0.0

# test
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy on test set: %d %%' %(100*correct/total))


if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        if epoch % 10 == 0:
            test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2401743.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国产录播一体机:科技赋能智慧教育信息化

在数字化时代,教育正经历着前所未有的变革。国产工控机作为信息化教育的核心载体,正在重新定义学习方式,赋能教师与学生,打造高效、互动、智能的教学环境,让我们一起感受科技与教育的深度融合!高能计算机推…

关于逻辑回归的见解

逻辑回归通过将线性回归的输出映射到 [ 0 , 1 ] \left[0,1\right] [0,1]区间,来表示某个类别的概率。也就是其本质是先通过线性回归的预测值 y \boldsymbol{y} y输入到映射函数,既将线性回归的输出通过映射函数映射到 [ 0 , 1 ] \left[0,1\right] [0,1].常用的映射函数是sigm…

Amazon Augmented AI:人类智慧与AI协作,破解机器学习审核难题

在人工智能日益渗透业务核心的今天,你是否遭遇过这样的困境:自动化AI处理海量数据时,面对模糊、复杂或高风险的场景频频“卡壳”?人工审核团队则被低效、重复的任务压得喘不过气?Amazon Augmented AI (A2I) 的诞生&…

VIN码车辆识别码解析接口如何用C#进行调用?

一、什么是VIN码车辆识别码解析接口 输入17位vin码,获取到车辆的品牌、型号、出厂日期、发动机类型、驱动类型、车型、年份等信息。无论是汽车电商平台、二手车商、维修厂,还是保险公司、金融机构,都能通过接入该API实现信息自动化、决策智能…

Playwright 测试框架 - Java

🚀【Playwright + Java 实战教程】从零到一掌握自动化测试利器! 🔧 本文专为 Java 开发者量身打造,通过详尽示例带你快速掌握 Playwright 自动化测试。涵盖基础操作、表单交互、测试框架集成、高阶功能及常见实战技巧,适用于企业 UI 测试与 CI/CD 场景。 🛠️ 一、环境…

力扣100题之128. 最长连续序列

方法1 使用了hash 方法思路 使用哈希集合:首先将数组中的所有数字存入一个哈希集合中,这样可以在 O(1) 时间内检查某个数字是否存在。 寻找连续序列:遍历数组中的每一个数字,对于每一个数字, 检查它是否是某个连续序列…

算法打卡12天

19.链表相交 (力扣面试题 02.07. 链表相交) 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点,返回 null 。 图示两个链表在节点 c1 开始相交**:** 题目数据…

蓝桥杯国赛训练 day1 Java大学B组

目录 k倍区间 舞狮 交换瓶子 k倍区间 取模后算组合数就行 import java.util.HashMap; import java.util.Map; import java.util.Scanner;public class Main {static Scanner sc new Scanner(System.in);public static void main(String[] args) {solve();}public static vo…

PyTorch——非线性激活(5)

非线性激活函数的作用是让神经网络能够理解更复杂的模式和规律。如果没有非线性激活函数,神经网络就只能进行简单的加法和乘法运算,没法处理复杂的问题。 非线性变化的目的就是给我们的网络当中引入一些非线性特征 Relu 激活函数 Relu处理图像 # 导入必…

OPenCV CUDA模块目标检测----- HOG 特征提取和目标检测类cv::cuda::HOG

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::cuda::HOG 是 OpenCV 的 CUDA 模块中对 HOG 特征提取和目标检测 提供的 GPU 实现。它与 CPU 版本的 cv::HOGDescriptor 类似,但利…

MATLAB读取文件内容:Excel、CSV和TXT文件解析

MATLAB读取文件内容:Excel、CSV和TXT文件解析 MATLAB 是一款强大的数学与工程计算工具,广泛应用于数据分析、模型构建和图像处理等领域。在处理实际问题时,我们常常需要从文件中读取数据进行分析。本文将介绍如何使用 MATLAB 读取常见的文件…

Spring MVC 之 异常处理

使用Spring MVC可以很灵活地完成数据的绑定和响应,极大的简化了Java Web的开发。但Spring MVC提供的便利不仅仅如此,使用Spring MVC还可以很便捷地完成项目中的异常处理、自定义拦截器以及文件上传和下载等高级功能。本章将对Spring MVC提供的这些高级功…

ubuntu24.04 使用apt指令只下载不安装软件

比如我想下载net-tools工具包及其依赖包可以如下指令 apt --download-only install net-tools 自动下载的软件包在/var/cache/apt/archives/目录下

计算机网络安全问答数据集(1788条) ,AI智能体知识库收集! AI大模型训练数据!

继续收集数据集,话不多说,见下文! 今天分享一个计算机网络安全问答数据集(1788条),适用于AI大模型训练、智能体知识库构建、安全教育系统开发等多种场景! 一、数据特点 结构清晰:共计1788条&…

WinCC学习系列-高阶应用(WinCC REST通信)

WinCC作为一个经典SCADA系统,它是OT与IT数据无缝集成桥梁,自WinCC7.5版本开始,可以直接提供Rest服务用于其它系统数据访问和操作。 WinCC REST 服务允许外部应用程序访问 WinCC 数据。 外部应用程序可以通过 REST 接口读取和写入 WinCC 组态…

使用交叉编译工具提示stubs-32.h:7:11: fatal error: gnu/stubs-soft.h: 没有那个文件或目录的解决办法

0 前言 使用ST官方SDK提供的交叉编译工具、cmake生成Makefile,使用make命令生成可执行文件提示fatal error: gnu/stubs-soft.h: 没有那个文件或目录的解决办法,如下所示: 根据这一错误提示,按照网上的解决方案逐一尝试均以失败告…

macOS 连接 Docker 运行 postgres,使用navicat添加并关联数据库

下载 docker注册一个账号,登录 Docker创建 docke r文件 mkdir -p ~/.docker && touch ~/.docker/daemon.json写入配置(全量替换) {"builder": {"gc": {"defaultKeepStorage": "20GB",&quo…

指针的使用——基本数据类型、数组、结构体

1 引言 对于学习指针要弄清楚如下问题基本可以应付大部分的场景: ① 指针是什么? ② 指针的类型是什么? ③ 指针指向的类型是什么? ④ 指针指向了哪里? 2 如何使用指针 任何东西的学习最好可以总结成一种通用化的…

TK海外抢单源码/指定卡单

​ 抢单源码,有指定派单,打针,这套二改过充值跳转客服 前端vue 后端php 两端分离 可二开 可以指定卡第几单,金额多少, 前后端开源 PHP7.2 MySQL5.6 前端要www.域名,后端要admin.域名 前端直接静态 伪静…

【Linux】Linux 环境变量

参考博客:https://blog.csdn.net/sjsjnsjnn/article/details/125533127 一、环境变量 1.1 基本概念 环境变量(environment variables)一般是指在操作系统中用来指定操作系统运行环境的一些参数如:我们在编写C/C代码的时候,在链接的时候&am…