day 44

news2025/6/6 20:26:34

使用DenseNet预训练模型对cifar10数据集进行训练

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    transform=test_transform
)

# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 4. 定义DenseNet模型
def create_densenet(pretrained=True, num_classes=10):
    # 修改为加载DenseNet121预训练模型
    model = models.densenet121(pretrained=pretrained)
    
    # 修改最后一层全连接层
    in_features = model.classifier.in_features
    model.classifier = nn.Linear(in_features, num_classes)
    
    return model.to(device)

# 5. 冻结/解冻模型层的函数
def freeze_model(model, freeze=True):
    """冻结或解冻模型的卷积层参数"""
    # 冻结/解冻除classifier层外的所有参数
    for name, param in model.named_parameters():
        if 'classifier' not in name:
            param.requires_grad = not freeze
    
    # 打印冻结状态
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    if freeze:
        print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")
    else:
        print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")
    
    return model

# 6. 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):
    """
    前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练
    """
    train_loss_history = []
    test_loss_history = []
    train_acc_history = []
    test_acc_history = []
    all_iter_losses = []
    iter_indices = []
    
    # 初始冻结卷积层
    if freeze_epochs > 0:
        model = freeze_model(model, freeze=True)
    
    for epoch in range(epochs):
        # 解冻控制:在指定轮次后解冻所有层
        if epoch == freeze_epochs:
            model = freeze_model(model, freeze=False)
            # 解冻后调整优化器(可选)
            optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合
        
        model.train()  # 设置为训练模式
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            # 记录Iteration损失
            iter_loss = loss.item()
            all_iter_losses.append(iter_loss)
            iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
            
            # 统计训练指标
            running_loss += iter_loss
            _, predicted = output.max(1)
            total_train += target.size(0)
            correct_train += predicted.eq(target).sum().item()
            
            # 每100批次打印进度
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "
                      f"| 单Batch损失: {iter_loss:.4f}")
        
        # 计算 epoch 级指标
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct_train / total_train
        
        # 测试阶段
        model.eval()
        correct_test = 0
        total_test = 0
        test_loss = 0.0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()
        
        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_acc = 100. * correct_test / total_test
        
        # 记录历史数据
        train_loss_history.append(epoch_train_loss)
        test_loss_history.append(epoch_test_loss)
        train_acc_history.append(epoch_train_acc)
        test_acc_history.append(epoch_test_acc)
        
        # 更新学习率调度器
        if scheduler is not None:
            scheduler.step(epoch_test_loss)
        
        # 打印 epoch 结果
        print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "
              f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")
    
    # 绘制损失和准确率曲线
    plot_iter_losses(all_iter_losses, iter_indices)
    plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
    
    return epoch_test_acc  # 返回最终测试准确率

# 7. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):
    plt.figure(figsize=(10, 4))
    plt.plot(indices, losses, 'b-', alpha=0.7)
    plt.xlabel('Iteration(Batch序号)')
    plt.ylabel('损失值')
    plt.title('训练过程中的Iteration损失变化')
    plt.grid(True)
    plt.show()

# 8. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):
    epochs = range(1, len(train_acc) + 1)
    
    plt.figure(figsize=(12, 5))
    
    # 准确率曲线
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_acc, 'b-', label='训练准确率')
    plt.plot(epochs, test_acc, 'r-', label='测试准确率')
    plt.xlabel('Epoch')
    plt.ylabel('准确率 (%)')
    plt.title('准确率随Epoch变化')
    plt.legend()
    plt.grid(True)
    
    # 损失曲线
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_loss, 'b-', label='训练损失')
    plt.plot(epochs, test_loss, 'r-', label='测试损失')
    plt.xlabel('Epoch')
    plt.ylabel('损失值')
    plt.title('损失值随Epoch变化')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# 主函数:训练模型
def main():
    # 参数设置
    epochs = 40  # 总训练轮次
    freeze_epochs = 5  # 冻结卷积层的轮次
    learning_rate = 1e-3  # 初始学习率
    weight_decay = 1e-4  # 权重衰减
    
    # 创建DenseNet模型(加载预训练权重)
    model = create_densenet(pretrained=True, num_classes=10)
    
    # 定义优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    # 定义学习率调度器
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )
    
    # 开始训练(前5轮冻结卷积层,之后解冻)
    final_accuracy = train_with_freeze_schedule(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=epochs,
        freeze_epochs=freeze_epochs
    )
    
    print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
    
    # # 保存模型
    # torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')
    # print("模型已保存至: resnet18_cifar10_finetuned.pth")

if __name__ == "__main__":
    main()

注意:

1.DenseNet 最后一层全连接层的名称是 classifier ,并非 fc 因此在 freeze_model 函数里,要注意改一下名。

2.残差连接

在传统的神经网络中,每一层的输出都是对上一层输入进行一系列非线性变换的结果。而在残差网络(ResNet)中,引入了残差块(Residual Block),残差块通过残差连接将输入直接加到经过非线性变换后的输出上。

假设一个神经网络层的输入为 $x$,期望学习的映射为 $H(x)$,在传统网络中,该层需要直接学习 $H(x)$。而在残差网络中,将该层改为学习残差函数 $F(x) = H(x) - x$,则输出变为 $y = F(x) + x$,这里的 $x$ 到 $y$ 的连接就是残差连接。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2399284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NER实践总结,记录一下自己实践遇到的各种问题。

更。 没卡,跑个模型休息好几天,又闲又急。 一开始直接套用了别人的代码进行实体识别,结果很差,原因是他的词表没有我需要的东西,我是用的医学文本。代码直接在github找了改的,用的是BERT的Chinese版本。 然…

微信小程序实现运动能耗计算

微信小程序实现运动能耗计算 近我做了一个挺有意思的微信小程序,能够实现运动能耗的计算。只需要输入性别、年龄、体重、运动时长和运动类型这些信息,就能算出对应的消耗热量。 具体来说,在小程序里,性别不同,身体基…

iTunes 无法备份 iPhone:10 种解决方法

Apple 设备是移动设备市场上最先进的产品之一,但有些人遇到过 iTunes 因出现错误而无法备份 iPhone 的情况。iTunes 拒绝备份 iPhone 时,可能会令人非常沮丧。不过,幸运的是,我们有 10 种有效的方法可以解决这个问题。您可以按照以…

LangChain4J 使用实践

这里写目录标题 大模型应用场景&#xff1a;创建一个测试示例AIService聊天记忆实现简单实现聊天记录记忆MessageWindowChatMemory实现聊天记忆 隔离聊天记忆聊天记忆持久化 添加AI提示词 大模型应用场景&#xff1a; 创建一个测试示例 导入依赖 <dependency><groupI…

【C++】—— 从零开始封装 Map 与 Set:实现与优化

人生的态度是&#xff0c;抱最大的希望&#xff0c;尽最大的努力&#xff0c;做最坏的打算。 —— 柏拉图 《理想国》 目录 1、理论基石——深度剖析 BSTree、AVLTree 与 RBTree 的概念区别 2、迭代器机制——RBTree 迭代器的架构与工程实现 3、高级容器设计——Map 与 Set…

内网穿透之Linux版客户端安装(神卓互联)

选择Linux系统版本 获取安装包 &#xff1a;https://www.shenzhuohl.com/download.html 这里以Ubuntu 18.04为例&#xff0c;其它版本方法类似 登录Ubuntu操作系统&#xff1a; 打开Ubuntu系统终端&#xff0c;更新版本 apt-get update 安装运行环境&#xff1a; 安装C 运…

开疆智能Profinet转Profibus网关连接CMDF5-8ADe分布式IO配置案例

本案例是客户通过开疆智能研发的Profinet转Profibus网关将PLC的Profinet协议数据转换成IO使用的Profibus协议&#xff0c;操作步骤如下。 配置过程&#xff1a; Profinet一侧设置 1. 打开西门子组态软件进行组态&#xff0c;导入网关在Profinet一侧的GSD文件。 2. 新建项目并…

华为云Flexus+DeepSeek征文|Flexus云服务器单机部署+CCE容器高可用部署快速搭建生产级的生成式AI应用

前引&#xff1a; 在AI技术高速演进的浪潮中&#xff0c;如何快速、高效、安全地搭建一个大模型应用平台&#xff0c;成为开发者和企业关注的焦点。近日&#xff0c;华为云推出的Flexus云服务器配合CCE容器引擎和Dify LLM应用开发平台&#xff0c;带来了极具吸引力的解决方案。…

60天python训练计划----day44

DAY 44 预训练模型 知识点回顾&#xff1a; 预训练的概念常见的分类预训练模型图像预训练模型的发展史预训练的策略预训练代码实战&#xff1a;resnet18 一、预训练的概念 我们之前在训练中发现&#xff0c;准确率最开始随着epoch的增加而增加。随着循环的更新&#xff0c;参数…

【JAVA版】意象CRM客户关系管理系统+uniapp全开源

一.介绍 CRM意象客户关系管理系统&#xff0c;是一个综合性的客户管理平台&#xff0c;旨在帮助企业高效地管理客户信息、商机、合同以及员工业绩。系统通过首页、系统管理、工作流程、审批中心、线索管理、客户管理、商机管理、合同管理、CRM系统、数据统计和系统配置等模块&…

API异常信息如何实时发送到钉钉

#背景 对于一些重要的API&#xff0c;开发人员会非常关注API有没有报错&#xff0c;为了方便开发人员第一时间获取错误信息&#xff0c;我们可以使用插件来将API报错实时发送到钉钉群。 接下来我们就来实操如何实现 #准备工作 #创建钉钉群 如果已有钉钉群&#xff0c;可以跳…

Python爬虫(48)基于Scrapy-Redis与深度强化学习的智能分布式爬虫架构设计与实践

目录 一、背景与行业痛点二、核心技术架构设计2.1 分布式爬虫基础架构2.2 深度强化学习模块 三、生产环境实践案例3.1 电商价格监控系统3.2 学术文献采集系统 四、高级优化技术4.1 联邦学习增强4.2 神经架构搜索&#xff08;NAS&#xff09; 五、总结&#x1f308;Python爬虫相…

(1-6-3)Java 多线程

目录 0.知识拓扑 1. 多线程相关概念 1.1 进程 1.2 线程 1.3 java 中的进程 与 线程概述 1.4 CPU、进程 与 线程的关系 2.多线程的创建方式 2.1 继承Thread类 2.2 实现Runnable接口 2.3 实现Callable接口 2.4 三种创建方式对比 3.线程同步 3.1 线程同步机制概述 …

java31

1.网络编程 三要素&#xff1a; 网址实质上就是ip InetAddress: UDP通信程序&#xff1a; 多个接收端的地址都要加入同一个组播地址&#xff0c;这样发送端发信息&#xff0c;全部接收端都能接受到数据 广播的代码差不多&#xff0c;就是地址不一样而已 TCP通信程序&#xf…

界面组件DevExpress WPF中文教程:Grid - 如何识别行和卡片?

DevExpress WPF拥有120个控件和库&#xff0c;将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序&#xff0c;这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

【HarmonyOS Next之旅】DevEco Studio使用指南(三十)

目录 1 -> 部署云侧工程 2 -> 通过CloudDev面板获取云开发资源支持 3 -> 通用云开发模板 3.1 -> 适用范围 3.2 -> 效果图 4 -> 总结 1 -> 部署云侧工程 可以选择在云函数和云数据库全部开发完成后&#xff0c;将整个云工程资源统一部署到AGC云端。…

AI基础知识(LLM、prompt、rag、embedding、rerank、mcp、agent、多模态)

AI基础知识&#xff08;LLM、prompt、rag、embedding、rerank、mcp、agent、多模态&#xff09; 1、LLM大语言模型 --基于​​深度学习技术​​&#xff0c;通过​​海量文本数据训练​​而成的超大规模人工智能模型&#xff0c;能够理解、生成和推理自然语言文本 --产品&…

[蓝桥杯]高僧斗法

高僧斗法 题目描述 古时丧葬活动中经常请高僧做法事。仪式结束后&#xff0c;有时会有"高僧斗法"的趣味节目&#xff0c;以舒缓压抑的气氛。 节目大略步骤为&#xff1a;先用粮食&#xff08;一般是稻米&#xff09;在地上"画"出若干级台阶&#xff08;…

pycharm F2 修改文件名 修改快捷键

菜单&#xff1a;File-> Setting&#xff0c; Keymap中搜索 Rename&#xff0c; 其中&#xff0c;有 Refactor-> Rename&#xff0c;右键添加快捷键&#xff0c;F2&#xff0c;删除原有快捷键就可以了。

Python Flask中启用AWS Secrets Manager+AWS Parameter Store配置中心

问题 最近需要改造一个Python的Flask项目。需要在这个项目中添加AWS Secrets Manager作为配置中心&#xff0c;主要是数据库相关配置。 前提 得预先在Amazon RDS里面新建好数据库用户和数据库&#xff0c;以AWS Aurora为例子&#xff0c;建库和建用户语句类似如下&#xff1…