第J2周:ResNet50V2 算法实战与解析

news2025/5/24 10:56:12
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

学习目标

✅ 根据TensorFlow代码,编写出相应的Python代码
✅ 了解ResNetV2和ResNet模型的区别

一、环境配置

在这里插入图片描述

二、数据预处理

在这里插入图片描述

三、创建、划分数据集

在这里插入图片描述

四、 创建数据加载器

在这里插入图片描述

五、加载预训练

在这里插入图片描述

在这里插入图片描述

六、显示训练数据

# 显示一些训练图像示例
def show_images(loader, title="数据示例"):
    plt.figure(figsize=(12, 8))
    plt.suptitle(title, fontsize=16)
    
    try:
        for batch_idx, (images, labels) in enumerate(loader):
            images = images[:12]
            labels = labels[:12]
            break
        
        for i in range(min(12, len(images))):
            ax = plt.subplot(3, 4, i + 1)
            img = images[i].cpu().numpy().transpose((1, 2, 0))
            # 反标准化
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = std * img + mean
            img = np.clip(img, 0, 1)
            ax.imshow(img)
            ax.set_title(f"{dataset.classes[labels[i]]}", fontsize=10)
            ax.axis("off")
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"无法显示图像示例: {e}")
        plt.close()

# 显示训练数据示例
print("\n显示训练数据示例...")
show_images(train_loader, "训练数据示例(无数据增强)")

七、编写训练函数、测试函数、设置早停机制

# 训练函数
def train_epoch(model, device, train_loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        try:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total += labels.size(0)
            
            if batch_idx % 10 == 0:
                print(f'Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
            
            # 清理GPU内存
            del inputs, labels, outputs, loss
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"训练批次 {batch_idx} 出现错误: {e}")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            continue
    
    epoch_loss = running_loss / total if total > 0 else float('inf')
    epoch_acc = running_corrects.double() / total if total > 0 else 0
    
    return epoch_loss, epoch_acc

# 验证函数
def validate(model, device, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            try:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                total += labels.size(0)
                
                # 清理内存
                del inputs, labels, outputs, loss
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"验证过程中出现错误: {e}")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                continue
    
    epoch_loss = running_loss / total if total > 0 else float('inf')
    epoch_acc = running_corrects.double() / total if total > 0 else 0
    
    return epoch_loss, epoch_acc
# 早停机制
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0.001, path='best_resnet50v2.pth'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

八、开始训练

# 开始训练
print("\n开始训练...")
num_epochs = 25
early_stopping = EarlyStopping(patience=5, verbose=True)

train_losses = []
train_accs = []
val_losses = []
val_accs = []
lr_history = []

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    print('-' * 10)
    
    try:
        # 记录当前学习率
        current_lr = optimizer.param_groups[0]['lr']
        lr_history.append(current_lr)
        
        # 训练阶段
        train_loss, train_acc = train_epoch(model, device, train_loader, optimizer, criterion)
        train_losses.append(train_loss)
        train_accs.append(train_acc.item())
        
        # 验证阶段
        val_loss, val_acc = validate(model, device, val_loader, criterion)
        val_losses.append(val_loss)
        val_accs.append(val_acc.item())
        
        # 更新学习率
        scheduler.step()
        
        # 打印结果
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%')
        print(f'Learning Rate: {current_lr:.6f}')
        
        # 早停检查
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
            
    except Exception as e:
        print(f"Epoch {epoch+1} 出现错误: {e}")
        print("尝试清理内存并继续...")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        continue

print("\n训练完成!")
开始训练...

Epoch 1/25
----------
Batch: 0/28, Loss: 1.3802
Batch: 10/28, Loss: 1.0213
Batch: 20/28, Loss: 0.4903
Train Loss: 0.8478, Train Acc: 85.49%
Val Loss: 0.1881, Val Acc: 99.12%
Learning Rate: 0.000100
Validation loss decreased (inf --> 0.188087).  Saving model ...

Epoch 2/25
----------
Batch: 0/28, Loss: 0.2666
Batch: 10/28, Loss: 0.0800
Batch: 20/28, Loss: 0.0263
Train Loss: 0.0922, Train Acc: 99.55%
Val Loss: 0.0231, Val Acc: 100.00%
Learning Rate: 0.000100
Validation loss decreased (0.188087 --> 0.023144).  Saving model ...

Epoch 3/25
----------
Batch: 0/28, Loss: 0.0139
Batch: 10/28, Loss: 0.0292
Batch: 20/28, Loss: 0.0233
Train Loss: 0.0256, Train Acc: 100.00%
Val Loss: 0.0150, Val Acc: 100.00%
Learning Rate: 0.000100
Validation loss decreased (0.023144 --> 0.015029).  Saving model ...

Epoch 4/25
----------
Batch: 0/28, Loss: 0.0086
Batch: 10/28, Loss: 0.0055
Batch: 20/28, Loss: 0.0183
Train Loss: 0.0079, Train Acc: 100.00%
Val Loss: 0.0113, Val Acc: 100.00%
Learning Rate: 0.000100
Validation loss decreased (0.015029 --> 0.011310).  Saving model ...

Epoch 5/25
----------
Batch: 0/28, Loss: 0.0280
Batch: 10/28, Loss: 0.0064
Batch: 20/28, Loss: 0.0134
Train Loss: 0.0059, Train Acc: 100.00%
Val Loss: 0.0104, Val Acc: 100.00%
Learning Rate: 0.000100
EarlyStopping counter: 1 out of 5

Epoch 6/25
----------
Batch: 0/28, Loss: 0.0039
Batch: 10/28, Loss: 0.0042
Batch: 20/28, Loss: 0.0041
Train Loss: 0.0209, Train Acc: 99.78%
Val Loss: 0.0122, Val Acc: 100.00%
Learning Rate: 0.000100
EarlyStopping counter: 2 out of 5

Epoch 7/25
----------
Batch: 0/28, Loss: 0.0015
Batch: 10/28, Loss: 0.0026
Batch: 20/28, Loss: 0.0007
Train Loss: 0.0171, Train Acc: 99.78%
Val Loss: 0.0115, Val Acc: 100.00%
Learning Rate: 0.000100
EarlyStopping counter: 3 out of 5

Epoch 8/25
----------
Batch: 0/28, Loss: 0.0041
Batch: 10/28, Loss: 0.0077
Batch: 20/28, Loss: 0.0032
Train Loss: 0.0071, Train Acc: 100.00%
Val Loss: 0.0112, Val Acc: 100.00%
Learning Rate: 0.000050
EarlyStopping counter: 4 out of 5

Epoch 9/25
----------
Batch: 0/28, Loss: 0.0051
Batch: 10/28, Loss: 0.0014
Batch: 20/28, Loss: 0.0037
Train Loss: 0.0026, Train Acc: 100.00%
Val Loss: 0.0091, Val Acc: 100.00%
Learning Rate: 0.000050
Validation loss decreased (0.011310 --> 0.009129).  Saving model ...

Epoch 10/25
----------
Batch: 0/28, Loss: 0.0164
Batch: 10/28, Loss: 0.0057
Batch: 20/28, Loss: 0.0015
Train Loss: 0.0058, Train Acc: 100.00%
Val Loss: 0.0080, Val Acc: 100.00%
Learning Rate: 0.000050
Validation loss decreased (0.009129 --> 0.008041).  Saving model ...

Epoch 11/25
----------
Batch: 0/28, Loss: 0.0019
Batch: 10/28, Loss: 0.0017
Batch: 20/28, Loss: 0.0009
Train Loss: 0.0197, Train Acc: 99.78%
Val Loss: 0.0116, Val Acc: 100.00%
Learning Rate: 0.000050
EarlyStopping counter: 1 out of 5

Epoch 12/25
----------
Batch: 0/28, Loss: 0.0011
Batch: 10/28, Loss: 0.0018
Batch: 20/28, Loss: 0.0016
Train Loss: 0.0030, Train Acc: 100.00%
Val Loss: 0.0123, Val Acc: 100.00%
Learning Rate: 0.000050
EarlyStopping counter: 2 out of 5

Epoch 13/25
----------
Batch: 0/28, Loss: 0.0019
Batch: 10/28, Loss: 0.0079
Batch: 20/28, Loss: 0.0068
Train Loss: 0.0040, Train Acc: 100.00%
Val Loss: 0.0085, Val Acc: 100.00%
Learning Rate: 0.000050
EarlyStopping counter: 3 out of 5

Epoch 14/25
----------
Batch: 0/28, Loss: 0.0009
Batch: 10/28, Loss: 0.0091
Batch: 20/28, Loss: 0.0007
Train Loss: 0.0052, Train Acc: 99.78%
Val Loss: 0.0106, Val Acc: 100.00%
Learning Rate: 0.000050
EarlyStopping counter: 4 out of 5

Epoch 15/25
----------
Batch: 0/28, Loss: 0.0020
Batch: 10/28, Loss: 0.0004
Batch: 20/28, Loss: 0.0061
Train Loss: 0.0026, Train Acc: 100.00%
Val Loss: 0.0094, Val Acc: 100.00%
Learning Rate: 0.000025
EarlyStopping counter: 5 out of 5
Early stopping

训练完成!

九、评估

# 绘制训练曲线
plt.figure(figsize=(15, 10))

plt.subplot(2, 2, 1)
plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss', marker='o')
plt.plot(range(1, len(val_losses)+1), val_losses, label='Val Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 2)
plt.plot(range(1, len(train_accs)+1), [acc*100 for acc in train_accs], label='Train Acc', marker='o')
plt.plot(range(1, len(val_accs)+1), [acc*100 for acc in val_accs], label='Val Acc', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 3)
plt.plot(range(1, len(lr_history)+1), lr_history, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.yscale('log')
plt.grid(True)

plt.subplot(2, 2, 4)

在这里插入图片描述

十、总结

ResNet v2的主要改进点。可能包括:1. 使用预激活结构,把BatchNorm和ReLU放在卷积之前。这样在残差块中,激活函数先于卷积层应用,这样梯度流动更顺畅,更有效地解决梯度消失的问题。2. 可能还调整了残差路径的设计,比如在跳跃连接中避免不必要的操作,保持恒等映射。3. 可能还涉及到了更深的网络结构或者更高效的训练方法。

然后,我需要总结ResNet v2的架构变化。比如,原来的残差块是conv-BN-ReLU的序列,而v2可能调整为BN-ReLU-conv的顺序,也就是预激活。这样的调整如何影响网络的表现?比如训练更稳定,准确率更高。此外,可能还探讨了不同的残差块设计,比如是否在跳跃连接中加入额外的卷积或BN层,而v2可能简化了这部分,保持跳跃连接尽可能简单,以保留恒等映射的特性。

另外,ResNet v2可能还进行了实验验证,比如在CIFAR-10、ImageNet等数据集上的表现,证明其有效性。比如,更深的网络(如1000层)是否能够稳定训练,准确率是否有提升。

需要注意的是,用户可能需要对比ResNet v1和v2的区别,所以总结时应该明确列出这些区别。比如预激活结构、恒等映射的改进、更优的梯度流动等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2384542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

虚拟机Centos7:Cannot find a valid baseurl for repo: base/7/x86_64问题解决

问题 解决&#xff1a;更新yum仓库源 # 备份现有yum配置文件 sudo cp -r /etc/yum.repos.d /etc/yum.repos.d.backup# 编辑CentOS-Base.repo文件 vi /etc/yum.repos.d/CentOS-Base.repo[base] nameCentOS-$releasever - Base baseurlhttp://mirrors.aliyun.com/centos/$relea…

IP风险度自检,多维度守护网络安全

如今IP地址不再只是网络连接的标识符&#xff0c;更成为评估安全风险的核心维度。IP风险度通过多维度数据建模&#xff0c;量化IP地址在网络环境中的安全威胁等级&#xff0c;已成为企业反欺诈、内容合规、入侵检测的关键工具。据Gartner报告显示&#xff0c;2025年全球78%的企…

NV066NV074美光固态颗粒NV084NV085

NV066NV074美光固态颗粒NV084NV085 在存储技术的快速发展浪潮中&#xff0c;美光科技&#xff08;Micron Technology&#xff09;始终扮演着引领者的角色。其NV系列闪存颗粒凭借创新设计和卓越性能&#xff0c;成为技术爱好者、硬件开发者乃至企业级用户关注的焦点。本文将围绕…

C++ 日志系统实战第六步:性能测试

全是通俗易懂的讲解&#xff0c;如果你本节之前的知识都掌握清楚&#xff0c;那就速速来看我的项目笔记吧~ 本文项目结束&#xff01; 性能测试 下面对日志系统做一个性能测试&#xff0c;测试一下平均每秒能打印多少条日志消息到文件。 主要的测试方法是&#xff1a;每秒能…

Java桌面应用开发详解:自制截图工具从设计到打包的全流程【附源码与演示】

&#x1f525; 本文详细介绍一个Java/JavaFX学习项目——轻量级智能截图工具的开发实践。通过这个项目&#xff0c;你将学习如何使用Java构建桌面应用&#xff0c;掌握JavaFX界面开发、系统托盘集成、全局快捷键注册等实用技能。本文主要关注基础功能实现&#xff0c;适合Java初…

手写一个简单的线程池

手写一个简单的线程池 项目仓库&#xff1a;https://gitee.com/bossDuy/hand-tearing-thread-pool 基于一个b站up的课程&#xff1a;https://www.bilibili.com/video/BV1cJf2YXEw3/?spm_id_from333.788.videopod.sections&vd_source4cda4baec795c32b16ddd661bb9ce865 理…

siparmyknife:SIP协议渗透测试的瑞士军刀!全参数详细教程!Kali Linux教程!

简介 SIP Army Knife 是一个模糊测试器&#xff0c;用于搜索跨站点脚本、SQL 注入、日志注入、格式字符串、缓冲区溢出等。 安装 源码安装 通过以下命令来进行克隆项目源码&#xff0c;建议请先提前挂好代理进行克隆。 git clone https://github.com/foreni-packages/sipa…

【Java高阶面经:微服务篇】4.大促生存法则:微服务降级实战与高可用架构设计

一、降级决策的核心逻辑:资源博弈下的生存选择 1.1 大促场景的资源极限挑战 在电商大促等极端流量场景下,系统面临的资源瓶颈呈现指数级增长: 流量特征: 峰值QPS可达日常的50倍以上(如某电商大促下单QPS从1万突增至50万)流量毛刺持续时间短(通常2-4小时),但对系统稳…

通过上传使大模型读取并分析文件实战

一、技术背景与需求分析 我们日常在使用AI的时候一定都上传过文件&#xff0c;AI会根据用户上传的文件内容结合用户的请求进行分析&#xff0c;给出用户解答。但是这是怎么实现的呢&#xff1f;在我们开发自己的大模型应用时肯定是不可避免的要思考这个问题&#xff0c;今天我会…

VueRouter路由组件的用法介绍

1.1、<router-link>标签 <router-link>标签的作用是实现路由之间的跳转功能&#xff0c;默认情况下&#xff0c;<router-link>标签是采用超链接<a>标签显示的&#xff0c;通过to属性指定需要跳转的路由地址。当然&#xff0c;如果你不想使用默认的<…

数据结构第1章 (竟成)

第 1 章 编程基础 1.1 前言 因为数据结构的代码大多采用 C 语言进行描述。而且&#xff0c;408 考试每年都有一道分值为 13 - 15 的编程题&#xff0c;要求使用 C/C 语言编写代码。所以&#xff0c;本书专门用一章来介绍 408 考试所需的 C/C 基础知识。有基础的考生可以快速浏览…

Terraform创建阿里云基础组件资源

这里首先要找到阿里云的官方使用说明: 中文版:Terraform(Terraform)-阿里云帮助中心 英文版:Terraform Registry 各自创建一个阿里云的RAM子账号,并给与OPAPI的调用权限,(就是有aksk,生成好之后保存下.) 创建路径: 登陆阿里云主账号-->控制台-->右上角企业-->人员…

企业级调度器LVS

访问效果 涉及内容&#xff1a;浏览拆分、 DNS 解析、反向代理、负载均衡、数据库等 1 集群 1.1 集群类型简介 对于⼀个业务项⽬集群来说&#xff0c;根据业务中的特性和特点&#xff0c;它主要有三种分类&#xff1a; 高扩展 (LB) &#xff1a;单个主机负载不足的时候&#xf…

【Web前端】HTML网页编程基础

HTML5简介与基础骨架 HTML5是用来描述网页的一种语言&#xff0c;被称为超文本标记语言。用HTML5编写的文件&#xff0c;后缀以.html结尾 HTML是一种标记语言&#xff0c;标记语言是一套标记标签。标签是由尖括号包围的关键字&#xff0c;例如<html> 标签有两种表现形…

阿里开源 CosyVoice2:打造 TTS 文本转语音实战应用

1、引言 1.1、CosyVoice2 简介 阿里通义实验室推出音频基座大模型 FunAudioLLM,包含 SenseVoice 和 CosyVoice 两大模型。 CosyVoice:模拟音色与提升情感表现力 多语言 支持的语言: 中文、英文、日文、韩文、中文方言(粤语、四川话、上海话、天津话、武汉话等)跨语言及…

RabbitMQ可靠传输——持久性、发送方确认

一、持久性 前面学习消息确认机制时&#xff0c;是为了保证Broker到消费者直接的可靠传输的&#xff0c;但是如果是Broker出现问题&#xff08;如停止服务&#xff09;&#xff0c;如何保证消息可靠性&#xff1f;对此&#xff0c;RabbitMQ提供了持久化功能&#xff1a; 持久…

无人机开启未来配送新篇章

低空物流&#xff08;无人机物流&#xff09;是利用无人机等低空飞行器进行货物运输的物流方式&#xff0c;依托低空空域&#xff08;通常在120-300米&#xff09;实现快速、高效、灵活的配送服务。它是低空经济的重要组成部分&#xff0c;广泛应用于快递配送、医疗物资运输、农…

Qt状态机QStateMachine

QStateMachine QState 提供了一种强大且灵活的方式来表示状态机中的状态&#xff0c;通过与状态机类(QStateMachine)和转换类(QSignalTransition&#xff0c; QEventTransition)结合&#xff0c;可以实现复杂的状态逻辑和用户交互。合理使用嵌套状态机、信号转换、动作与动画、…

Java详解LeetCode 热题 100(20):LeetCode 48. 旋转图像(Rotate Image)详解

文章目录 1. 题目描述2. 理解题目3. 解法一&#xff1a;转置 翻转3.1 思路3.2 Java代码实现3.3 代码详解3.4 复杂度分析3.5 适用场景 4. 解法二&#xff1a;四点旋转法4.1 思路4.2 Java代码实现4.3 代码详解4.4 复杂度分析4.5 适用场景 5. 详细步骤分析与示例跟踪5.1 解法一&a…

CAU人工智能class4 批次归一化

归一化 在对输入数据进行预处理时会用到归一化&#xff0c;将输入数据的范围收缩到0到1之间&#xff0c;这有利于避免纲量对模型训练产生的影响。 但当模型过深时会产生下述问题&#xff1a; 当一个学习系统的输入分布发生变化时&#xff0c;这种现象称之为“内部协变量偏移”…