【重新定义matlab强大系列十七】Matlab深入浅出长短期记忆神经网络LSTM

news2025/6/8 0:28:31

🔗 运行环境:Matlab

🚩 撰写作者:左手の明天

🥇 精选专栏:《python》

🔥  推荐专栏:《算法研究》

#### 防伪水印——左手の明天 ####

💗 大家好🤗🤗🤗,我是左手の明天!好久不见💗

💗今天更新系列——重新定义matlab强大系列💗

📆  最近更新:2024 年 03 月 09 日,左手の明天的第 316 篇原创博客

📚 更新于专栏:matlab

#### 防伪水印——左手の明天 ####


 本文主要说明如何使用长短期记忆 (LSTM) 神经网络处理分类和回归任务的序列和时间序列数据。有关如何使用 LSTM 神经网络对序列数据进行分类的示例,请参阅:

【深度学习】详解利用Matlab和Python中 LSTM 网络实现序列分类

【Python深度学习】详解Python深度学习进行时间序列预测

【Python机器学习】详解Python机器学习进行时间序列预测

【Matlab深度学习】详解matlab深度学习进行时间序列预测


一、LSTM 神经网络架构

LSTM 神经网络是一种循环神经网络 (RNN),可以学习序列数据的时间步之间的长期依存关系。LSTM 神经网络的核心组件是序列输入层和 LSTM 层。序列输入层将序列或时间序列数据输入神经网络中。LSTM 层学习序列数据的时间步之间的长期相关性。

下图说明用于分类的简单 LSTM 网络的架构。该神经网络从一个序列输入层开始,后跟一个 LSTM 层。为了预测类标签,该神经网络的末尾是一个全连接层、一个 softmax 层和一个分类输出层。

下图说明用于回归的简单 LSTM 神经网络的架构。该神经网络从一个序列输入层开始,后跟一个 LSTM 层。该神经网络的末尾是一个全连接层和一个回归输出层。

下图说明用于视频分类的神经网络的架构。要将图像序列输入到神经网络,请使用序列输入层。要使用卷积层来提取特征,也就是说,要将卷积运算独立地应用于视频的每帧,请使用一个序列折叠层,后跟一个卷积层,然后是一个序列展开层。要使用 LSTM 层从向量序列中学习,请使用一个扁平化层,后跟 LSTM 层和输出层。


二、LSTM 层架构

以下是数据流通过 LSTM 层的示意图,其中输入为 x、输出为 y、时间步数为 T。在图中,ht 表示在时间步 t 的输出(也称为隐藏状态),ct 表示在该时间步的单元状态

如果层输出完整序列,则它输出 y1、…、yT,等效于 h1、…、hT。如果该层仅输出最后一个时间步,则该层输出 yT,等效于 hT。输出中的通道数量与 LSTM 层的隐藏单元数量匹配。

第一个 LSTM 运算使用 RNN 的初始状态和序列的第一个时间步来计算第一个输出和更新后的单元状态。在时间步 t 上,该运算使用 RNN 的当前状态 (ct−1,ht−1) 和序列的下一个时间步来计算输出和更新后的单元状态 ct。

该层的状态由隐藏状态(也称为输出状态)和单元状态组成。时间步 t 处的隐藏状态包含该时间步的 LSTM 层的输出。单元状态包含从前面的时间步中获得的信息。在每个时间步,该层都会在单元状态中添加或删除信息。该层使用不同的控制这些更新。

以下组件控制层的单元状态和隐藏状态。

组件目的
输入门 (i)控制单元状态更新的级别
遗忘门 (f)控制单元状态重置(遗忘)的级别
候选单元 (g)向单元状态添加信息
输出门 (o)控制添加到隐藏状态的单元状态的级别

下图说明在时间步 t 上的数据流。此图显示门如何遗忘、更新和输出单元状态和隐藏状态。

LSTM 层的可学习权重包括输入权重 W (InputWeights)、循环权重 R (RecurrentWeights) 以及偏置 b (Bias)。矩阵 W、R 和 b 分别是输入权重、循环权重和每个分量的偏置的串联。该层根据以下方程串联矩阵:

其中 i、f、g、o 分别表示输入门、遗忘门、候选单元和输出门。

时间步 t 处的单元状态由下式给出:

其中 ⊙ 表示哈达玛乘积(向量的按元素乘法)。

时间步 t 处的隐藏状态由下式给出:

其中 σc 表示状态激活函数。默认情况下,lstmLayer 函数使用双曲正切函数 (tanh) 计算状态激活函数。

以下公式说明时间步 t 处的组件。

这些计算中,σg 表示门激活函数。默认情况下,lstmLayer 函数使用 

给出的 sigmoid 函数来计算门激活函数。


三、LSTM网络种类

3.1 分类 LSTM 网络

要创建针对“序列到标签”分类的 LSTM 网络,请创建一个层数组,其中包含一个序列输入层、一个 LSTM 层、一个全连接层、一个 softmax 层和一个分类输出层。

将序列输入层的大小设置为输入数据的特征数量。将全连接层的大小设置为类的数量。您不需要指定序列长度。

对于 LSTM 层,指定隐含单元的数量和输出模式 'last'

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

要针对“序列到序列”分类创建一个 LSTM 网络,请使用与“序列到标签”分类相同的架构,但将 LSTM 层的输出模式设置为 'sequence'

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','sequence')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

3.2 回归 LSTM 网络

要针对“序列到单个”回归创建一个 LSTM 网络,请创建一个层数组,其中包含一个序列输入层、一个 LSTM 层、一个全连接层和一个回归输出层。

将序列输入层的大小设置为输入数据的特征数量。将全连接层的大小设置为响应的数量。不需要指定序列长度。

对于 LSTM 层,指定隐含单元的数量和输出模式 'last'

numFeatures = 12;
numHiddenUnits = 125;
numResponses = 1;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numResponses)
    regressionLayer];

要针对“序列到序列”回归创建一个 LSTM 网络,请使用与“序列到单个”回归相同的架构,但将 LSTM 层的输出模式设置为 'sequence'

numFeatures = 12;
numHiddenUnits = 125;
numResponses = 1;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','sequence')
    fullyConnectedLayer(numResponses)
    regressionLayer];

3.3 视频分类网络

要针对包含图像序列的数据(如视频数据和医学图像)创建一个深度学习网络,请使用序列输入层指定图像序列输入。

要使用卷积层来提取特征,也就是说,要将卷积运算独立地应用于视频的每帧,请使用一个序列折叠层,后跟一个卷积层,然后是一个序列展开层。要使用 LSTM 层从向量序列中学习,请使用一个扁平化层,后跟 LSTM 层和输出层。

inputSize = [28 28 1];
filterSize = 5;
numFilters = 20;
numHiddenUnits = 200;
numClasses = 10;

layers = [ ...
    sequenceInputLayer(inputSize,'Name','input')
    
    sequenceFoldingLayer('Name','fold')
    
    convolution2dLayer(filterSize,numFilters,'Name','conv')
    batchNormalizationLayer('Name','bn')
    reluLayer('Name','relu')
    
    sequenceUnfoldingLayer('Name','unfold')
    flattenLayer('Name','flatten')
    
    lstmLayer(numHiddenUnits,'OutputMode','last','Name','lstm')
    
    fullyConnectedLayer(numClasses, 'Name','fc')
    softmaxLayer('Name','softmax')
    classificationLayer('Name','classification')];

将这些层转换为一个层图,并将序列折叠层的 miniBatchSize 输出连接到序列展开层的对应输入。

lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph,'fold/miniBatchSize','unfold/miniBatchSize');

3.4 更深的 LSTM 网络

可以通过在 LSTM 层之前插入具有输出模式 'sequence' 的额外 LSTM 层来加大 LSTM 网络的深度。为了防止过拟合,可以在 LSTM 层后插入丢弃层。

对于“序列到标签”分类网络,最后一个 LSTM 层的输出模式必须为 'last'

numFeatures = 12;
numHiddenUnits1 = 125;
numHiddenUnits2 = 100;
numClasses = 9;
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits1,'OutputMode','sequence')
    dropoutLayer(0.2)
    lstmLayer(numHiddenUnits2,'OutputMode','last')
    dropoutLayer(0.2)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

对于“序列到序列”分类网络,最后一个 LSTM 层的输出模式必须为 'sequence'

numFeatures = 12;
numHiddenUnits1 = 125;
numHiddenUnits2 = 100;
numClasses = 9;
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits1,'OutputMode','sequence')
    dropoutLayer(0.2)
    lstmLayer(numHiddenUnits2,'OutputMode','sequence')
    dropoutLayer(0.2)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

四、网络层

描述
sequenceInputLayer序列输入层向神经网络输入序列数据并应用数据归一化。

lstmLayer

LSTM 层是一个 RNN 层,该层学习时间序列和序列数据中时间步之间的长期相关性。

bilstmLayer

双向 LSTM (BiLSTM) 层是一个 RNN 层,该层学习时间序列或序列数据的时间步之间的双向长期相关性。当您希望 RNN 在每个时间步从完整时间序列中学习时,这些相关性会很有用。
gruLayerGRU 层是一个 RNN 层,它学习时间序列和序列数据中时间步之间的相关性。
convolution1dLayer一维卷积层将滑动卷积滤波器应用于一维输入。
maxPooling1dLayer一维最大池化层通过将输入划分为一维池化区域并计算每个区域的最大值来执行下采样。
averagePooling1dLayer一维平均池化层通过将输入划分为若干一维池化区域,然后计算每个区域的平均值来执行下采样。
globalMaxPooling1dLayer一维全局最大池化层通过输出输入的时间或空间维度的最大值来执行下采样。
sequenceFoldingLayer序列折叠层将一批图像序列转换为一批图像。使用序列折叠层独立地对图像序列的时间步执行卷积运算。
sequenceUnfoldingLayer序列展开层在序列折叠后还原输入数据的序列结构。
flattenLayer扁平化层将输入的空间维度折叠成通道维度。

wordEmbeddingLayer(Text Analytics Toolbox)

单词嵌入层将单词索引映射到向量。

五、序列处理

5.1 按长度对序列排序

要在填充或截断序列时减少填充或丢弃的数据量,请尝试按序列长度对数据进行排序。要按序列长度对数据进行排序,首先使用 cellfun 对每个序列应用 size(X,2) 来获得每个序列的列数。然后使用 sort 对序列长度进行排序,并使用第二个输出对原始序列重新排序。

sequenceLengths = cellfun(@(X) size(X,2), XTrain);
[sequenceLengthsSorted,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);

5.2 填充序列

如果指定序列长度 'longest',则软件会填充序列,使小批量中的所有序列具有与小批量中的最长序列相同的长度。此选项是默认选项。

5.3 截断序列

如果指定序列长度 'shortest',则软件会截断序列,使小批量中的所有序列具有与该小批量中的最短序列相同的长度。序列中的其余数据被丢弃。

5.4 拆分序列

如果将序列长度设置为整数值,则软件会将小批量中的所有序列填充到小批量中最长序列的长度。然后,软件将每个序列拆分为指定长度的较小序列。如果发生拆分,则软件会创建额外的小批量。如果指定的序列长度没有均分数据的序列长度,则对于包含这些序列的最终时间步的小批量,其长度短于指定的序列长度。

如果将序列长度指定为正整数,则软件会在连续的迭代中处理较小的序列。神经网络会在拆分的序列之间更新 RNN 状态。


六、解决长期依赖问题的常用架构

除了LSTM(长短期记忆)之外,还有其他神经网络架构也可以用于解决长期依赖问题。以下是一些常见的架构:

  • 门控循环单元(GRU):GRU是一种类似于LSTM的循环神经网络架构,旨在解决长期依赖问题。它通过门控机制来控制信息的流动,使得模型能够选择性地保留或遗忘之前时间步的信息。GRU比LSTM有更简单的结构,因此训练速度更快,且参数更少,有助于减少过拟合的风险。
  • 循环递归神经网络(RRNN):RRNN是另一种解决长期依赖问题的神经网络架构。它通过递归连接的方式,使得神经网络能够捕获序列数据中的长期依赖关系。RRNN的递归连接使得每个时间步的输出都依赖于前一个时间步的输出,从而能够处理长期依赖问题。

这些架构在处理长期依赖问题时表现出色,并且在许多应用中取得了显著的成功,如语音识别、机器翻译、图像识别等。选择合适的架构取决于任务的具体需求和数据的特点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1504563.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

音视频按照时长分类小工具

应某用户的需求,编写了这款根据音视频时长分类小工具。 实际效果如下: 显示的是时分秒: 核心代码: MediaInfo MI; if (MI.Open(strPathInput.c_str()) 0){return -1;}_tstring stDuration MI.Get(stream_t::Stream_Audio,0,_T…

【Flink】Flink 的八种分区策略(源码解读)

Flink 的八种分区策略(源码解读) 1.继承关系图1.1 接口:ChannelSelector1.2 抽象类:StreamPartitioner1.3 继承关系图 2.分区策略2.1 GlobalPartitioner2.2 ShufflePartitioner2.3 BroadcastPartitioner2.4 RebalancePartitioner2…

手机APP测试——如何进行安装、卸载、运行?

手机APP测试——主要针对的是安卓( Android )和苹果IOS两大主流操作系统,主要考虑的就是功能性、兼容性、稳定性、易用性、性能等测试,今天先来讲讲如何进行安装、卸载、运行的内容。 一、App安装 1、点击运行APP安装包,检测安装包是否正常; . 2、进入[安装向导]…

Java17 --- SpringCloud之OpenFeign

目录 一、OpenFeign实现服务调用 1.1、创建openfeign微服务 二、Openfeign超时控制 2.1、全局默认配置 2.2、单个微服务配置 三、重试机制 四、替换openfeign默认的HttpClient 五、请求响应压缩 六、日志打印 一、OpenFeign实现服务调用 1.1、创建openfeign微服…

LLM长上下文外推方法

现在的LLM都集中在卷上下文长度了,最新的Claude3已经支持200K的上下文,见:cost-context。下面是一些提升LLM长度外推能力的方法总结: 数据工程 符尧大佬的最新工作:Data Engineering for Scaling Language Models to …

[虚拟机保护逆向] [HGAME 2023 week4]vm

[虚拟机保护逆向] [HGAME 2023 week4]vm 虚拟机逆向的注意点:具体每个函数的功能,和其对应的硬件编码的*长度* 和 *含义*,都分析出来后就可以编写脚本将题目的opcode转化位vm实际执行的指令 :分析完成函数功能后就可以编写脚本输出…

c++ primer plus 笔记 第十六章 string类和标准模板库

string类 string自动调整大小的功能: string字符串是怎么占用内存空间的? 前景: 如果只给string字符串分配string字符串大小的空间,当一个string字符串附加到另一个string字符串上,这个string字符串是以占用…

Spring web开发(入门)

1、我们在执行程序时,运行的需要是这个界面 2、简单的web接口(127.0.0.1表示本机IP) package com.example.demo;import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestCont…

代码学习记录15

随想录日记part15 t i m e : time: time: 2024.03.09 主要内容:今天的主要内容是二叉树的第四部分,主要涉及平衡二叉树的建立;二叉树的路径查找;左叶子之和;找树左下角的值&#xff…

考研复习C语言初阶(4)+标记和BFS展开的扫雷游戏

目录 1. 一维数组的创建和初始化。 1.1 数组的创建 1.2 数组的初始化 1.3 一维数组的使用 1.4 一维数组在内存中的存储 2. 二维数组的创建和初始化 2.1 二维数组的创建 2.2 二维数组的初始化 2.3 二维数组的使用 2.4 二维数组在内存中的存储 3. 数组越界 4. 冒泡…

3.DOM-事件进阶(事件对象、事件委托)

环境对象this 环境对象本质上是一个关键字 this this所在的代码区域不同,代表的含义不同 全局作用域中的this 全局作用域中this代表window对象 局部作用域中的this 在局部作用域中(函数中)this代表window对象 原因是函数调用的时候简写了,函数完整写…

Go语言数据结构(二)堆/优先队列

文章目录 1. container中定义的heap2. heap的使用示例3. 刷lc应用堆的示例 更多内容以及其他Go常用数据结构的实现在这里,感谢Star:https://github.com/acezsq/Data_Structure_Golang 1. container中定义的heap 在golang中的"container/heap"…

[数据集][目标检测]变电站缺陷检测数据集VOC+YOLO格式8307张17类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):8307 标注数量(xml文件个数):8307 标注数量(txt文件个数):8307 标注…

Java8 CompletableFuture异步编程-进阶篇

🏷️个人主页:牵着猫散步的鼠鼠 🏷️系列专栏:Java全栈-专栏 🏷️个人学习笔记,若有缺误,欢迎评论区指正 前言 我们在前面文章讲解了CompletableFuture这个异步编程类的基本用法,…

【操作系统概念】第11章:文件系统实现

文章目录 0.前言11.1 文件系统结构11.2 文件系统实现11.2.1 虚拟文件系统 11.3 分配方法11.3.1 连续分配11.3.2 链接分配11.3. 3 索引分配 11.5 空闲空间管理11.5.1 位图/位向量11.5.2 链表11.5.3 组 0.前言 正如第10章所述,文件系统提供了机制,以在线存…

【数据分享】2000-2022年全国1km分辨率的逐年PM2.5栅格数据(免费获取)

PM2.5作为最主要的空气质量指标,在我们日常研究中非常常用!之前我们给大家分享了2013-2022年全国范围逐日的PM2.5栅格数据(可查看之前的文章获悉详情)! 本次我们给大家带来的是2000-2022年全国范围的逐年的PM2.5栅格数…

树莓派4B Ubuntu20.04 Python3.9安装ROS踩坑记录

问题描述 在使用sudo apt-get update命令更新时发现无法引入apt-pkg,使用python3 -c "import apt_pkg"发现无法引入,应该是因为:20.04的系统默认python是3.8,但是我换成了3.9所以没有编译文件,于是使用sudo update-alte…

K8S - 在任意node里执行kubectl 命令

当我们初步安装玩k8s (master 带 2 nodes) 时 正常来讲kubectl 只能在master node 里运行 当我们尝试在某个 node 节点来执行时, 通常会遇到下面错误 看起来像是访问某个服务器的8080 端口失败了。 原因 原因很简单 , 因为k8s的各个组建&…

思科网络中如何配置标准ACL协议

一、什么是标准ACL协议?有什么作用及配置方法? (1)标准ACL(Access Control List)协议是一种用于控制网络设备上数据流进出的协议。标准ACL基于源IP地址来过滤数据流,可以允许或拒绝特定IP地址范…

微信私信短剧机器人源码

本源码仅提供参考,有能力的继续开发 接口为api调用 云端同步 https://ys.110t.cn/api/ajax.php?actyingshilist 影视搜索 https://ys.110t.cn/api/ajax.php?actsearch&name剧名 每日更新 https://ys.110t.cn/api/ajax.php?actDaily 反馈接口 https://ys.11…