matlab从pytorch中导入LeNet-5网络框架

news2025/5/13 0:17:12

文章目录

  • 一、Pytorch的LeNet-5网络准备
  • 二、保存用于导入matlab的model
  • 三、导入matlab
  • 四、用matlab训练这个导入的网络

这里演示从pytorch的LeNet-5网络导入到matlab中进行训练用。

一、Pytorch的LeNet-5网络准备

根据LeNet-5的结构图,我们可以写如下结构

import torch
import torch.nn as nn


class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()

        self.feature_extractor = nn.Sequential(
            # C1: Conv(1→6), 输出 28x28 → 6x28x28
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            nn.BatchNorm2d(6),
            nn.ReLU(inplace=True),

            # S2: MaxPool 2x2, 输出 6x14x14
            nn.MaxPool2d(kernel_size=2, stride=2),

            # C3: Conv(6→16), 输出 16x10x10
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            # S4: MaxPool 2x2, 输出 16x5x5
            nn.MaxPool2d(kernel_size=2, stride=2),

            # C5: Conv(16→120), 输出 120x1x1(接近 flatten)
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
            nn.BatchNorm2d(120),
            nn.ReLU(inplace=True)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),  # [batch, 120]
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(84, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x


if __name__ == "__main__":
    model = LeNet5()

    model.eval()
    # 示例输入:MNIST 的图像大小 [1, 1, 28, 28]
    example_input = torch.randn(1, 1, 28, 28)
    # Tracing
    traced_model = torch.jit.trace(model, example_input)

    # 保存
    traced_model.save("traced_lenet5.pt")
    print("✅ traced_lenet5.pt 已成功保存!")

二、保存用于导入matlab的model

在上面的代码中,我们有几行是产生trace model的,即

在这里插入图片描述

torch.jit.trace() 是 PyTorch 的一种 静态图(Static Graph)转换方法,它会:

  • 运行一次前向传播(forward),记录下所有的张量操作;
  • 然后构建一个不可变的计算图(graph),这个图就是所谓的 trace model

保存这个model后,我们就得到了traced_lenet5.pt这个文件。

三、导入matlab

导入matlab可以通过APPS里的Deep Network Designer,如下图

在这里插入图片描述

然后通过From PyTorch这个地方,导入刚才保存的网络结构

在这里插入图片描述

点开From PyTorch后, 我们可以复制刚才保存的traced_lenet5.pt这个文件的绝对路径用于导入,如下图

在这里插入图片描述

然后,import就会有,如下结果

在这里插入图片描述

然后,点击红色方框那部分,进行一下输入尺寸的修改

在这里插入图片描述

导入的这个网络框架,我们还要在末尾段加入softmax层,这个层在原pytorch框架里没写

在这里插入图片描述

这样,我们就完成了LeNet5从Pytorch里导入到matlab了。接着我们可以通过Analyze按钮分析这个网络,如下图

在这里插入图片描述

没有问题后,我们就可以Export这个网络到工作区了,输出的网络自动命名为net_1。

在这里插入图片描述

四、用matlab训练这个导入的网络

训练的代码如下

% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...  % 指定在加载数据时包含子文件夹中的图像
    LabelSource="foldernames");  % 使用子文件夹的名称作为图像的标签(自动分类)

% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels);  % 将 imds.Labels


%%
% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");

% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...  
    MaxEpochs = 4, ...  
    ValidationData = imdsValidation, ... 
    ValidationFrequency = 30, ...  
    Plots = "training-progress", ...  
    Metrics = "accuracy", ...  
    Verbose = false); 



% 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);

%%
% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");

%%
% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsTest);

% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YTest = scores2label(scores, classNames);


% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);

% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);

% 创建一个新的图形窗口
figure
tiledlayout("flow")  % 使用自动流式布局排列子图(tiled layout)

% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9
    nexttile  % 在下一个网格位置准备绘图
    img = readimage(imdsTest, idx(i));  % 读取第 idx(i) 张图像
    imshow(img)  % 显示图像
    title("Predicted Class: " + string(YTest(idx(i))))  % 设置标题,显示预测类别
end

上面用到的数据集是0-9的数字图片,如下图

在这里插入图片描述

训练的详细信息如下

在这里插入图片描述

预测结果显示

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2328820.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot向Vue发送消息通过WebSocket实现通信

注意:如果后端有contextPath,如/app,那么前端访问的url就是ip:port/app/ws 后端实现步骤 添加Spring Boot WebSocket依赖配置WebSocket端点和消息代理创建控制器,使用SimpMessagingTemplate发送消息 前端实现步骤 安装sockjs-…

网络编程—Socket套接字(UDP)

上篇文章: 网络编程—网络概念https://blog.csdn.net/sniper_fandc/article/details/146923380?fromshareblogdetail&sharetypeblogdetail&sharerId146923380&sharereferPC&sharesourcesniper_fandc&sharefromfrom_link 目录 1 概念 2 Soc…

视频设备轨迹回放平台EasyCVR综合智能化,搭建运动场体育赛事直播方案

一、背景 随着5G技术的发展,体育赛事直播迎来了新的高峰。无论是NBA、西甲、英超、德甲、意甲、中超还是CBA等热门赛事,都是值得记录和回放的精彩瞬间。对于体育迷来说,选择观看的平台众多,但是作为运营者,搭建一套体…

AIGC实战——CycleGAN详解与实现

AIGC实战——CycleGAN详解与实现 0. 前言1. CycleGAN 基本原理2. CycleGAN 模型分析3. 实现 CycleGAN小结系列链接 0. 前言 CycleGAN 是一种用于图像转换的生成对抗网络(Generative Adversarial Network, GAN),可以在不需要配对数据的情况下将一种风格的图像转换成…

VS2022远程调试Linux程序

一、 1、VS2022安装参考 VS Studio2022安装教程(保姆级教程)_visual studio 2022-CSDN博客 注意:勾选的时候,要勾选下方的选项,才能调试Linux环境下运行的程序! 2、VS2022远程调试Linux程序测试 原文参…

345-java人事档案管理系统的设计与实现

345-java人事档案管理系统的设计与实现 项目概述 本项目为基于Java语言的人事档案管理系统,旨在帮助企事业单位高效管理员工档案信息,实现档案的电子化、自动化管理。系统涵盖了员工信息的录入、查询、修改、删除等功能,同时具备权限控制和…

【Linux系统编程】进程概念,进程状态

目录 一,操作系统(Operator System) 1-1概念 1-2设计操作系统的目的 1-3核心功能 1-4系统调用和库函数概念 二,进程(Process) 2-1进程概念与基本操作 2-2task_struct结构体内容 2-3查看进程 2-4通…

优选算法的妙思之流:分治——快排专题

专栏:算法的魔法世界 个人主页:手握风云 目录 一、快速排序 二、例题讲解 2.1. 颜色分类 2.2. 排序数组 2.3. 数组中的第K个最大元素 2.4. 库存管理 III 一、快速排序 分治,简单理解为“分而治之”,将一个大问题划分为若干个…

wx206基于ssm+vue+uniapp的优购电商小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

React编程高级主题:错误处理(Error Handling)

文章目录 **5.2 错误处理(Error Handling)概述****5.2.1 onErrorReturn / onErrorResume(错误回退)****1. onErrorReturn:提供默认值****2. onErrorResume:切换备用数据流** **5.2.2 retry / retryWhen&…

ubuntu20.04升级成ubuntu22.04

命令行 sudo do-release-upgrade 我是按提示输入y确认操作,也可以遇到配置文件冲突时建议选择N保留当前配置

SpringCloud(25)——Stream介绍

1.场景描述 当我们的分布式系统建设到一定程度了,或者服务间是通过异步请求来通讯的,那么我们避免不了使用MQ来解决问题。 假如公司内部进行了业务合并或者整合,需要服务A和服务B通过MQ的方式进行消息传递,而服务A用的是RabbitMQ&…

centos8上实现lvs集群负载均衡dr模式

1.前言 个人备忘笔记,欢迎探讨。 centos8上实现lvs集群负载均衡nat模式 centos8上实现lvs集群负载均衡dr模式 之前写过一篇lvs-nat模式。实验起来相对顺利。dr模式最大特点是响应报文不经调度器,而是直接返回客户机。 dr模式分同网段和不同网段。同…

uniapp如何接入星火大模型

写在前面:最近的ai是真的火啊,琢磨了一下,弄个uniappx的版本开发个东西玩一下,想了想不知道放啥内容,突然觉得deepseek可以接入,好家伙,一对接以后发现这是个付费的玩意,我穷&#x…

MySQL vs MSSQL 对比

在企业数据库管理系统中,MySQL 和 Microsoft SQL Server(MSSQL)是最受欢迎的两大选择。MySQL 是一款开源的关系型数据库管理系统(RDBMS),由 MySQL AB 开发,现归属于 Oracle 公司。而 MSSQL 是微…

python基础-10-组织文件

文章目录 【README】【10】组织文件(复制移动删除重命名)【10.1】shutil模块(shell工具)【10.1.1】复制文件和文件夹【10.1.1.1】复制文件夹及其下文件-shutil.copytree 【10.1.2】文件和文件夹的移动与重命名【10.1.3】永久删除文件和文件夹【10.1.4】用…

ORA-09925 No space left on device 问题处理全过程记录

本篇文章关键字:linux、oracle、审计、ORA-09925 一、故障现像 朋友找到我说是他们备份软件上报错。 问题比较明显,ORA-09925,看起来就是空间不足导致的 二、问题分析过程 这里说一下逐步的分析思路,有个意外提前说一下就是我…

多输入多输出 | Matlab实现BO-GRU贝叶斯优化门控循环单元多输入多输出预测

多输入多输出 | Matlab实现BO-GRU贝叶斯优化门控循环单元多输入多输出预测 目录 多输入多输出 | Matlab实现BO-GRU贝叶斯优化门控循环单元多输入多输出预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 Matlab实现BO-GRU贝叶斯优化门控循环单元多输入多输出预测&#…

27信号和槽_自定义信号(2)

自定义信号和槽 绑定信号和槽 如何才能触发出自定义的信号呢?(上诉代码只是将信号和槽绑定在一起,但并没有触发信号) Qt 内置的信号,都不需要咱们手动通过代码来触发 用户在 GUI, 进行某些操作,就会自动触发对应信号.(发射信号的代码已经内置…

人工智能在生物医药领域的应用地图:AIBC2025将于6月在上海召开!

人工智能在生物医药领域的应用地图:AIBC2025将于6月在上海召开! 近年来,人工智能在生物医药行业中的应用受到广泛关注。 2024年10月,2024诺贝尔化学奖被授予“计算蛋白质设计和蛋白质结构预测”,这为行业从业人员带来…