Python训练营打卡Day45

news2025/12/14 20:42:25

知识点回顾:

  1. tensorboard的发展历史和原理
  2. tensorboard的常见操作
  3. tensorboard在cifar上的实战:MLP和CNN模型

效果展示如下,很适合拿去组会汇报撑页数:

作业:对resnet18在cifar10上采用微调策略下,用tensorboard监控训练过程。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18
from torch.utils.tensorboard import SummaryWriter
import time
import os

# 设置随机种子确保结果可复现
torch.manual_seed(42)

# 数据预处理
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 加载CIFAR-10数据集
trainset = datasets.CIFAR10(root='./data', train=True,
                             download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128,
                          shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                            download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100,
                         shuffle=False, num_workers=2)

# 定义类别名称
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 加载预训练的ResNet18模型
model = resnet18(pretrained=True)

# 修改模型以适应CIFAR-10
# 调整输入层以适应32x32的图像
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
# 修改最后一层以适应10个类别
model.fc = nn.Linear(model.fc.in_features, 10)

# 微调策略:冻结部分层
for param in list(model.parameters())[:-10]:  # 解冻最后10层
    param.requires_grad = False

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
                      lr=0.001, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# 设置TensorBoard写入器
timestamp = time.strftime('%Y%m%d_%H%M%S')
log_dir = os.path.join('runs', f'resnet18_cifar10_finetune_{timestamp}')
writer = SummaryWriter(log_dir)

# 训练函数
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # 每100个batch记录一次训练状态
        if (batch_idx+1) % 100 == 0:
            step = epoch * len(trainloader) + batch_idx
            writer.add_scalar('Train/Loss', loss.item(), step)
            writer.add_scalar('Train/Accuracy', 100.*correct/total, step)
            print(f'Epoch: {epoch} | Batch: {batch_idx+1}/{len(trainloader)} | Loss: {loss.item():.3f} | Acc: {100.*correct/total:.3f}%')
    
    # 记录每个epoch的平均训练损失和准确率
    avg_loss = train_loss / len(trainloader)
    avg_acc = 100. * correct / total
    writer.add_scalar('Epoch/Train_Loss', avg_loss, epoch)
    writer.add_scalar('Epoch/Train_Accuracy', avg_acc, epoch)
    return avg_loss, avg_acc

# 测试函数
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    # 记录测试损失和准确率
    avg_loss = test_loss / len(testloader)
    avg_acc = 100. * correct / total
    writer.add_scalar('Epoch/Test_Loss', avg_loss, epoch)
    writer.add_scalar('Epoch/Test_Accuracy', avg_acc, epoch)
    
    # 记录学习率
    writer.add_scalar('Epoch/Learning_Rate', optimizer.param_groups[0]['lr'], epoch)
    
    print(f'Test Epoch: {epoch} | Loss: {avg_loss:.3f} | Acc: {avg_acc:.3f}%')
    
    # 保存最佳模型
    global best_acc
    if avg_acc > best_acc:
        print('Saving best model...')
        state = {
            'model': model.state_dict(),
            'acc': avg_acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = avg_acc
    
    # 记录错误分类的图像
    if epoch % 10 == 0:  # 每10个epoch记录一次
        misclassified = []
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                mis_indices = (predicted != targets).nonzero(as_tuple=True)[0]
                
                for idx in mis_indices:
                    if len(misclassified) < 10:  # 只记录10张错误分类的图像
                        misclassified.append({
                            'image': inputs[idx],
                            'predicted': predicted[idx].item(),
                            'actual': targets[idx].item()
                        })
        
        if misclassified:
            for i, item in enumerate(misclassified):
                img = item['image'].cpu()
                writer.add_image(
                    f'Misclassified/{classes[item["actual"]]}_as_{classes[item["predicted"]]}',
                    img, epoch
                )
    
    return avg_loss, avg_acc

# 可视化模型结构
def visualize_model():
    sample_input = torch.rand(1, 3, 32, 32).to(device)
    writer.add_graph(model, sample_input)

# 可视化样本数据
def visualize_samples():
    # 获取一个批次的训练数据
    images, labels = next(iter(trainloader))
    # 创建图像网格
    img_grid = torchvision.utils.make_grid(images)
    # 添加图像网格到TensorBoard
    writer.add_image('CIFAR10_Samples', img_grid)
    # 添加标签到TensorBoard
    class_labels = [classes[i] for i in labels]
    writer.add_embedding(
        images.view(images.size(0), -1),
        metadata=class_labels,
        label_img=images,
        global_step=0
    )

# 可视化特征图
def visualize_feature_maps(inputs, epoch):
    model.eval()
    # 获取第一层卷积的特征图
    first_conv = model.conv1
    feature_maps = first_conv(inputs.to(device))
    
    # 创建特征图网格
    grid = torchvision.utils.make_grid(
        feature_maps[:8].unsqueeze(1),  # 只显示前8个特征图
        nrow=4, normalize=True, scale_each=True
    )
    writer.add_image('FeatureMaps/Conv1', grid, epoch)

# 主训练循环
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    best_acc = 0
    start_epoch = 0
    
    # 可视化模型和样本
    visualize_model()
    visualize_samples()
    
    # 训练模型
    total_epochs = 50
    for epoch in range(start_epoch, total_epochs):
        print(f'Epoch {epoch}/{total_epochs-1}')
        print('-' * 10)
        
        train_loss, train_acc = train(epoch)
        test_loss, test_acc = test(epoch)
        
        # 学习率调度
        scheduler.step()
        
        # 可视化训练过程中的特征图
        if epoch % 5 == 0:  # 每5个epoch可视化一次
            inputs, _ = next(iter(testloader))
            visualize_feature_maps(inputs[:1], epoch)
        
        # 记录训练过程中的权重和梯度直方图
        if epoch % 10 == 0:  # 每10个epoch记录一次
            for name, param in model.named_parameters():
                if param.requires_grad:
                    writer.add_histogram(f'Params/{name}', param, epoch)
                    if param.grad is not None:
                        writer.add_histogram(f'Grads/{name}', param.grad, epoch)
        
        print()
    
    # 关闭TensorBoard写入器
    writer.close()
    print('Training completed!')
    print(f'Best accuracy: {best_acc:.2f}%')    

  1. tensorboard和torch版本存在一定的不兼容性,如果报错请新建环境尝试。启动tensorboard的时候需要先在cmd中进入对应的环境,conda activate xxx,再用cd命令进入环境(如果本来就是正确的则无需操作)。
  2. tensorboard的代码还有有一定的记忆量,实际上深度学习的经典代码都是类似于八股文,看多了就习惯了,难度远远小于考研数学等需要思考的内容
  3. 实际上对目前的ai而言,你只需要先完成最简单的demo,然后让他给你加上tensorboard需要打印的部分即可。---核心是弄懂tensorboard可以打印什么信息,以及如何看可视化后的结果,把ai当成记忆大师用到的时候通过它来调取对应的代码即可。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2404413.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Xilinx FPGA 重构Multiboot ICAPE2和ICAPE3使用

一、FPGA Multiboot 本文主要介绍基于IPROG命令的FPGA多版本重构&#xff0c;用ICAP原语实现在线多版本切换。需要了解MultiBoot Fallback点击链接。 如下图所示&#xff0c;ICAP原语可实现flash中n1各版本的动态切换&#xff0c;在工作过程中&#xff0c;可以通过IPROG命令切…

Redis专题-基础篇

题记 本文涵盖了Redis的各种数据结构和命令&#xff0c;Redis的各种常见Java客户端的应用和最佳实践 jedis案例github地址&#xff1a;https://github.com/whltaoin/fedis_java_demo SpringbootDataRedis案例github地址&#xff1a;https://github.com/whltaoin/springbootData…

springMVC-11 中文乱码处理

前言 本文介绍了springMVC中文乱码的解决方案&#xff0c;同时也贴出了本人遇到过的其他乱码情况&#xff0c;可以根据自身情况选择合适的解决方案。 其他-jdbc、前端、后端、jsp乱码的解决 Tomcat导致的乱码解决 自定义中文乱码过滤器 老方法&#xff0c;通过javaW…

【iOS安全】iPhone X iOS 16.7.11 (20H360) WinRa1n 越狱教程

前言 越狱iPhone之后&#xff0c;一定记得安装一下用于屏蔽更新的描述文件&#xff08;可使用爱思助手&#xff09; 因为即便关闭了自动更新&#xff0c;iPhone仍会在某些时候自动更新系统&#xff0c;导致越狱失效&#xff1b;更为严重的是&#xff0c;更新后的iOS版本可能是…

智能标志桩图像监测装置如何守护地下电缆安全

在现代城市基础设施建设中&#xff0c;大量电缆、管道被埋设于地下&#xff0c;这虽然美化了城市景观&#xff0c;却也带来了新的安全隐患。施工挖掘时的意外破坏、自然灾害的影响&#xff0c;都可能威胁这些"城市血管"的安全运行。 传统的地下设施标识方式往往只依…

【网站建设】网站 SEO 中 meta 信息修改全攻略 ✅

在做 SEO 优化时,除了前一篇提过的Title之外,meta 信息(通常指 <meta> 标签)也是最基础、最重要的内容之一,主要包括: <meta name="description"> <meta name="keywords"> 搜索引擎重点参考这些信息,决定你网页的展示效果与排名。…

计算机视觉处理----OpenCV(从摄像头采集视频、视频处理与视频录制)

一、采集视频 VideoCapture 用于从视频文件、摄像头或其他视频流设备中读取视频帧。它可以捕捉来自 多种源的视频。 cv2.VideoCapture() 打开摄像头或视频文件。 cap cv2.VideoCapture(0) # 0表示默认摄像头&#xff0c;1是第二个摄像头&#xff0c;传递视频文件路径也可以 …

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- 第一篇:MIPI CSI-2基础入门

第一篇&#xff1a;MIPI CSI-2基础入门 1. 为什么需要CSI-2&#xff1f; 痛点场景对比 &#xff08;用生活案例降低理解门槛&#xff09; 传统并行接口CSI-2接口30根线传输720P图像仅需5根线&#xff08;1对CLK4对DATA&#xff09;线距&#xff1e;5cm时出现重影线缆可长达1…

变幻莫测:CoreData 中 Transformable 类型面面俱到(一)

概述 各位似秃似不秃小码农们都知道&#xff0c;在苹果众多开发平台中 CoreData 无疑是那个最简洁、拥有“官方认证”且最具兼容性的数据库框架。使用它可以让我们非常方便的搭建出 App 所需要的持久存储体系。 不过&#xff0c;大家是否知道在 CoreData 中还存在一个 Transfo…

开源技术驱动下的上市公司财务主数据管理实践

开源技术驱动下的上市公司财务主数据管理实践 —— 以人造板制造业为例 引言&#xff1a;财务主数据的战略价值与行业挑战 在资本市场监管日益严格与企业数字化转型的双重驱动下&#xff0c;财务主数据已成为上市公司财务治理的核心基础设施。对于人造板制造业而言&#xff0…

Java建造者模式(Builder Pattern)详解与实践

一、引言 在软件开发中&#xff0c;我们经常会遇到需要创建复杂对象的场景。例如&#xff0c;构建一个包含多个可选参数的对象时&#xff0c;传统的构造函数或Setter方法可能导致代码臃肿、难以维护。此时&#xff0c;建造者模式&#xff08;Builder Pattern&#xff09;便成为…

win32相关(IAT HOOK)

IAT HOOK 什么是IAT Hook&#xff1f; IAT Hook&#xff08;Import Address Table Hook&#xff0c;导入地址表钩子&#xff09;是一种Windows平台下的API钩取技术&#xff0c;通过修改目标程序的导入地址表(IAT)来拦截和重定向API调用 在我们之前学习pe文件结构的导入表时&am…

零基础玩转物联网-串口转以太网模块如何快速实现与TCP服务器通信

目录 1 前言 2 环境搭建 2.1 硬件准备 2.2 软件准备 2.3 驱动检查 3 TCP服务器通信配置与交互 3.1 硬件连接 3.2 开启TCP服务器 3.3 打开配置工具读取基本信息 3.4 填写连接参数进行连接 3.5 通信测试 4 总结 1 前言 TCP是TCP/IP体系中的传输层协议&#xff0c;全称为Transmiss…

ESP32开发之LED闪烁和呼吸的实现

硬件电路介绍GPIO输出模式GPIO配置过程闪烁灯的源码LED PWM的控制器(LEDC)概述LEDC配置过程及现象整体流程 硬件电路介绍 电路图如下&#xff1a; 只要有硬件基础的应该都知道上图中&#xff0c;当GPIO4的输出电平为高时&#xff0c;LED灯亮&#xff0c;反之则熄灭。如果每间…

【产品业务设计】支付业务设计规范细节记录,含订单记录、支付业务记录、支付流水记录、退款业务记录

【产品业务设计】支付业务设计规范细节记录&#xff0c;含订单记录、支付业务记录、支付流水记录 前言 我为什么要写这个篇文章 总结设计经验生成设计模板方便后期快速搭建 一个几张表 一共5张表&#xff1b; 分别是&#xff1a; 订单主表&#xff1a;jjy_orderMain订单产…

2025软件供应链安全最佳实践︱证券DevSecOps下供应链与开源治理实践

项目背景&#xff1a;近年来&#xff0c;云计算、AI人工智能、大数据等信息技术的不断发展、各行各业的信息电子化的步伐不断加快、信息化的水平不断提高&#xff0c;网络安全的风险不断累积&#xff0c;金融证券行业面临着越来越多的威胁挑战。特别是近年以来&#xff0c;开源…

WebRTC通话原理与入门难度实战指南

波煮的实习公司主要是音视频业务&#xff0c;所以最近在补习WebRTC的相关内容&#xff0c;会不定期给大家分享学习心得和笔记。 文章目录 WebRTC通话原理进行媒体协商&#xff1a;彼此要了解对方支持的媒体格式网络协商&#xff1a;彼此要了解对方的网络情况&#xff0c;这样才…

N元语言模型 —— 一文讲懂!!!

目录 引言 一. 基本知识 二.参数估计 三.数据平滑 一.加1法 二.减值法/折扣法 ​编辑 1.Good-Turing 估计 ​编辑 2.Back-off (后备/后退)方法 3.绝对减值法 ​编辑4.线性减值法 5.比较 三.删除插值法(Deleted interpolation) 四.模型自适应 引言 本章节讲的…

.NET 9中的异常处理性能提升分析:为什么过去慢,未来快

一、为什么要关注.NET异常处理的性能 随着现代云原生、高并发、分布式场景的大量普及&#xff0c;异常处理&#xff08;Exception Handling&#xff09;早已不再只是一个冷僻的代码路径。在高复杂度的微服务、网络服务、异步编程环境下&#xff0c;服务依赖的外部资源往往不可…

Mac 安装git心路历程(心累版)

省流版&#xff1a;直接安装Xcode命令行工具即可&#xff0c;不用安Xcode。 git下载官网 第一部分 上网初步了解后&#xff0c;打算直接安装Binary installer&#xff0c;下载完安装时&#xff0c;苹果还阻止安装&#xff0c;只好在“设置–安全性与隐私”最下面的提示进行安…