ResNet18在MNIST手写数字数据库上的深度学习网络识别及Matlab仿真实验研究
ResNet18深度学习网络的mnist手写数字数据库识别matlab仿真MNIST手写数字识别算是深度学习界的Hello World了不过这次咱们用ResNet18来整点不一样的。别看ResNet本来是给ImageNet设计的拿来折腾下28x28的小图片还挺有意思。先说说数据准备这块Matlab处理起来比Python其实更省心digitDatasetPath fullfile(matlabroot,toolbox,nnet,nndemos,nndatasets,DigitDataset); imds imageDatastore(digitDatasetPath,... IncludeSubfolders,true,LabelSource,foldernames); [imdsTrain,imdsTest] splitEachLabel(imds,0.8,randomized);这里要注意个坑原始ResNet输入是224x224的RGB图。咱们得给灰度图加个戏——用augmentedImageDatastore强行拉伸尺寸虽然有点暴力但效果还行inputSize [224 224 3]; augImdsTrain augmentedImageDatastore(inputSize,imdsTrain,ColorPreprocessing,rgb); augImdsTest augmentedImageDatastore(inputSize,imdsTest,ColorPreprocessing,rgb);接下来构建网络骨架。Matlab自带的resnet18其实可以直接魔改但为了展示原理咱们手搓一个残差块function lgraph addBasicBlock(lgraph, blockName, numFilters, stride, inputLayerName) conv1_name [blockName _conv1]; bn1_name [blockName _bn1]; conv2_name [blockName _conv2]; bn2_name [blockName _bn2]; add_name [blockName _add]; % 残差路径 lgraph addLayers(lgraph, [ convolution2dLayer(3,numFilters,Stride,stride,Padding,same,Name,conv1_name) batchNormalizationLayer(Name,bn1_name) reluLayer(Name,[blockName _relu1]) convolution2dLayer(3,numFilters,Padding,same,Name,conv2_name) batchNormalizationLayer(Name,bn2_name) ]); % shortcut连接 if stride ~ 1 shortcut [ convolution2dLayer(1,numFilters,Stride,stride,Name,[blockName _shortcut_conv]) batchNormalizationLayer(Name,[blockName _shortcut_bn]) ]; lgraph addLayers(lgraph, shortcut); lgraph connectLayers(lgraph, inputLayerName, [blockName _shortcut_conv]); else lgraph connectLayers(lgraph, inputLayerName, add_name/in2); end % 合并残差 lgraph addLayers(lgraph, additionLayer(2,Name,add_name)); lgraph connectLayers(lgraph, bn2_name, [add_name /in1]); end这个残差块实现有几个精妙之处当stride不为1时需要1x1卷积调整维度否则直接相加。注意Matlab的加法层要处理两个输入源的连接这里用connectLayers手动指定连接关系比自动构建更靠谱。ResNet18深度学习网络的mnist手写数字数据库识别matlab仿真训练配置这块别照搬ImageNet那套学习率得调小点options trainingOptions(sgdm,... InitialLearnRate,0.1,... LearnRateSchedule,piecewise,... LearnRateDropPeriod,5,... MaxEpochs,15,... Shuffle,every-epoch,... Plots,training-progress,... ValidationData,augImdsTest);跑完15个epoch基本能到99.2%左右的准确率。测试时有个小技巧用classify函数直接输出预测结果[YPred,probs] classify(net,augImdsTest); YTest imdsTest.Labels; accuracy sum(YPred YTest)/numel(YTest)最后画混淆矩阵的时候建议用自定义颜色更直观cm confusionchart(YTest, YPred); cm.Title ResNet18在MNIST上的混淆矩阵; cm.ColumnSummary column-normalized; cm.RowSummary row-normalized; cm.FontSize 12;整个过程跑下来发现虽然用ResNet18处理MNIST有点杀鸡用牛刀但残差连接确实能加速训练收敛。有意思的是把图片强行拉伸到224x224后网络前几层的特征图会保留更多细节这对识别边缘尖锐的手写数字反而有帮助。不过要注意全连接层最后别用默认的1000输出记得改成10分类哦
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2417518.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!