《深度学习进阶 自然语言处理》第六章:LSTM介绍

news2025/9/24 10:24:22

文章目录

      • 6.1 RNN的问题
        • 6.1.1 RNN的复习
        • 6.1.2 梯度消失和梯度爆炸
        • 6.1.4 梯度爆炸的对策
      • 6.2 梯度消失和LSTM
        • 6.2.1 LSTM的接口
        • 6.2.2 LSTM层的结构
        • 6.2.3 输出门
        • 6.2.4 遗忘门
        • 6.2.5 新的记忆单元
        • 6.2.6 输入门
        • 6.2.7 LSTM的梯度的流动
      • 6.3 使用LSTM的语言模型
        • 6.3.1 LSTM层的多层化
        • 6.3.2 基于Dropout抑制过拟合
        • 6.3.3 权重共享
      • 6.4 总结


之前文章链接:

开篇介绍:《深度学习进阶 自然语言处理》书籍介绍
第一章:《深度学习进阶 自然语言处理》第一章:神经网络的复习
第二章:《深度学习进阶 自然语言处理》第二章:自然语言和单词的分布式表示
第三章:《深度学习进阶 自然语言处理》第三章:word2vec
第四章:《深度学习进阶 自然语言处理》第四章:Embedding层和负采样介绍
第五章:《深度学习进阶 自然语言处理》第五章:RNN通俗介绍

上一章我们介绍了结构比较简单的RNN,存在环路的RNN可以记忆过去的信息,只是效果不好,因为很多情况下它都无法很好地学习到时序数据的长期依赖关系。为了学习到时序数据的长期依赖关系,我们增加了一种名为“门”的结构,具有代表性的有LSTM和GRU等“Gated RNN”。本章将指出上一章的RNN问题,重点介绍LSTM的结构,并揭示它实现“长期记忆”的机制。

6.1 RNN的问题

上一章的RNN不擅长学习时序数据的长期依赖关系,是因为BPTT会发生梯消失和梯度爆炸的问题。本节先回顾一下上一章介绍的RNN层,并通过实例来说明为什么不擅长长期记忆。

6.1.1 RNN的复习

RNN层存在环路,展开来看它是一个在水平方向上延伸的网络,如下图

image-20221118173445887

x t x_{t} xt是输入的时序数据, h t h_{t} ht是输出的隐藏状态,记录过去信息。RNN正是使用了上一刻的隐藏状态所以才可以继承过去的信息。RNN层的处理用计算图来表示的话,如下图:

image-20221118173540476

如上图所示,RNN层的正向传播进行的计算包含矩阵乘积、矩阵加法和基于激活函数tanh的变换。这是我们上一章介绍的RNN层,下面我们看一下这其中存在的问题,主要是关于长期记忆的问题。

6.1.2 梯度消失和梯度爆炸

语言模型的任务是根据已经出现的单词预测下一个将要出现的单词。RNN层通过向过去传递“有意义的梯度”,理论上是能够学习时间方向上的依赖关系。但是随着时间的回溯,这个梯度在中途不可避免的变小(梯度消失)或者变大(梯度爆炸),权重参数不能正常更新,RNN层也就无法学习长期的依赖关系。

根据上一节的第二张图展示的RNN层正向传播的计算,可知反向传播的梯度依次流经tanh、矩阵加和(“+”)和MatMul(矩阵乘积)运算。"+"的反向传播梯度值不变。当y = tanh(x)时,它的导数是 d y d x = 1 − y 2 \frac{dy}{dx} = 1- y^{2} dxdy=1y2,分别画在图上如下图:

image-20221118173607702

图中的虚线可以看出,tanh(x)导数的值小于1.0,随着x远离0,它的值变小。这意味着,反向传播的梯度经过tanh节点时,它的值会越来越小。

接着我们看一下MatMul(矩阵乘积)节点,简单起见,我忽略tanh节点,这样RNN层的反向传播的梯度仅取决于矩阵乘积运算。通过下面代码来观察梯度大小的变化。

import numpy as np
import matplotlib.pyplot as plt

N = 2  # mini-batch的大小
H = 3  # 隐藏状态向量的维数
T = 20  # 时序数据的长度

dh = np.ones((N, H))
# np.random.seed(6)  # 为了复现,固定随机数种子
Wh = np.random.randn(H, H)

norm_list = []
for t in range(T):
    dh = np.dot(dh, Wh.T)
    norm = np.sqrt(np.sum(dh**2)) / N
    norm_list.append(norm)
print(norm_list)

多运行几次上面的代码可以发现,梯度的大小(norm值)随时间步长呈指数级增加/减小。为什么会出现这样的指数级变化呢?因为矩阵Wh被反复乘了T次。如果Wh是标量,当Wh大于1或小于1时,梯度呈指数级变化。但是Wh是矩阵,此时矩阵的奇异值作为指标。如果奇异值的最大值大于1,梯度可能呈指数级增加(必要非充分条件),最大值小于1,可以判断梯度会呈指数级减小。

6.1.4 梯度爆炸的对策

解决梯度爆炸有既定的方法,称为梯度裁剪(gradients clipping)。伪代码如下:
i f ∣ ∣ g ^ ∣ ∣ ≥ t h r e s h o l d : g ^ = t h r e s h o l d ∣ ∣ g ^ ∣ ∣ g ^ if \quad ||\hat{g}|| \geq threshold:\\ \hat{g} = \frac{threshold}{||\hat{g}||}\hat{g} ifg^threshold:g^=g^thresholdg^
这里假设将神经网络用到的所有参数的梯度整合成一个,并用符号 g ^ \hat{g} g^表示,阈值设为threshold。当梯度的L2范数 ∣ ∣ g ^ ∣ ∣ ||\hat{g}|| g^大于或等于阈值,就按照上面的方法修正梯度,这就是梯度裁剪。

6.2 梯度消失和LSTM

为了解决RNN的梯度消失问题,人们提出了诸多Gated RNN框架,具有代表性的有LSTM和GRU,本节我们只关注LSTM,从结构入手,阐明它为何不会引起梯度消失。

6.2.1 LSTM的接口

我们先以简略图比较一下LSTM与RNN的接口:

image-20221118173655283

LSTM比RNN的接口多了一个路径c,这个c称为记忆单元,是LSTM专用的记忆部门,其特点是仅在LSTM层内部接收和传递数据,不向其他层输出。

6.2.2 LSTM层的结构

c t c_{t} ct存储了时刻t时LSTM的记忆,可以认为其中保存了过去到时刻t的所有必要信息,基于 c t c_{t} ct向外部输出隐藏状态 h t h_{t} ht。二者的关系是 h t h_{t} ht c t c_{t} ct经过tanh函数变换后的记忆单元, c t c_{t} ct h t h_{t} ht是按元素应用tanh函数,所以它们的元素个数相同。

进入下一项之前,我们需要特别说明的是,LSTM的门控并非只能“开/合“,还能通过”开合程度“来控制信息的流量。sigmoid函数的输出范围是0.0 ~ 1.0,所以用于求门的开合程度。

6.2.3 输出门

刚才介绍 h t h_{t} ht c t c_{t} ct关系时只提到应用了tanh函数,这里其实还考虑了对 t a n h ( c t ) tanh(c_{t}) tanh(ct)施加了“门控”,调整了 t a n h ( c t ) tanh(c_{t}) tanh(ct)的各个元素的重要程度。由于这里的门管理了 h t h_{t} ht的输出,所以称为输出门(output gate)。公式如下:
o = σ ( x t W x ( o ) + h t − 1 W h ( o ) + b ( o ) ) o = \sigma(x_{t}W_{x}^{(o)}+h_{t-1}W_{h}^{(o)}+b^{(o)}) o=σ(xtWx(o)+ht1Wh(o)+b(o))
o o o t a n h ( c t ) tanh(c_{t}) tanh(ct)的对应元素乘积,得到 h t h_{t} ht

tanh的输出是 -1.0 ~ 1.0的实数,可以认为是表示某种被编码信息的强弱程度。sigmoid的输出是 0.0 ~ 1.0的实数,表示数据流出的比例。因此,多数情况下,门使用sigmoid函数作为激活函数,包含实质信息的数据使用tanh函数作为激活函数。

6.2.4 遗忘门

在记忆单元 c t − 1 c_{t-1} ct1上添加一个忘记不必要信息的门,称为遗忘门(forget gate)。公式如下:
f = σ ( x t W x ( f ) + h t − 1 W h ( f ) + b ( f ) ) f = \sigma(x_{t}W_{x}^{(f)}+h_{t-1}W_h^{(f)}+b^{(f)}) f=σ(xtWx(f)+ht1Wh(f)+b(f))
f f f c t − 1 c_{t-1} ct1对应元素的乘积求得传递到当前时刻的记忆信息(删除了应该忘记的信息)。

6.2.5 新的记忆单元

除了删除上一时刻的记忆单元中应该忘记的信息,我们还需要向这个记忆单元添加应当记住的新信息,为此添加新的tanh节点,公式如下:
g = t a n h ( x t W x ( g ) + h t − 1 W h ( g ) + b ( g ) ) g=tanh(x_tWx^{(g)}+h_{t-1}W_h^{(g)}+b^{(g)}) g=tanh(xtWx(g)+ht1Wh(g)+b(g))
g g g表示新的信息。

6.2.6 输入门

对上面的新的记忆单元 g g g添加门,控制信息流量(对新信息添加权重从而对信息进行取舍),称为输入门(input gate)。公式如下:
i = σ ( x t W x ( i ) + h t − 1 W h ( i ) + b ( i ) ) i = \sigma(x_tW_x^{(i)}+h_{t-1}W_h^{(i)}+b^{(i)}) i=σ(xtWx(i)+ht1Wh(i)+b(i))
然后,将 i i i g g g的对应元素的乘积作为最终的新记忆添加到记忆单元中。

LSTM的整个内部处理如下图所示:

image-20221118173741677

6.2.7 LSTM的梯度的流动

上面介绍完LSTM的结构,观察记忆单元 c c c的反向传播我们就会知道它为什么不会引起梯度消失。

记忆单元 c c c的反向传播仅流过"➕"和“✖️”节点。“➕”节点将上游传来的梯度原样流出,所以梯度没有变化。“✖️”节点进行的是对应元素的乘积计算(RNN是相同权重矩阵重复多次的矩阵乘积),而且每次乘积的门值不同,门值调节着梯度的大小,所以LSTM的记忆单元不会(难以)梯度消失。

6.3 使用LSTM的语言模型

这里实现的语言模型和上一章几乎一样,唯一区别是这次用Time LSTM层替换了Time RNN层,如下图:

image-20221118173825969

当前的使用LSTM层的语言模型有3点需要改进的地方,分别是:LSTM层的多层化、基于Dropout抑制过拟合和权重共享。

6.3.1 LSTM层的多层化

叠加多个LSTM层可以提高语言模型的精度,前一个LSTM层的隐藏状态是后面一个LSTM层的输入。具体要叠加几层,需要根据问题的复杂程度和数据规模来确定。

image-20221118173857385

6.3.2 基于Dropout抑制过拟合

通过加深层,可以提高模型的表现力,但是这样往往会导致过拟合。抑制过拟合的常用方法有增加训练数据、降低模型的复杂度,以及对模型复杂度给予正则化的惩罚。Dropout也是一种正则化。考虑到噪声的积累,最好在垂直方向上插入Dropout层,这样无论沿时间方向(水平方向)前进多少,信息都不会丢失。常规的Dropout不适合用在时间方向上,但是"变分Dropout"可以,它的机制是同一层的Dropout使用相同的mask(决定是否传递数据的随机布尔值),因为mask被固定,信息的损失方式也被固定,所以可以避免常规Dropout发生的指数级信息损失。

image-20221118173936374

6.3.3 权重共享

Embedding层和Affine层的权重共享,在不影响精度的提高的同时,可以大大减少学习的参数的数量,而且还能抑制过拟合。

6.4 总结

本章的主题是Gated RNN,我们先指出上一章的简单RNN中存在的梯度消失/爆炸问题,说明了作为替代层的Grated RNN(LSTM、GRU等)的有效性。介绍了使用LSTM层创建的语言模型,以及模型的优化。

下一章我们将介绍如何使用语言模型生成文本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/15791.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

乐华娱乐欲重返上市:毛利率走低,上半年利润下滑,韩庚为股东

撰稿|汤汤 来源|贝多财经 宣布暂缓香港IPO计划不到3个月,乐华娱乐(HK:02306)欲再度回归。 11月17日,港交所披露的信息显示,乐华娱乐集团(YH Entertainment Group,简称“乐华娱乐”&#xff0…

谈谈主数据管理的概念、原则、标准和指南

1主数据的定义和关键概念 1.1什么是主数据 主数据是不同业务领域的公共信息,并在多个业务流程中使用。主数据通常描述参与事务或事件的事物。示例包括有关课程、学生或雇员的信息。 参考数据通常用于对其他数据(如状态代码)进行分类,或与组织边界以外的信息(如国家列表)相…

ctfhub -afr -1 2 3

afr-1 打开题目链接 默认的传参数据为 ?phello 更改一下试试看 ?p111 无回显 ?pflag 回显了 no no no 想到了 php任意文件读取 ?pphp://filter/readconvert.base64-encode/resourceflag 回显出数据 应该是base64 拿去解码 得到flag afr-2 打开题目链接 查看源代码…

【第一阶段:java基础】第8章:面向对象编程高级-1(P333-P393)static、main、代码块、单例设计模式

本系列博客是韩顺平老师java基础课的课程笔记,B站:课程链接,吐血推荐的一套全网最细java教程,获益匪浅! 韩顺平P333-P393类变量和类方法类变量/静态变量类方法/静态方法理解main方法语法代码块注意事项好处细节单例设计…

Java接口

什么是接口? 在Java中,接口可以看成是多个类的公共规范,是一种引用数据类型。 使用关键字interface来定义接口 interface IRunning {void run(); }在创建接口时,接口的命名一般以大写字母I开头,接口的命名一般使用形…

【数据结构】栈和队列

文章目录栈和队列栈栈的概念及结构栈的实现初始化栈入栈出栈获取栈顶元素获取栈中有效元素个数判断栈是否为空销毁栈括号匹配问题队列队列的概念及结构队列的实现初始化队列队尾入队列对头出队列获取队头元素获取队尾元素销毁队列判断队列是否为空栈和队列 栈 栈的概念及结构…

带你初识微服务

博客主页:踏风彡的博客 博主介绍:一枚在学习的大学生,希望在这里和各位一起学习。 所属专栏:SpringCloud 文章创作不易,期待各位朋友的互动,有什么学习问题都可在评论区留言或者私信我,我会尽我…

同花顺_代码解析_技术指标_L

本文通过对同花顺中现成代码进行解析,用以了解同花顺相关策略设计的思想 目录 LH_LYDG LH_猎鹰歼狐 LHBLX LHSJ LHTZ LHXJ LH猎狐雷达 LOF净值 LWR LH_LYDG 猎鹰渡关(检测大盘顶部) 指标用法: (1&#xff…

进程间的信号

目录 一.信号入门 1.1概念 1.2信号发送与记录 1.3信号的处理方式 二.产生信号的方式 2.1通过终端按键产生 2.2通过系统函数向进程发信号 2.3由软件条件产生信号 2.4由硬件异常产生信号 三.阻塞信号 3.1信号相关概念 3.2信号在内核的表示 3.3sigset_t: 3.4…

工具及方法 - 使用Total Commander来查找重名文件

我只是一个Total Commander的轻度使用者,主要使用的是打开多个窗口,可以方便的以标签形式切换。 还有,这个软件是免费的,只是免费版打开时多一步,要输入个数字验证。 今天在使用一个SDK时,要包含进很多头文…

论文阅读笔记《Locality Preserving Matching》

核心思想 该文提出一种基于局部保持的特征匹配方法(LPM)。其核心思想是对于一个正确匹配点,其邻域范围内的其他匹配点与对应目标点之间的变换关系,应该和正确的匹配点保持一致,而错误匹配点,则应该有较大的…

第一章《初学者问题大集合》第6节:IntelliJ IDEA的下载与安装

当完成了Java开发环境之后,各位读者就可以开始编写第一个Java程序了。可是应该在哪里写程序呢?早期的开发者们都是用纯文本编辑工具编写Java程序,并且在命令行窗口中编译和运行Java程序。时至今日,我们早已远离了那个程序开发的“…

CentOs程序环境准备

1. MySQL的安装启动 选择指定操作系统指定版本的mysql进行下载 MySQL :: Download MySQL Community Serverhttps://dev.mysql.com/downloads/mysql/5.7.html#downloads 选择复制下载链接 回到终端,执行此命令下载 wget https://dev.mysql.com/get/Downloads/MyS…

举个栗子~Tableau 技巧(244):用和弦图(Chord diagram)呈现数据关系

关于和弦图 和弦图(Chord diagram)常用来表示数据之间的相互关系。数据点沿着圆圈分布,通过点和点之间相互连接的弧线来呈现相互之间的关系。和弦图从视觉上来说比较美观,数据呈现又很直观,所以深受数据粉喜爱。 之前…

【Linux初阶】Linux调试器-gdb使用 | gdb的 l/b/info/d/r/n/s/bt/finish/p/(un)display/q

🌟hello,各位读者大大们你们好呀🌟 🍭🍭系列专栏:【Linux初阶】 ✒️✒️本篇内容:gdb使用相关背景知识,gdb的使用(打断点、查断点、消断点、调试运行、查看对应变量&…

【python拼图游戏】图片自选,来挑战一下自己的极限吧~

嗨害大家好鸭!我是小熊猫❤ 拼图的画面多以自然风光、建筑物以及一些为人所熟识的图案的为题材。 城堡和山峦是两类传统的主题, 不过任何图画和影像都可以用做拼图的素材。 有一些公司还提供将私人摄影作品制成拼图的服务。 今天我小熊猫就给带来py…

FFmpeg5.1 解码rtsp 并用OpenCV 播放

RTSP 连接过程如下图 看下实际过程中FFmpeg 的日志情况: [tcp 0000014CC3256D40] No default whitelist set [tcp 0000014CC3256D40] Original list of addresses: [tcp 0000014CC3256D40] Address ::1 port 8554 [tcp 0000014CC3256D40] Address 127.0.0.1 po…

使用BP神经网络、RBF神经网络以及PSO优化的RBF神经网络对数据进行预测(Matlab代码实现)

🍒🍒🍒欢迎关注🌈🌈🌈 📝个人主页:我爱Matlab 👍点赞➕评论➕收藏 养成习惯(一键三连)🌻🌻🌻 🍌希…

Keysight是德科技E5061B网络分析仪-安泰测试

E5061B ENA系列网络分析仪从5Hz 至3GHz提供了广泛的频率范围。它不仅支持一般的射频应用(例如滤波器或放大器测量等),还支持低频应用(例如直流至直流转换器环路增益测量)。因此,它是所有实验台上进行网络分析的最重要工具。 拥有E5061B,您就…

Selenium4 新特性

一、Selenium4 简介 Selenium是一个综合性项目,包含一系列的工具和库,支持Web浏览器的各种自动化操作: 软件测试爬虫领域RPA领域优点: 开源:https://github.com/SeleniumHQ兼容性: Chrome、FireFox、Edeg、IE、Opera、Safari支持多种编程语言:Java、Python、C#、Ruby、…