文章目录
- 一、Pytorch的LeNet-5网络准备
- 二、保存用于导入matlab的model
- 三、导入matlab
- 四、用matlab训练这个导入的网络
这里演示从pytorch的LeNet-5网络导入到matlab中进行训练用。
一、Pytorch的LeNet-5网络准备
根据LeNet-5的结构图,我们可以写如下结构
import torch
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self, num_classes=10):
super(LeNet5, self).__init__()
self.feature_extractor = nn.Sequential(
# C1: Conv(1→6), 输出 28x28 → 6x28x28
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
nn.BatchNorm2d(6),
nn.ReLU(inplace=True),
# S2: MaxPool 2x2, 输出 6x14x14
nn.MaxPool2d(kernel_size=2, stride=2),
# C3: Conv(6→16), 输出 16x10x10
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
# S4: MaxPool 2x2, 输出 16x5x5
nn.MaxPool2d(kernel_size=2, stride=2),
# C5: Conv(16→120), 输出 120x1x1(接近 flatten)
nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
nn.BatchNorm2d(120),
nn.ReLU(inplace=True)
)
self.classifier = nn.Sequential(
nn.Flatten(), # [batch, 120]
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(84, num_classes)
)
def forward(self, x):
x = self.feature_extractor(x)
x = self.classifier(x)
return x
if __name__ == "__main__":
model = LeNet5()
model.eval()
# 示例输入:MNIST 的图像大小 [1, 1, 28, 28]
example_input = torch.randn(1, 1, 28, 28)
# Tracing
traced_model = torch.jit.trace(model, example_input)
# 保存
traced_model.save("traced_lenet5.pt")
print("✅ traced_lenet5.pt 已成功保存!")
二、保存用于导入matlab的model
在上面的代码中,我们有几行是产生trace model的,即
torch.jit.trace()
是 PyTorch 的一种 静态图(Static Graph)转换方法,它会:
- 运行一次前向传播(forward),记录下所有的张量操作;
- 然后构建一个不可变的计算图(graph),这个图就是所谓的 trace model。
保存这个model后,我们就得到了traced_lenet5.pt这个文件。
三、导入matlab
导入matlab可以通过APPS里的Deep Network Designer,如下图
然后通过From PyTorch这个地方,导入刚才保存的网络结构
点开From PyTorch后, 我们可以复制刚才保存的traced_lenet5.pt这个文件的绝对路径用于导入,如下图
然后,import就会有,如下结果
然后,点击红色方框那部分,进行一下输入尺寸的修改
导入的这个网络框架,我们还要在末尾段加入softmax层,这个层在原pytorch框架里没写
这样,我们就完成了LeNet5从Pytorch里导入到matlab了。接着我们可以通过Analyze按钮分析这个网络,如下图
没有问题后,我们就可以Export这个网络到工作区了,输出的网络自动命名为net_1。
四、用matlab训练这个导入的网络
训练的代码如下
% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...
IncludeSubfolders=true, ... % 指定在加载数据时包含子文件夹中的图像
LabelSource="foldernames"); % 使用子文件夹的名称作为图像的标签(自动分类)
% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels); % 将 imds.Labels
%%
% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");
% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...
MaxEpochs = 4, ...
ValidationData = imdsValidation, ...
ValidationFrequency = 30, ...
Plots = "training-progress", ...
Metrics = "accuracy", ...
Verbose = false);
% 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);
%%
% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");
%%
% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsTest);
% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YTest = scores2label(scores, classNames);
% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);
% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);
% 创建一个新的图形窗口
figure
tiledlayout("flow") % 使用自动流式布局排列子图(tiled layout)
% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9
nexttile % 在下一个网格位置准备绘图
img = readimage(imdsTest, idx(i)); % 读取第 idx(i) 张图像
imshow(img) % 显示图像
title("Predicted Class: " + string(YTest(idx(i)))) % 设置标题,显示预测类别
end
上面用到的数据集是0-9的数字图片,如下图
训练的详细信息如下
预测结果显示