day46 python预训练模型补充

news2025/6/8 18:05:57

目录

一、预训练模型的背景知识

二、实验过程

(一)实验环境与数据准备

(二)预训练模型的选择与适配

(三)训练策略

三、实验结果与分析

四、学习总结与展望


一、预训练模型的背景知识

在传统的神经网络训练中,模型的参数是随机初始化的,这可能导致训练初期的不稳定,并且容易陷入局部最优解。而预训练模型的出现,为这一问题提供了有效的解决方案。预训练模型是在大规模数据集(如 ImageNet)上预先训练好的模型,它已经学习到了丰富的通用特征。当我们面临一个新的图像分类任务时,可以直接利用这些预训练好的模型参数来初始化我们的模型,这样模型在初始阶段就具备了一定的特征提取能力,能够更快地收敛,并且在一定程度上避免了局部最优解的问题。

预训练模型的选择至关重要。首先,预训练任务与目标任务的相似性是关键因素。如果两个任务在特征层面具有相似性,那么预训练模型提取的特征将对目标任务更有帮助。其次,预训练数据集的规模也非常重要。大规模的数据集能够支撑模型学习到更通用的特征,从而在不同的任务中具有更好的泛化能力。例如,ImageNet 数据集拥有 1000 个类别,1.2 亿张图像,尺寸为 224x224,是一个非常适合用于预训练的大规模图像数据集。

二、实验过程

(一)实验环境与数据准备

本次实验使用的是 PyTorch 深度学习框架,借助其丰富的预训练模型库和便捷的数据处理工具,能够高效地完成模型的加载、训练和测试。实验中使用的 CIFAR-10 数据集是一个经典的图像分类数据集,包含 10 个类别,共 60000 张 32x32 的彩色图像,其中训练集有 50000 张图像,测试集有 10000 张图像。由于 CIFAR-10 图像的尺寸较小,且类别相对较少,因此直接在该数据集上训练模型可能会面临过拟合等问题,而预训练模型的引入则有望缓解这一问题。

在数据预处理阶段,为了增强模型的泛化能力,对训练集进行了多种数据增强操作,包括随机裁剪、随机水平翻转、颜色抖动以及随机旋转等。这些操作能够在训练过程中为模型提供更多的“干扰”或变形,使模型能够学习到更加鲁棒的特征。具体的数据预处理代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    transform=test_transform
)

# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

对于测试集,则仅进行了标准化处理,以确保其与训练集在数据分布上具有一致性。

(二)预训练模型的选择与适配

在本次实验中,选择了 ResNet18 作为预训练模型。ResNet18 是一种经典的卷积神经网络架构,具有 18 层深度,并且通过残差连接解决了深层网络训练中的梯度消失问题。它在 ImageNet 数据集上预训练后,能够提取出具有较强表达能力的特征。

由于 ResNet18 在 ImageNet 数据集上预训练时,其输入图像尺寸为 224x224,而 CIFAR-10 图像的尺寸为 32x32,因此需要对模型进行适当的调整以适配 CIFAR-10 数据集。具体来说,需要修改模型的最后一层全连接层,将其输出类别数从 1000 改为 10,以匹配 CIFAR-10 的类别数量。此外,由于输入图像尺寸的变化,还需要调整模型中的一些层的参数,例如池化层的步长等。以下是 ResNet18 模型的适配代码:

from torchvision.models import resnet18

# 定义 ResNet18 模型(支持预训练权重加载)
def create_resnet18(pretrained=True, num_classes=10):
    model = resnet18(pretrained=pretrained)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model.to(device)

# 创建 ResNet18 模型(加载 ImageNet 预训练权重,不进行微调)
model = create_resnet18(pretrained=True, num_classes=10)
model.eval()  # 设置为推理模式

(三)训练策略

在训练过程中,采用了阶段式训练策略。首先,冻结模型的卷积层参数,仅训练全连接层,这样可以在不破坏预训练模型特征提取能力的前提下,快速调整模型的输出层以适应 CIFAR-10 数据集。经过一定轮次的训练后,解冻模型的所有参数,进行整体训练,以进一步提升模型的性能。这种策略能够在训练初期快速降低损失,并在后续训练中充分利用预训练模型的特征提取能力,实现更好的收敛效果。

具体来说,实验中设置了前 5 轮冻结卷积层参数,之后解冻所有参数进行训练。在解冻后,为了防止过拟合,还适当降低了学习率。以下是训练函数的关键代码:

# 冻结/解冻模型层的函数
def freeze_model(model, freeze=True):
    """冻结或解冻模型的卷积层参数"""
    for name, param in model.named_parameters():
        if 'fc' not in name:
            param.requires_grad = not freeze
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    if freeze:
        print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")
    else:
        print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")
    return model

# 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):
    model = freeze_model(model, freeze=True)
    for epoch in range(epochs):
        if epoch == freeze_epochs:
            model = freeze_model(model, freeze=False)
            optimizer.param_groups[0]['lr'] = 1e-4  # 解冻后调整学习率
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = output.max(1)
            total_train += target.size(0)
            correct_train += predicted.eq(target).sum().item()
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct_train / total_train
        model.eval()
        correct_test = 0
        total_test = 0
        test_loss = 0.0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()
        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_acc = 100. * correct_test / total_test
        if scheduler is not None:
            scheduler.step(epoch_test_loss)
        print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")

三、实验结果与分析

经过 40 轮的训练,最终测试准确率达到了 86.30%。从训练过程的输出中可以看出,在解冻卷积层参数后,模型的训练损失迅速下降,训练准确率和测试准确率都得到了显著提升。这充分证明了预训练模型的强大优势,即使在 CIFAR-10 这种相对较小的数据集上,也能够通过微调取得优异的性能。

此外,由于训练集采用了数据增强操作,模型在训练初期可能会出现训练准确率暂时低于测试准确率的情况。这是因为数据增强增加了模型训练的难度,而测试集是标准的、未增强的图像,模型在测试集上预测相对轻松。随着训练的推进,模型逐渐适应了数据增强带来的变化,训练准确率和测试准确率之间的差距逐渐缩小。

以下是完整的训练代码:

# 主函数:训练模型
def main():
    # 参数设置
    epochs = 40  # 总训练轮次
    freeze_epochs = 5  # 冻结卷积层的轮次
    learning_rate = 1e-3  # 初始学习率
    weight_decay = 1e-4  # 权重衰减

    # 创建 ResNet18 模型(加载预训练权重)
    model = create_resnet18(pretrained=True, num_classes=10)

    # 定义优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    # 定义学习率调度器
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )

    # 开始训练(前 5 轮冻结卷积层,之后解冻)
    final_accuracy = train_with_freeze_schedule(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=epochs,
        freeze_epochs=freeze_epochs
    )

    print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

if __name__ == "__main__":
    main()

@浙大疏锦行

补充-60日计划day44,pynote中day55

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2404398.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java建造者模式(Builder Pattern)详解与实践

一、引言 在软件开发中,我们经常会遇到需要创建复杂对象的场景。例如,构建一个包含多个可选参数的对象时,传统的构造函数或Setter方法可能导致代码臃肿、难以维护。此时,建造者模式(Builder Pattern)便成为…

win32相关(IAT HOOK)

IAT HOOK 什么是IAT Hook? IAT Hook(Import Address Table Hook,导入地址表钩子)是一种Windows平台下的API钩取技术,通过修改目标程序的导入地址表(IAT)来拦截和重定向API调用 在我们之前学习pe文件结构的导入表时&am…

零基础玩转物联网-串口转以太网模块如何快速实现与TCP服务器通信

目录 1 前言 2 环境搭建 2.1 硬件准备 2.2 软件准备 2.3 驱动检查 3 TCP服务器通信配置与交互 3.1 硬件连接 3.2 开启TCP服务器 3.3 打开配置工具读取基本信息 3.4 填写连接参数进行连接 3.5 通信测试 4 总结 1 前言 TCP是TCP/IP体系中的传输层协议,全称为Transmiss…

ESP32开发之LED闪烁和呼吸的实现

硬件电路介绍GPIO输出模式GPIO配置过程闪烁灯的源码LED PWM的控制器(LEDC)概述LEDC配置过程及现象整体流程 硬件电路介绍 电路图如下: 只要有硬件基础的应该都知道上图中,当GPIO4的输出电平为高时,LED灯亮,反之则熄灭。如果每间…

【产品业务设计】支付业务设计规范细节记录,含订单记录、支付业务记录、支付流水记录、退款业务记录

【产品业务设计】支付业务设计规范细节记录,含订单记录、支付业务记录、支付流水记录 前言 我为什么要写这个篇文章 总结设计经验生成设计模板方便后期快速搭建 一个几张表 一共5张表; 分别是: 订单主表:jjy_orderMain订单产…

2025软件供应链安全最佳实践︱证券DevSecOps下供应链与开源治理实践

项目背景:近年来,云计算、AI人工智能、大数据等信息技术的不断发展、各行各业的信息电子化的步伐不断加快、信息化的水平不断提高,网络安全的风险不断累积,金融证券行业面临着越来越多的威胁挑战。特别是近年以来,开源…

WebRTC通话原理与入门难度实战指南

波煮的实习公司主要是音视频业务,所以最近在补习WebRTC的相关内容,会不定期给大家分享学习心得和笔记。 文章目录 WebRTC通话原理进行媒体协商:彼此要了解对方支持的媒体格式网络协商:彼此要了解对方的网络情况,这样才…

N元语言模型 —— 一文讲懂!!!

目录 引言 一. 基本知识 二.参数估计 三.数据平滑 一.加1法 二.减值法/折扣法 ​编辑 1.Good-Turing 估计 ​编辑 2.Back-off (后备/后退)方法 3.绝对减值法 ​编辑4.线性减值法 5.比较 三.删除插值法(Deleted interpolation) 四.模型自适应 引言 本章节讲的…

.NET 9中的异常处理性能提升分析:为什么过去慢,未来快

一、为什么要关注.NET异常处理的性能 随着现代云原生、高并发、分布式场景的大量普及,异常处理(Exception Handling)早已不再只是一个冷僻的代码路径。在高复杂度的微服务、网络服务、异步编程环境下,服务依赖的外部资源往往不可…

Mac 安装git心路历程(心累版)

省流版:直接安装Xcode命令行工具即可,不用安Xcode。 git下载官网 第一部分 上网初步了解后,打算直接安装Binary installer,下载完安装时,苹果还阻止安装,只好在“设置–安全性与隐私”最下面的提示进行安…

计算机网络第2章(下):物理层传输介质与核心设备全面解析

目录 一、传输介质1.1 传输介质的分类1.2 导向型传输介质1.2.1 双绞线(Twisted Pair)1.2.2 同轴电缆(Coaxial Cable)1.2.3 光纤(Optical Fiber)1.2.4 以太网对有线传输介质的命名规则 1.3 非导向型传输介质…

C# 类和继承(扩展方法)

扩展方法 在迄今为止的内容中,你看到的每个方法都和声明它的类关联。扩展方法特性扩展了这个边 界,允许编写的方法和声明它的类之外的类关联。 想知道如何使用这个特性,请看下面的代码。它包含类MyData,该类存储3个double类型 的…

MySQL复杂SQL(多表联查/子查询)详细讲解

🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 MySQL复杂SQL(多表联查/子查询&a…

STM32使用土壤湿度传感器

1.1 介绍: 土壤湿度传感器是一种传感装置,主要用于检测土壤湿度的大小,并广泛应用于汽车自动刮水系统、智能灯光系统和智能天窗系统等。传感器采用优质FR-04双料,大面积5.0 * 4.0厘米,镀镍处理面。 它具有抗氧化&…

Windows平台RTSP/RTMP播放器C#接入详解

大牛直播SDK在Windows平台下的RTSP、RTMP播放器模块,基于自研高性能内核,具备极高的稳定性与行业领先的超低延迟表现。相比传统基于FFmpeg或VLC的播放器实现,SmartPlayer不仅支持RTSP TCP/UDP自动切换、401鉴权、断网重连等网络复杂场景自适应…

从 JDK 8 到 JDK 17:Swagger 升级迁移指南

点击上方“程序猿技术大咖”,关注并选择“设为星标” 回复“加群”获取入群讨论资格! 随着 Java 生态向 JDK 17 及 Jakarta EE 的演进,许多项目面临从 JDK 8 升级的挑战,其中 Swagger(API 文档工具)的兼容性…

使用 Coze 工作流一键生成抖音书单视频:全流程拆解与技术实现

使用 Coze 工作流一键生成抖音书单视频:全流程拆解与技术实现(提供工作流) 摘要:本文基于一段关于使用 Coze 平台构建抖音爆火书单视频的详细讲解,总结出一套完整的 AI 视频自动化制作流程。内容涵盖从思路拆解、节点配…

【发布实录】云原生+AI,助力企业全球化业务创新

5 月 22 日,在最新一期阿里云「飞天发布时刻」,阿里云云原生应用平台产品负责人李国强重磅揭晓面向 AI 场景的云原生产品体系升级,通过弹性智能的一体化架构、开箱即用的云原生 AI 能力,为中国企业出海提供新一代技术引擎。 发布会…

LabVIEW主轴故障诊断案例

LabVIEW 开发主轴机械状态识别与故障诊断系统,适配工业场景主轴振动监测需求。通过整合品牌硬件与软件算法,实现从信号采集到故障定位的全流程自动化,为设备维护提供数据支撑,提升数控机床运行可靠性。 ​ 面向精密制造企业数控机…

计算机组成与体系结构:补码数制二(Complementary Number Systems)

目录 4位二进制的减法 补码系统 🧠减基补码 名字解释: 减基补码有什么用? 计算方法 ❓为什么这样就能计算减基补码 💡 原理揭示:按位减法,模拟总减法! 那对于二进制呢?&…