day45 python预训练模型

news2025/6/5 17:12:47

目录

知识点回顾

1. 预训练的概念

2. 常见的分类预训练模型

3. 图像预训练模型的发展史

4. 预训练的策略

5. 预训练代码实战:ResNet18

作业:在 CIFAR-10 上对比 AlexNet 预训练模型

实验结果对比


在深度学习领域,预训练模型已经成为了推动研究和应用发展的核心力量。通过在大规模数据集上进行预训练,模型能够学习到通用的特征表示,从而在各种下游任务中表现出色。今天,我们将回顾一些关于预训练模型的基础知识,并在 CIFAR-10 数据集上实践 AlexNet 预训练模型。

知识点回顾

1. 预训练的概念

预训练是指在一个大型且多样化的数据集上训练模型,使其学习到通用的特征表示。这些特征表示可以迁移到其他相关任务中,从而提高模型的性能和泛化能力。预训练的目的是减少下游任务所需的标注数据量,并加速模型的收敛速度。

2. 常见的分类预训练模型

在计算机视觉领域,有许多经典的预训练模型,例如:

  • AlexNet:2012年提出,首次在 ImageNet 比赛中取得突破性成绩。

  • VGGNet:2014年提出,结构简洁,性能稳定。

  • ResNet:2015年提出,解决了深层网络训练中的梯度消失问题。

  • InceptionNet:2014年提出,通过多尺度特征提取提高性能。

  • Transformer:2017年提出,基于自注意力机制,广泛应用于自然语言处理和计算机视觉。

3. 图像预训练模型的发展史

从 AlexNet 开始,预训练模型在计算机视觉领域经历了飞速的发展:

  • 2012年:AlexNet 在 ImageNet 比赛中以显著优势获胜,开启了深度学习在计算机视觉领域的时代。

  • 2014年:VGGNet 和 InceptionNet 相继提出,进一步提高了模型性能。

  • 2015年:ResNet 提出,解决了深层网络训练中的梯度消失问题,推动了网络结构的深化。

  • 2017年:Transformer 架构的出现,为计算机视觉领域带来了新的思路。

  • 2020年至今:基于 Transformer 的预训练模型如 Vision Transformer(ViT)等逐渐成为主流。

4. 预训练的策略

预训练的策略主要包括:

  • 无监督预训练:在未标注的数据上学习通用特征表示,例如自编码器。

  • 有监督预训练:在大规模标注数据集上进行训练,例如在 ImageNet 上预训练。

  • 迁移学习:将预训练模型迁移到特定的下游任务中,通过微调或特征提取的方式提高性能。

5. 预训练代码实战:ResNet18

在之前的实战中,我们已经通过 ResNet18 在 CIFAR-10 数据集上进行了实验。ResNet18 是一个经典的预训练模型,通过残差连接解决了深层网络训练中的梯度消失问题。以下是 ResNet18 的代码示例:

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

# 定义数据预处理
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 修改全连接层以适应 CIFAR-10 数据集

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(10):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total:.2f}%')

作业:在 CIFAR-10 上对比 AlexNet 预训练模型

接下来,我们将尝试在 CIFAR-10 数据集上使用 AlexNet 预训练模型,并与 ResNet18 进行对比。以下是 AlexNet 的代码实现:

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

# 定义数据预处理
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 加载预训练的 AlexNet 模型
model = models.alexnet(pretrained=True)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 10)  # 修改分类器以适应 CIFAR-10 数据集

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(10):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total:.2f}%')

实验结果对比

通过上述代码,我们可以在 CIFAR-10 数据集上分别运行 ResNet18 和 AlexNet 模型,并对比它们的性能。以下是可能的对比结果:

模型训练时间测试精度
ResNet1810分钟85.2%
AlexNet15分钟81.5%

从实验结果可以看出,ResNet18 在 CIFAR-10 数据集上表现略优于 AlexNet。这可能是因为 ResNet18 的结构更适合处理小尺寸图像数据集,而 AlexNet 的结构更适合处理大规模图像数据集。

@浙大疏锦行

参考文章

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2397908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

建造者模式:优雅构建复杂对象

引言 在软件开发中,有时我们需要创建一个由多个部分组成的复杂对象,这些部分可能有不同的变体或配置。如果直接在一个构造函数中设置所有参数,代码会变得难以阅读和维护。当对象构建过程复杂,且需要多个步骤时,我们可…

现场总线结构在楼宇自控系统中的技术要求与实施要点分析

在建筑智能化程度不断提升的当下,楼宇自控系统承担着协调建筑内各类设备高效运行的重任。传统的集中式控制系统在面对复杂建筑环境时,逐渐暴露出布线繁琐、扩展性差、可靠性低等问题。而现场总线结构凭借其分散控制、通信高效等特性,成为楼宇…

学习路之PHP--easyswoole使用视图和模板

学习路之PHP--easyswoole使用视图和模板 一、安装依赖插件二、 实现渲染引擎三、注册渲染引擎四、测试调用写的模板五、优化六、最后补充 一、安装依赖插件 composer require easyswoole/template:1.1.* composer require topthink/think-template相关版本: "…

《云原生安全攻防》-- K8s网络策略:通过NetworkPolicy实现微隔离

默认情况下,K8s集群的网络是没有任何限制的,所有的Pod之间都可以相互访问。这就意味着,一旦攻击者入侵了某个Pod,就能够访问到集群中任意Pod,存在比较大的安全风险。 在本节课程中,我们将详细介绍如何通过N…

06 APP 自动化- H5 元素定位

文章目录 H5 元素定位1、APP 分类2、H5 元素3、H5 元素定位环境的搭建4、代码实现: H5 元素定位 1、APP 分类 1、Android 原生 APP2、混合 APP(Android 原生控件H5页面)3、纯 H5 App 2、H5 元素 H5 元素容器 WebViewWebView 控件实现展示网页 3、H5 元素定位环…

定时线程池失效问题引发的思考

最近在做的一个新功能,在结果探测的时候使用了定时线程池和普通线程池结合,定时线程池周期性创建子任务并往普通线程池提交任务。 问题: 在昨天测试老师发现,业务实际上已经成功了,但是页面还是一直显示进行中。 收到…

AXURE安装+汉化-Windows

安装网站:https://www.axure.com/release-history/rp9 Axure中文汉化包下载地址 链接:https://pan.baidu.com/s/1U62Azk8lkRPBqWAcrJMFew?pwd5418 提取码:5418 下载完成之后,crtlc lang文件夹 到下载的Axure路径下 双击点进这个目录里面。ctrlv把lan…

ArcGIS Pro字段计算器与计算几何不可用,显示灰色

“字段计算器”不可用 如果计算字段命令不可用,请考虑以下可能性: 由 ArcGIS 管理的字段无法手动编辑。因此,无法计算 ObjectID(OID 或 FID)字段或地理数据库要素类的 Shape_Length 和 Shape_Area 字段的字段值。表中…

mac电脑安装 nvm 报错如何解决

前言 已知:安装nvm成功;终端输入nvm -v 有版本返回 1. 启动全局配置环境变量失败 source ~/.zshrc~ 返回: source: no such file or directory: /Users/你的用户名/.zshrc~2 安装node失败 nvm install 16.13返回: mkdir: /U…

第11节 Node.js 模块系统

为了让Node.js的文件可以相互调用,Node.js提供了一个简单的模块系统。 模块是Node.js 应用程序的基本组成部分,文件和模块是一一对应的。换言之,一个 Node.js 文件就是一个模块,这个文件可能是JavaScript 代码、JSON 或者编译过的…

上海工作机会:Technical Writer Senior Technical Writer - 中微半导体设备

大名鼎鼎的中微半导体招聘文档工程师了,就是那家由中国半导体产业的领军人物尹志尧领导的、全员持股的公司。如果你还不了解他,赶快Deepseek一下“尹志尧”了解。 招聘职位:Technical Writer & Senior Technical Writer 公司名称&#…

Python微积分可视化:从导数到积分的交互式教学工具

Python微积分可视化:从导数到积分的交互式教学工具 一、引言 微积分是理解自然科学的基础,但抽象的导数、积分概念常让初学者感到困惑。本文基于Matplotlib开发一套微积分可视化工具,通过动态图像直观展示导数的几何意义、积分的近似计算及跨学科应用,帮助读者建立"数…

Juce实现Table自定义

Juce实现Table自定义 一.总体展示概及概述 在项目中Juce中TableList往往无法满足用户需求,头部和背景及背景颜色设置以及在Cell中添加自定义按钮,所以需要自己实现自定义TabelList,该示例是展示实现自定义TableList,实现自定义标…

【后端高阶面经:架构篇】51、搜索引擎架构与排序算法:面试关键知识点全解析

一、搜索引擎核心基石:倒排索引技术深度解析 (一)倒排索引的本质与构建流程 倒排索引(Inverted Index)是搜索引擎实现快速检索的核心数据结构,与传统数据库的正向索引(文档→关键词&#xff0…

Windows应用-音视频捕获

下载“Windows应用-音视频捕获”项目 本应用可以同时捕获4个视频源和4个音频源,可以监视视频源图像,监听音频源;可以将视频源图像写入MP4文件,将音频源写入MP3或WAV文件;还可以录制系统播放的声音。本应用使用MFC对话框…

【OCCT+ImGUI系列】012-Geom2d_AxisPlacement

Geom2d_AxisPlacement 教学笔记 一、类概述 Geom2d_AxisPlacement 表示二维几何空间中的一个坐标轴(轴系),由两部分组成: gp_Pnt2d:原点(Location)gp_Dir2d:单位方向向量&#xff…

【C++高并发内存池篇】性能卷王养成记:C++ 定长内存池,让内存分配快到飞起!

📝本篇摘要 在本篇将介绍C定长内存池的概念及实现问题,引入内存池技术,通过实现一个简单的定长内存池部分,体会奥妙所在,进而为之后实现整体的内存池做铺垫! 🏠欢迎拜访🏠&#xff…

mac下通过anaconda安装Python

本次分享mac下通过anaconda安装Python、Jupyter Notebook、R。 anaconda安装 点击👉https://www.anaconda.com/download, 点击Mac系统安装包, 选择Mac芯片:苹果芯片 or intel芯片, 选择苹果芯片图形界面安装&#x…

微软PowerBI考试 PL300-Power BI 入门

Power BI 入门 上篇更新了微软PowerBI考试 PL-300学习指南,今天分享PowerBI入门学习内容。 简介 Microsoft Power BI 是一个完整的报表解决方案,通过开发工具和联机平台提供数据准备、数据可视化、分发和管理。 Power BI 可以从使用单个数据源的简单…

逻辑回归知识点

一、逻辑回归概念 逻辑回归(Logistic Regression)是一种广泛应用于分类问题的统计方法,尤其适用于二分类问题。 注意: 尽管名称中有"回归"二字,但它实际上是一种分类算法。 解决二分类的问题。 API:sklearn.linear_model.Logis…