Python训练营---Day44

news2025/6/7 6:13:03

DAY 44 预训练模型

知识点回顾:

  1. 预训练的概念
  2. 常见的分类预训练模型
  3. 图像预训练模型的发展史
  4. 预训练的策略
  5. 预训练代码实战:resnet18

作业:

  1. 尝试在cifar10对比如下其他的预训练模型,观察差异,尽可能和他人选择的不同
  2. 尝试通过ctrl进入resnet的内部,观察残差究竟是什么

选用 DenseNet121预训练模型,注意DenseNet121 模型的最后分类层名为classifier,而不是 ResNet 中的fc

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.models import resnet18, densenet121, vgg16

# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
    root='./cifar_data',
    train=True,
    download=True,
    transform=train_transform
)

test_dataset = datasets.CIFAR10(
    root='./cifar_data',
    train=False,
    transform=test_transform
)

# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 4. 定义DenseNet121模型
def create_densenet121(pretrained=True, num_classes=10):
    model = models.densenet121(pretrained=pretrained)

    # 修改最后一层全连接层
    in_features = model.classifier.in_features
    model.classifier = nn.Linear(in_features, num_classes) # DenseNet121 的最后一层分类器名称是classifier

    return model.to(device)

# 5. 冻结/解冻模型层的函数
# 这种设计允许我们在迁移学习中保留预训练模型的特征提取部分(卷积层),只训练新添加的分类层(全连接层)。
def freeze_model(model, freeze=True):
    """冻结或解冻模型的卷积层参数"""
    # 冻结/解冻除fc层外的所有参数
    for name, param in model.named_parameters():
        if 'classifier' not in name:    #排除名称中包含 "fc" 的参数,这些通常是全连接层的参数
            param.requires_grad = not freeze    #param.requires_grad是 PyTorch 中控制参数是否参与反向传播和梯度更新的标志
    
    # 打印冻结状态
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)   #统计所有requires_grad=False的参数数量
    total_params = sum(p.numel() for p in model.parameters())
    
    if freeze:
        print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")
    else:
        print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")
    
    return model

# 6. 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):
    """
    前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练
    """
    train_loss_history = []
    test_loss_history = []
    train_acc_history = []
    test_acc_history = []
    all_iter_losses = []
    iter_indices = []
    
    # 初始冻结卷积层
    if freeze_epochs > 0:
        model = freeze_model(model, freeze=True)
    
    for epoch in range(epochs):
        # 解冻控制:在指定轮次后解冻所有层
        if epoch == freeze_epochs:
            model = freeze_model(model, freeze=False)
            # 解冻后调整优化器(可选)
            optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合
        
        model.train()  # 设置为训练模式
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            # 记录Iteration损失
            iter_loss = loss.item()
            all_iter_losses.append(iter_loss)
            iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
            
            # 统计训练指标
            running_loss += iter_loss
            _, predicted = output.max(1)
            total_train += target.size(0)
            correct_train += predicted.eq(target).sum().item()
            
            # 每100批次打印进度
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "
                      f"| 单Batch损失: {iter_loss:.4f}")
        
        # 计算 epoch 级指标
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct_train / total_train
        
        # 测试阶段
        model.eval()
        correct_test = 0
        total_test = 0
        test_loss = 0.0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()
        
        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_acc = 100. * correct_test / total_test
        
        # 记录历史数据
        train_loss_history.append(epoch_train_loss)
        test_loss_history.append(epoch_test_loss)
        train_acc_history.append(epoch_train_acc)
        test_acc_history.append(epoch_test_acc)
        
        # 更新学习率调度器
        if scheduler is not None:
            scheduler.step(epoch_test_loss)
        
        # 打印 epoch 结果
        print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "
              f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")
    
    # 绘制损失和准确率曲线
    plot_iter_losses(all_iter_losses, iter_indices)
    plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
    
    return epoch_test_acc  # 返回最终测试准确率

# 7. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):
    plt.figure(figsize=(10, 4))
    plt.plot(indices, losses, 'b-', alpha=0.7)
    plt.xlabel('Iteration(Batch序号)')
    plt.ylabel('损失值')
    plt.title('训练过程中的Iteration损失变化')
    plt.grid(True)
    plt.show()

# 8. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):
    epochs = range(1, len(train_acc) + 1)
    
    plt.figure(figsize=(12, 5))
    
    # 准确率曲线
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_acc, 'b-', label='训练准确率')
    plt.plot(epochs, test_acc, 'r-', label='测试准确率')
    plt.xlabel('Epoch')
    plt.ylabel('准确率 (%)')
    plt.title('准确率随Epoch变化')
    plt.legend()
    plt.grid(True)
    
    # 损失曲线
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_loss, 'b-', label='训练损失')
    plt.plot(epochs, test_loss, 'r-', label='测试损失')
    plt.xlabel('Epoch')
    plt.ylabel('损失值')
    plt.title('损失值随Epoch变化')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# 主函数:训练模型
def main():
    # 参数设置
    epochs = 40  # 总训练轮次
    freeze_epochs = 5  # 冻结卷积层的轮次
    learning_rate = 1e-3  # 初始学习率
    weight_decay = 1e-4  # 权重衰减
    
    # 创建DenseNet121模型(加载预训练权重)
    model = create_densenet121(pretrained=True, num_classes=10)
    
    # 定义优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    # 定义学习率调度器
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )
    
    # 开始训练(前5轮冻结卷积层,之后解冻)
    final_accuracy = train_with_freeze_schedule(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=epochs,
        freeze_epochs=freeze_epochs
    )
    
    print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
    
    # # 保存模型
    # torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')
    # print("模型已保存至: resnet18_cifar10_finetuned.pth")

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2402550.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端项目初始化

​​​​​​ 目录 1. 安装 nvm 2. 配置 nvm 并切换到 Node.js 16.15.0 3. 安装 LightProxy 代理 4. GIT安装 1. 配置用户名和邮箱(这些信息将用于您在提交代码时的标识): 2. 生成SSH密钥(用于将本地代码仓库与远程存储库连…

USB扩展器与USB服务器的2个主要区别

在现代办公和IT环境中,连接和管理USB设备是常见需求。USB扩展器(常称USB集线器)与USB服务器(如朝天椒USB服务器)是两类功能定位截然不同的解决方案。前者主要解决物理接口数量不足的“近身”连接扩展问题,而…

第46节:多模态分类(图像+文本)

一、多模态分类概述 多模态分类是指利用来自不同模态(如图像、文本、音频等)的数据进行联合分析和分类的任务。 在当今大数据时代,信息往往以多种形式存在,例如社交媒体上的图片配文字、视频附带字幕、医疗检查中的影像与报告等。单一模态的数据往往只能提供有限的信息,…

springMVC-10验证及国际化

验证 概述 ● 概述 1. 对输入的数据(比如表单数据),进行必要的验证,并给出相应的提示信息。 2. 对于验证表单数据,springMVC提供了很多实用的注解, 这些注解由JSR303 验证框架提供. ●JSR 303 验证框架 1. JSR 303 的含义 JSR&#xff0…

LRU 和 DiskLRU实现相册缓存器

我是写Linux后端的(golang、c、py),后端缓存算法通常是指的是内存里面的lru、或diskqueue,都是独立使用。 很少有用内存lru与disklru结合的场景需求。近段时间研究android开发,里面有一些设计思想值得后端学习。 写这…

figma MCP + cursor如何将设计稿生成前端页面

一、准备工作 figma MCP需要通过figma key来获取设计稿权限,key的生成步骤如下 1. 打开figma网页版/APP,进入账户设定 2. 点击生成token 3. 填写内容生成token(一定要确认复制了,不然关闭弹窗后就不会显示了) 二、配置MCP 4. 进入到cursor…

如何理解OSI七层模型和TCP/IP四层模型?HTTP作为如何保存用户状态?多服务器节点下 Session方案怎么做

本篇概览: OSI 七层模型是什么?每一层的作用是什么?TCP/IP四层模型和OSI七层模型的区别是什么? HTTP 本身是无状态协议,HTTP如何保存用户状态? 能不能具体说一下Cookie的工作原理、生命周期、作用域?使用…

【Elasticsearch】映射:Nested 类型

映射:Nested 类型 1.为什么需要 Nested 类型2.如何定义 Nested 类型3.相关操作3.1 索引包含 Nested 数据的文档3.2 查询 Nested 数据3.3 聚合 Nested 数据3.4 排序 Nested 数据3.5 更新 Nested 文档中的特定元素 4.Nested 类型的高级操作4.1 内嵌 inner hits4.2 多级…

Vue3 + UniApp 蓝牙连接与数据发送(稳定版)

本教程适用于使用 uni-app Vue3 (script setup) 开发的跨平台 App(支持微信小程序、H5、Android/iOS 等) 🎯 功能目标 ✅ 获取蓝牙权限✅ 扫描周围蓝牙设备✅ 连接指定蓝牙设备✅ 获取服务和特征值✅ 向设备发送数据包(ArrayBu…

三种读写传统xls格式文件开源库libxls、xlslib、BasicExcel的比较

最近准备读写传统xls格式文件,而不是较新的xlsx,询问DeepSeek有哪些开源库,他给出了如下的简介和建议,还给出了相应链接,不过有的链接已失效。最后还不忘提醒,现在该用xlsx格式了。 以下是几个可以处理传统…

Nature子刊同款的宏基因组免疫球蛋白测序怎么做?

免疫球蛋白A(IgA)是人体肠道黏膜分泌的主要抗体,它在塑造肠道微生物群落和维持肠道稳态中起着关键作用,有研究发现缺乏IgA的患者更容易患自身免疫性疾病和感染性疾病。 目前用于研究IgA结合的主要技术是IgA-SEQ,结合了…

2025年牛客网秋招/社招高质量 Java 面试八股文整理

Java 面试 不论是校招还是社招都避免不了各种面试。笔试,如何去准备这些东西就显得格外重要。不论是笔试还是面试都是有章可循的。关键在于理解企业的需求,明确自己的定位,以及掌握一定的应试技巧。 笔试部分,通常是对基础知识、…

ADI的BF609双核DSP怎么做开发,我来说一说(五)LAN口测试

作者的话 ADI的双核DSP,第二颗是Blackfin系列的BF609,这颗DSP我用了很久,比较熟悉,且写过一些给新手的教程。 硬件准备 ADSP-BF609-CORE:ADI BF609开发板 产品链接:https://item.taobao.com/item.htm?…

行业赋能篇-2-能源行业安全运维升级

在能源行业,尤其是风电领域,运维作业往往面临“三高”挑战——高风险环境、高异构数据量)、高合规要求。以海上风电场为例,传统运维依赖卫星电话沟通,数据记录碎片化,故障因信息传递延迟导致损失扩大。如何…

飞云智能波段主图+多空短线决策副图指标,组合操盘技术图文解说

如上图,组合指标:主图-飞云智能波段,红线上红色K线标记,波段做多.副图指标-多空短线决策,跟踪做短线,红柱做多,绿柱短线卖出或做空。 实战操作中,我们在主图红色线支撑上红色K线出现…

【51单片机】1. 基础点灯大师

1. 新建一个项目集一些基本操作 打开Keli软件,然后: 【Project】→【new μVision Project】→选择项目保存位置 建议文件名选一些通用的名字,如【Project】 左下角选择【Atmel】的【AT89C52】 弹出的【是否添加启动文件到文件夹下】&…

PC端直接打印功能(包括两张图片合并功能)

一、 效果图 二、demo代码 <template><div class"box"><divref"printContent"class"print-content"><div class"print-title">打印图片</div><imgclass"preview-image":src"merged…

Vue前端篇——项目目录结构介绍

&#x1f4d8; 前言 在正式开始学习 Vue 3 开发之前&#xff0c;了解并熟悉其项目目录结构是非常关键的第一步。一个清晰、规范的目录结构不仅有助于开发者高效地组织代码&#xff0c;还能显著提升项目的可读性和可维护性。 Vue 3 作为现代前端开发中广泛使用的主流框架之一&…

抽象工厂模式深度解析:从原理到与应用实战

作者简介 我是摘星&#xff0c;一名全栈开发者&#xff0c;专注 Java后端开发、AI工程化 与 云计算架构 领域&#xff0c;擅长Python技术栈。热衷于探索前沿技术&#xff0c;包括大模型应用、云原生解决方案及自动化工具开发。日常深耕技术实践&#xff0c;乐于分享实战经验与…

35.成功解决编写关于“江协科技”编写技巧第二期标志位积累的问题

江科大学长又发布了第二期的编写技巧&#xff01; 大家可以看看&#xff1a;https://space.bilibili.com/383400717 最后面给了一个未完成的任务&#xff1a; 这里我已经把这个问题给解决了&#xff01; 总代码放在资源里面&#xff0c;key.c放在文章最后面&#xff01;同时感…