卷积神经网络迁移学习:原理与实践指南

news2025/6/15 15:05:40

引言

        在深度学习领域,卷积神经网络(CNN)已经在计算机视觉任务中取得了巨大成功。然而,从头开始训练一个高性能的CNN模型需要大量标注数据和计算资源。迁移学习(Transfer Learning)技术为我们提供了一种高效解决方案,它能够将预训练模型的知识迁移到新任务中,显著减少训练时间和数据需求。本文将全面介绍CNN迁移学习的原理、优势及实践方法。

1、内容

      迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作

2、步骤

1、选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

2、冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

3、在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

4、微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。

5、评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

Resnet网络:

原理:

       卷积神经网络都是通过卷积层和池化层的叠加组成的。 在实际的试验中发现,随着卷积层和池化层的叠加,学习效果不会逐渐变好,反而出现2个问题:

1、梯度消失和梯度爆炸

梯度消失:若每一层的误差梯度小于1,反向传播时,网络越深,梯度越趋近于0

梯度爆炸:若每一层的误差梯度大于1,反向传播时,网络越深,梯度越来越大

2、退化问题

       为了解决梯度消失或梯度爆炸问题,论文提出通过数据的预处理以及在网络中使用 BN(Batch Normalization)层来解决。

      为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)

1、18层resnet结构:

2、BN(Batch Normalization)

实例

1、导入相关的库

import torch
from torch.utils.data import DataLoader,Dataset  #数据包管理工具,打包数据,
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import numpy as np

2、调取模型并冻结参数

#不再需要自己来搭建模型了。预训练的文件也加载进去了。
# 将resnet18模型迁移到食物分类项目中.#残差网络是固定的网络结构,不需要你自己来类定义了。
resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#weights=models.ResNet18_Weights.DEFAULT表示使用在 ImageNet 数据集上预先训练好的权重
for param in resnet_model.parameters():
    print(param)
    param.requires_grad=False      #冻结
#模型所有参数(即权重和偏差)的requires_grad属性设置为False,从而冻结所有模型参数,

详细说明:

  1. models.resnet18()加载ResNet18架构

  2. weights=models.ResNet18_Weights.DEFAULT指定使用官方预训练权重

  3. 遍历所有参数并冻结

3、对网络模型进行微调

in_features=resnet_model.fc.in_features    #获取模型原输入的特征个数
resnet_model.fc=nn.Linear(in_features,20)  #创建一个全连接层

4、保存需要训练的参数

params_to_update=[]    #保存需要训练的参数,仅仅包含全连接层的参数
for param in resnet_model.parameters():
    if param.requires_grad==True:
        params_to_update.append(param)

5、数据预处理

data_transforms={
'train':
transforms.Compose([
    transforms.Resize([300, 300]),
    transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选
    transforms.CenterCrop(224),  # 从中心开始裁剪[256,256]
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率
    transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
    # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R=G=B
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid':
transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}

数据预处理包括:

  • 训练集使用多种数据增强(随机旋转、水平翻转等)

  • 验证集只进行简单的resize和归一化

  • 归一化参数使用ImageNet的均值和标准差

6、自定义数据集的类

class food_dataset(Dataset):
    def __init__(self,file_path,transform=None):
        self.file_path=file_path
        self.imgs=[]
        self.labels=[]
        self.transform=transform
        with open(self.file_path) as f:
            samples=[x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image=Image.open(self.imgs[idx])
        if self.transform:
            image=self.transform(image)

        label = self.labels[idx]
        label = torch.from_numpy(np.array(label,dtype=np.int64))
        return image,label

这个自定义Dataset类:

  1. 从文本文件读取图像路径和标签

  2. 实现__len____getitem__方法,供DataLoader使用

  3. 应用指定的transform处理图像

7、数据加载器准备

# 创建训练集和测试集实例
training_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])

# 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

数据加载器提供:

  • 批量加载功能(batch_size=64)

  • 训练数据随机打乱(shuffle=True)

  • 多线程数据预读取

8、训练和测试流程

def train(dataloader,model,loss_fn,optimizer):
    model.train()   #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
#pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
#一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()
    for X,y in dataloader:       #其中batch为每一个数据的编号
        X,y=X.to(device),y.to(device)    #把训练数据集和标签传入cpu或GPU
        pred=model.forward(X)    #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化
        loss=loss_fn(pred,y)     #通过交叉熵损失函数计算损失值loss
        # Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络
        optimizer.zero_grad()    #梯度值清零
        loss.backward()          #反向传播计算得到每个参数的梯度值w
        optimizer.step()         #根据梯度更新网络w参数


best_acc=0
def test(dataloader, model, loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()

    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

9、循环训练

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # Adam优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 学习率调度

# 训练10个epoch
epoch = 10
for i in range(epoch):
    print(f"Epoch {i + 1}")
    train(train_dataloader, model, loss_fn, optimizer)  # 训练
    scheduler.step()  # 更新学习率
    test(test_dataloader, model, loss_fn)  # 测试

print('Best accuracy:', best_acc)  # 打印最佳准确率

主循环流程:

  1. 定义损失函数和优化器

  2. 设置学习率调度器(每5个epoch学习率减半)

  3. 进行10轮训练和测试

  4. 打印最终最佳准确率

结果展示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2343572.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spark与Hadoop之间的联系和对比

(一)Spark概述 Apache Spark 是一个快速、通用、可扩展的大数据处理分析引擎。它最初由加州大学伯克利分校 AMPLab 开发,后成为 Apache 软件基金会的顶级项目。Spark 以其内存计算的特性而闻名,能够在内存中对数据进行快速处理&am…

基于线性LDA算法对鸢尾花数据集进行分类

基于线性LDA算法对鸢尾花数据集进行分类 1、效果 2、流程 1、加载数据集 2、划分训练集、测试集 3、创建模型 4、训练模型 5、使用LDA算法 6、画图3、示例代码 # 基于线性LDA算法对鸢尾花数据集进行分类# 基于线性LDA算法对鸢尾花数据集进行分类 import numpy as np import …

【Deepseek基础篇】--v3基本架构

目录 MOE参数 1.基本架构 1.1. Multi-Head Latent Attention多头潜在注意力 1.2.无辅助损失负载均衡的 DeepSeekMoE 2.多标记预测 2.1. MTP 模块 论文地址:https://arxiv.org/pdf/2412.19437 DeepSeek-V3 是一款采用 Mixture-of-Experts(MoE&…

centos7使用yum快速安装最新版本Jenkins-2.462.3

Jenkins支持多种安装方式:yum安装、war包安装、Docker安装等。 官方下载地址:https://www.jenkins.io/zh/download 本次实验使用yum方式安装Jenkins LTS长期支持版,版本为 2.462.3。 一、Jenkins基础环境的安装与配置 1.1:基本…

【vue】【element-plus】 el-date-picker使用cell-class-name进行标记,type=year不生效解决方法

typedete&#xff0c;自定义cell-class-name打标记效果如下&#xff1a; 相关代码&#xff1a; <el-date-pickerv-model"date":clearable"false":editable"false":cell-class-name"cellClassName"type"date"format&quo…

c++11新特性随笔

1.统一初始化特性 c98中不支持花括号进行初始化&#xff0c;编译时会报错&#xff0c;在11当中初始化可以通过{}括号进行统一初始化。 c98编译报错 c11: #include <iostream> #include <set> #include <string> #include <vector>int main() {std:…

C++23 中 constexpr 的重要改动

文章目录 1. constexpr 函数中使用非字面量变量、标号和 goto (P2242R3)示例代码 2. 允许 constexpr 函数中的常量表达式中使用 static 和 thread_local 变量 (P2647R1)示例代码 3. constexpr 函数的返回类型和形参类型不必为字面类型 (P2448R2)示例代码 4. 不存在满足核心常量…

全面解析React内存泄漏:原因、解决方案与最佳实践

在开发React应用时&#xff0c;内存泄漏是一个常见但容易被忽视的问题。如果处理不当&#xff0c;它会导致应用性能下降、卡顿甚至崩溃。由于React的组件化特性&#xff0c;许多开发者可能没有意识到某些操作&#xff08;如事件监听、异步请求、定时器等&#xff09;在组件卸载…

【FreeRTOS】事件标志组

文章目录 1 简介1.1事件标志1.2事件组 2事件标志组API2.1创建动态创建静态创建 2.2 删除事件标志组2.3 等待事件标志位2.4 设置事件标志位在任务中在中断中 2.5 清除事件标志位在任务中在中断中 2.6 获取事件组中的事件标志位在任务中在中断中 2.7 函数xEventGroupSync 3 事件标…

超级扩音器手机版:随时随地,大声说话

在日常生活中&#xff0c;我们常常会遇到手机音量太小的问题&#xff0c;尤其是在嘈杂的环境中&#xff0c;如KTV、派对或户外活动时&#xff0c;手机自带的音量往往难以满足需求。今天&#xff0c;我们要介绍的 超级扩音器手机版&#xff0c;就是这样一款由上海聚告德业文化发…

【数据可视化-27】全球网络安全威胁数据可视化分析(2015-2024)

&#x1f9d1; 博主简介&#xff1a;曾任某智慧城市类企业算法总监&#xff0c;目前在美国市场的物流公司从事高级算法工程师一职&#xff0c;深耕人工智能领域&#xff0c;精通python数据挖掘、可视化、机器学习等&#xff0c;发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…

【6G 开发】NV NGC

配置 生成密钥 API Keys 生成您自己的 API 密钥&#xff0c;以便通过 Docker 客户端或通过 NGC CLI 使用 Secrets Manager、NGC Catalog 和 Private Registry 的 NGC 服务 以下个人 API 密钥已成功生成&#xff0c;可供此组织使用。这是唯一一次显示您的密钥。 请妥善保管您的…

SIEMENS PLC程序解读 -Serialize(序列化)SCATTER_BLK(数据分散)

1、程序数据 第12个字节 PI 2、程序数据 第16个字节 PI 3、程序数据 第76个字节 PO 4、程序代码 2、程序解读 图中代码为 PLC 梯形图&#xff0c;主要包含以下指令及功能&#xff1a; Serialize&#xff08;序列化&#xff09;&#xff1a; 将 SRC_VARIABLE&#xff…

宁德时代25年时代长安动力电池社招入职测评SHL题库Verify测评语言理解数字推理真题

测试分为语言和数字两部分&#xff0c;测试时间各为17分钟&#xff0c;测试正式开始后不能中断或暂停

【硬核解析:基于Python与SAE J1939-71协议的重型汽车CAN报文解析工具开发实战】

引言&#xff1a;重型汽车CAN总线的数据价值与挑战 随着汽车电子化程度的提升&#xff0c;控制器局域网&#xff08;CAN总线&#xff09;已成为重型汽车的核心通信网络。不同控制单元&#xff08;ECU&#xff09;通过CAN总线实时交互海量报文数据&#xff0c;这些数据隐藏着车…

Uniapp 自定义 Tabbar 实现教程

Uniapp 自定义 Tabbar 实现教程 1. 简介2. 实现步骤2.1 创建自定义 Tabbar 组件2.2 配置 pages.json2.3 在 App.vue 中引入组件 3. 实现过程中的关键点3.1 路由映射3.2 样式设计3.3 图标处理 4. 常见问题及解决方案4.1 页面跳转问题4.2 样式适配问题4.3 性能优化 5. 扩展功能5.…

记录一次使用面向对象的C语言封装步进电机驱动

简介 (2025/4/21) 本库对目前仅针对TB6600驱动下的42步进电机的基础功能进行了一定的封装, 也是我初次尝试以面向对象的思想去编写嵌入式代码, 和直流电机的驱动步骤相似在调用stepmotor_attach()函数和stepmotor_init()函数之后仅通过结构体数组stepm然后指定枚举变量中的id即…

Spark-streaming核心编程

1.导入依赖‌&#xff1a; <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.12</artifactId> <version>3.0.0</version> </dependency> 2.编写代码‌&#xff1a; 创建Sp…

vue3+TS+echarts 折线图

需要实现的效果如下 <script setup lang"ts" name"RepsSingleLineChart">import * as echarts from echartsimport { getInitecharts } from /utils/echartimport type { EChartsOption } from echarts// 定义 props 类型interface Props {id: strin…

小火电视桌面TV版下载-小火桌面纯净版下载-官方历史版本安装包

别再费心地寻找小火桌面的官方历史版本安装包啦&#xff0c;试试乐看家桌面吧&#xff0c;它作为纯净版本的第三方桌面&#xff0c;具有诸多优点。 界面简洁纯净&#xff1a;乐看家桌面设计简洁流畅&#xff0c;页面简洁、纯净无广告&#xff0c;为用户打造了一个干净的电视操…