Python Day37

news2025/5/31 20:46:25

Task:
1.过拟合的判断:测试集和训练集同步打印指标
2.模型的保存和加载
a.仅保存权重
b.保存权重和模型
c.保存全部信息checkpoint,还包含训练状态
3.早停策略

在这里插入图片描述
在这里插入图片描述

1. 过拟合的判断:测试集和训练集同步打印指标

过拟合是指模型在训练数据上表现良好,但在未见过的新数据(测试数据)上表现较差。 判断过拟合的关键在于比较训练集和测试集的性能。

  • 训练集指标: 准确率、损失等。 训练集指标通常会随着训练的进行而持续提高(准确率)或降低(损失)。
  • 测试集指标: 准确率、损失等。 测试集指标一开始可能会随着训练的进行而提高,但当模型开始过拟合时,测试集指标会停止提高甚至开始下降。

判断标准:

  • 训练集性能远好于测试集性能: 这是过拟合的典型迹象。
  • 训练集损失持续下降,但测试集损失开始上升: 这表明模型正在记住训练数据,而不是学习泛化能力。
  • 训练集准确率接近 100%,但测试集准确率较低: 也是过拟合的信号。

代码示例 (PyTorch):

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# 1. 创建一些模拟数据
np.random.seed(42)
X = np.random.rand(100, 10)  # 100个样本,每个样本10个特征
y = np.random.randint(0, 2, 100)  # 二分类问题

# 将数据分成训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# 转换为 PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

# 创建 DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


# 2. 定义一个简单的模型
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out


input_size = 10
hidden_size = 20
num_classes = 2
model = SimpleNN(input_size, hidden_size, num_classes)

# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)


# 4. 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    # 训练
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_accuracy = 100 * correct_train / total_train
    train_loss /= len(train_loader)

    # 评估
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = 100 * correct_test / total_test
    test_loss /= len(test_loader)

    # 打印指标
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
          f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')

    # 观察训练集和测试集的损失和准确率,判断是否过拟合。

关键点:

  • 在每个 epoch 结束后,同时打印训练集和测试集的损失和准确率。
  • 观察这些指标的变化趋势。 如果训练集损失持续下降,但测试集损失开始上升,或者训练集准确率远高于测试集准确率,则表明模型正在过拟合。

2. 模型的保存和加载

在 PyTorch 中,有几种保存和加载模型的方法:

a. 仅保存权重 (推荐):

这是最常见和推荐的方法。 它只保存模型的参数(权重和偏置),而不保存模型的结构。 这种方法更轻量级,更灵活,因为你可以将权重加载到任何具有相同结构的模型的实例中。

  • 保存: torch.save(model.state_dict(), 'model_weights.pth')
  • 加载:
    1. 首先,创建一个与保存时具有相同结构的模型的实例。
    2. 然后,使用 model.load_state_dict(torch.load('model_weights.pth')) 加载权重。
# 保存权重
torch.save(model.state_dict(), 'model_weights.pth')

# 加载权重
model = SimpleNN(input_size, hidden_size, num_classes)  # 创建模型实例
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # 设置为评估模式

b. 保存权重和模型结构:

这种方法保存整个模型对象,包括模型的结构和权重。 它更简单,但不如仅保存权重灵活。

  • 保存: torch.save(model, 'model.pth')
  • 加载: model = torch.load('model.pth')
# 保存整个模型
torch.save(model, 'model.pth')

# 加载整个模型
model = torch.load('model.pth')
model.eval()  # 设置为评估模式

c. 保存全部信息 (Checkpoint):

这种方法保存模型的全部信息,包括模型的结构、权重、优化器的状态、epoch 数等。 这对于恢复训练非常有用。

  • 保存:
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': train_loss,
}
torch.save(checkpoint, 'checkpoint.pth')
  • 加载:
checkpoint = torch.load('checkpoint.pth')
model = SimpleNN(input_size, hidden_size, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval() # or model.train() depending on your needs
# - or -
model.train()

3. 早停策略 (Early Stopping)

早停是一种正则化技术,用于防止过拟合。 它的基本思想是:在训练过程中,如果测试集上的性能在一定时间内没有提高,就提前停止训练。

实现步骤:

  1. 监控指标: 选择一个在测试集上监控的指标(例如,测试集损失或准确率)。
  2. 耐心 (Patience): 定义一个“耐心”值,表示在测试集指标没有提高的情况下,允许训练继续进行的 epoch 数。
  3. 最佳模型: 在训练过程中,保存测试集上性能最佳的模型。
  4. 停止训练: 如果测试集指标在“耐心”个 epoch 内没有提高,则停止训练,并加载保存的最佳模型。

代码示例 (PyTorch):

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# 1. 创建一些模拟数据
np.random.seed(42)
X = np.random.rand(100, 10)  # 100个样本,每个样本10个特征
y = np.random.randint(0, 2, 100)  # 二分类问题

# 将数据分成训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# 转换为 PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

# 创建 DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


# 2. 定义一个简单的模型
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out


input_size = 10
hidden_size = 20
num_classes = 2
model = SimpleNN(input_size, hidden_size, num_classes)

# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 4. 早停策略
patience = 10  # 如果10个epoch测试集损失没有下降,就停止训练
best_loss = np.inf  # 初始最佳损失设置为无穷大
counter = 0  # 计数器,记录测试集损失没有下降的epoch数
best_model_state = None # 保存最佳模型的状态

# 5. 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    # 训练
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_accuracy = 100 * correct_train / total_train
    train_loss /= len(train_loader)

    # 评估
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = 100 * correct_test / total_test
    test_loss /= len(test_loader)

    # 打印指标
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
          f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')

    # 早停逻辑
    if test_loss < best_loss:
        best_loss = test_loss
        counter = 0
        best_model_state = model.state_dict() # 保存最佳模型的状态
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping!")
            model.load_state_dict(best_model_state)  # 加载最佳模型
            break

print(f"Best Test Loss: {best_loss:.4f}")

关键点:

  • patience:控制早停的敏感度。 较小的 patience 值会导致更早的停止,可能导致欠拟合。 较大的 patience 值允许训练继续进行更长时间,可能导致过拟合。
  • best_loss:跟踪测试集上的最佳损失。
  • counter:跟踪测试集损失没有下降的 epoch 数。
  • 在早停后,加载保存的最佳模型。

总结:

  • 通过同步打印训练集和测试集的指标,可以有效地判断过拟合。
  • 仅保存权重是保存模型的推荐方法,因为它更轻量级和灵活。
  • 早停是一种有效的正则化技术,可以防止过拟合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2391853.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RabbitMQ集群与负载均衡实战指南

文章目录 集群架构概述仲裁队列的使用1. 使用Spring框架代码创建2. 使用amqp-client创建3. 使用管理平台创建 负载均衡引入HAProxy 负载均衡&#xff1a;使用方法1. 修改配置文件2. 声明队列 test_cluster3. 发送消息 集群架构 概述 RabbitMQ支持部署多个结点&#xff0c;每个…

linux安装ffmpeg7.0.2全过程

​编辑 白眉大叔 发布于 2025年4月16日 评论关闭 阅读(341) centos 编译安装 ffmpeg 7.0.2 &#xff1a;连接https://www.baimeidashu.com/19668.html 下载 FFmpeg 源代码 在文章最后 一、在CentOS上编译安装FFmpeg 以常见的CentOS为例&#xff0c;FFmpeg的编译说明页面为h…

Java中的设计模式实战:单例、工厂、策略模式的最佳实践

Java中的设计模式实战&#xff1a;单例、工厂、策略模式的最佳实践 在Java开发中&#xff0c;设计模式是构建高效、可维护、可扩展应用程序的关键。本文将深入探讨三种常见且实用的设计模式&#xff1a;单例模式、工厂模式和策略模式&#xff0c;并通过详细代码实例&#xff0…

2025.05.28【Parallel】Parallel绘图:拟时序分析专用图

Improve general appearance Add title, use a theme, change color palette, control variable orders and more Highlight a group Highlight a group of interest to help people understand your story 文章目录 Improve general appearanceHighlight a group探索Paralle…

自动驾驶与智能交通:构建未来出行的智能引擎

随着人工智能、物联网、5G和大数据等前沿技术的发展&#xff0c;自动驾驶汽车和智能交通系统正以前所未有的速度改变人类的出行方式。这一变革不仅是技术的融合创新&#xff0c;更是推动城市可持续发展的关键支撑。 一、自动驾驶与智能交通的定义 1. 自动驾驶&#xff08;Auto…

ELectron 中 BrowserView 如何进行实时定位和尺寸调整

背景 BrowserView 是继 Webview 后推出来的高性能多视图管理工具&#xff0c;与 Webview 最大的区别是&#xff0c;Webview 是一个 DOM 节点&#xff0c;依附于主渲染进程的附属进程&#xff0c;Webview 节点的崩溃会导致主渲染进程的连锁反应&#xff0c;会引起软件的崩溃。 …

深兰科技董事长陈海波率队考察南京,加速AI大模型区域落地应用

近日&#xff0c;深兰科技创始人、董事长陈海波受邀率队赴南京市&#xff0c;先后考察了南京高新技术产业开发区与鼓楼区&#xff0c;就推进深兰AI医诊大模型在南京的落地应用&#xff0c;与当地政府及相关部门进行了深入交流与合作探讨。 此次考察聚焦于深兰科技自主研发的AI医…

《深度关系-从建立关系到彼此信任》

陈海贤老师推荐的书&#xff0c;花了几个小时&#xff0c;感觉现在的人与人之间特别缺乏这种深度的关系&#xff0c;但是与一个人建立深度的关系并没有那么简单&#xff0c;反正至今为止&#xff0c;自己好像没有与任何一个人建立了这种深度的关系&#xff0c;那种双方高度同频…

IT选型指南:电信行业需要怎样的服务器?

从第一条电报发出的 那一刻起 电信技术便踏上了飞速发展的征程 百余年间 将世界编织成一个紧密相连的整体 而在今年 我们迎来了第25届世界电信日 同时也是国际电联成立的第160周年 本届世界电信日的主题为:“弥合性别数字鸿沟,为所有人创造机遇”,但在新兴技术浪潮汹涌…

【ConvLSTM第二期】模拟视频帧的时序建模(Python代码实现)

目录 1 准备工作&#xff1a;python库包安装1.1 安装必要库 案例说明&#xff1a;模拟视频帧的时序建模ConvLSTM概述损失函数说明&#xff08;python全代码&#xff09; 参考 ConvLSTM的原理说明可参见另一博客-【ConvLSTM第一期】ConvLSTM原理。 1 准备工作&#xff1a;pytho…

【论文解读】DETR: 用Transformer实现真正的End2End目标检测

1st authors: About me - Nicolas Carion‪Francisco Massa‬ - ‪Google Scholar‬ paper: [2005.12872] End-to-End Object Detection with Transformers ECCV 2020 code: facebookresearch/detr: End-to-End Object Detection with Transformers 1. 背景 目标检测&#…

ElasticSearch简介及常用操作指南

一. ElasticSearch简介 ElasticSearch 是一个基于 Lucene 构建的开源、分布式、RESTful 风格的搜索和分析引擎。 1. 核心功能 强大的搜索能力 它能够提供全文检索功能。例如&#xff0c;在海量的文档数据中&#xff0c;可以快速准确地查找到包含特定关键词的文档。这在处理诸如…

纤维组织效应偏斜如何影响您的高速设计

随着比特率继续飙升&#xff0c;光纤编织效应 &#xff08;FWE&#xff09; 偏移&#xff0c;也称为玻璃编织偏移 &#xff08;GWS&#xff09;&#xff0c;正变得越来越成为一个问题。今天的 56GB/s 是高速路由器中最先进的&#xff0c;而 112 GB/s 指日可待。而用于个人计算机…

Rust使用Cargo构建项目

文章目录 你好&#xff0c;Cargo&#xff01;验证Cargo安装使用Cargo创建项目新建项目配置文件解析默认代码结构 Cargo工作流常用命令速查表详细使用说明1. 编译项目2. 运行程序3.快速检查4. 发布版本构建 Cargo的设计哲学约定优于配置工程化优势 开发建议1. 新项目初始化​2. …

Python训练营打卡Day39

DAY 39 图像数据与显存 知识点回顾 1.图像数据的格式&#xff1a;灰度和彩色数据 2.模型的定义 3.显存占用的4种地方 a.模型参数梯度参数 b.优化器参数 c.数据批量所占显存 d.神经元输出中间状态 4.batchisize和训练的关系 作业&#xff1a;今日代码较少&#xff0c;理解内容…

UE5蓝图中播放背景音乐和使用代码播放声音

UE5蓝图中播放背景音乐 1.创建背景音乐Cube 2.勾选looping 循环播放背景音乐 3.在关卡蓝图中 Event BeginPlay-PlaySound2D Sound选择自己创建的Bgm_Cube 蓝图播放声音方法二&#xff1a; 使用代码播放声音方法一 .h文件中 头文件引用 #include "Kismet/GameplayS…

AI 赋能数据可视化:漏斗图制作的创新攻略

在数据可视化的广阔天地里&#xff0c;漏斗图以其独特的形状和强大的功能&#xff0c;成为展示流程转化、分析数据变化的得力助手。传统绘制漏斗图的方式往往需要耗费大量时间和精力&#xff0c;对使用者的绘图技能和软件操作熟练度要求颇高。但随着技术的蓬勃发展&#xff0c;…

用 Python 模拟下雨效果

用 Python 模拟下雨效果 雨天别有一番浪漫情怀&#xff1a;淅淅沥沥的雨滴、湿润的空气、朦胧的光影……在屏幕上也能感受下雨的美妙。本文将带你用一份简单的 Python 脚本&#xff0c;手把手实现「下雨效果」动画。文章深入浅出&#xff0c;零基础也能快速上手&#xff0c;完…

C#对象集合去重的一种方式

前言 现在AI越来越强大了&#xff0c;有很多问题其实不需要在去各个网站上查了&#xff0c;直接问AI就好了&#xff0c;但是呢&#xff0c;AI给的代码可能能用&#xff0c;也可能需要调整&#xff0c;但是自己肯定是要会的&#xff0c;所以还是总结一下吧。 问题 如果有一个…

在ROS2(humble)+Gazebo+rqt下,实时显示仿真无人机的相机图像

文章目录 前言一、版本检查检查ROS2版本 二、步骤1.下载对应版本的PX4(1)检查PX4版本(2)修改文件名(3)下载正确的PX4版本 2.下载对应版本的Gazebo(1)检查Gazebo版本(2)卸载不正确的Gazebo版本(3)下载正确的Gazebo版本 3.安装bridge包4.启动 总结 前言 在ROS2的环境下&#xff…