【RNN详解】SimpleRNN,LSTM,bi-LSTM的原理及Tensorflow实现

news2025/7/19 4:49:03

目录

  • 1.摘要
  • 2.SimpleRNN
    • 2.1.原理介绍
    • 2.2.代码实现
  • 3.LSTM
    • 3.1.原理介绍
    • 3.2.代码实现
  • 4.LSTM改进—bi_LSTM
  • 5.总结
        • 感觉不错的话,记得点赞收藏啊。

1.摘要

RNN是一种特殊的神经网络结构,它是根据“人的认知是基于过往的经验和记忆”这一观点提出的,它与DNN,CNN不同的是:它不仅考虑前一时刻的输入,而且赋予了网络对前面内容的一种“记忆”功能,对于处理序列数据比较友好。RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。本篇将分别介绍SimpleRNN,LSTM,bi-LSTM。

2.SimpleRNN

2.1.原理介绍

在介绍之前,我们先来认识下最简单的SimpleRNN,如下图,我们先看左边X-U-S-V-O这条路线,这就好比一个全连接神经网络,X为输入数据,U为输入层到隐藏层的权重矩阵,S为隐藏层,V为隐藏层到输出层O的权重矩阵。如果加上W,隐藏层S的值不仅仅取决于输入层X,还取决于上次隐藏层的值与权重矩阵W相乘的值,相信大家都可以理解。按时间先展开,如右侧所示。
在这里插入图片描述
我们可以用如下公式来表示上图的关系:
在这里插入图片描述

2.2.代码实现

接下来,我们使用keras中的SimpleRNN模型对IMDB数据集进行训练,代码如下:

from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
# 作为特征的单词个数
max_features = 10000 
# 在maxlen个单词后截断文本
maxlen = 500 
batch_size = 32
print('Loading data...')
(input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features)
print(len(input_train), 'train sequences')
print(len(input_test), 'test sequences')
print('Pad sequences (samples x time)')
input_train = sequence.pad_sequences(input_train, maxlen=maxlen)
input_test = sequence.pad_sequences(input_test, maxlen=maxlen)
print('input_train shape:', input_train.shape)
print('input_test shape:', input_test.shape)

在这里插入图片描述

from tensorflow.keras.layers import Dense,Embedding,SimpleRNN
from tensorflow.keras.models import Sequential
# 搭建模型
model = Sequential() 
model.add(Embedding(max_features, 32)) 
model.add(SimpleRNN(32)) 
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
history = model.fit(input_train, y_train,
                    epochs=5, 
                    batch_size=128,
                    validation_split=0.2)

在这里插入图片描述

3.LSTM

3.1.原理介绍

接下来,来讲RNN里面的算法之星LSTM,是一种改进后的循环神经网络,可以解决simpleRNN无法处理长距离的依赖问题以及梯度爆炸的问题,目前比较流行。相比于原始的RNN的隐层(hidden state), LSTM增加了一个记忆细胞(如下图),具有选择性记忆功能,可以选择记忆重要信息,过滤掉噪音信息,减轻记忆负担。
在这里插入图片描述
LSTM模型解读:
(1)遗忘门
在这里插入图片描述

遗忘门的输入为Xt和h(t-1),输出是ft,激活函数为sigmoid函数,激活函数使ft的输出范围为0-1之间,当值为0时,后续与C(t-1)相乘时,结果为0,即“遗忘”,因此叫做遗忘门,结果还是十分简单的。
(2)更新门
在这里插入图片描述
更新门输入也是h(t-1)和Xt,输出分别为it(和遗忘门一样)和C`t(更新的信息),对应的激活函数为sigmoid和tanh,sigmoid输出结果为(0,1),tanh输出结果为(-1,1)。
在这里插入图片描述
如上图所示,C(t-1)和ft相乘,it和C’t相乘,前者代表对过去信息选择性的保留,后者代表更新的信息的添加,最后将两部分合并起来,就是新的状态Ct了。
(3)输出门层
在这里插入图片描述
这就是LSTM模型的输出了,Ot是h(t-1)和Xt输入后经sigmoid后的结果,Ct通过tanh缩放后与Ot相乘,就是输出了。

3.2.代码实现

说了半天原理,接下来使用keras中的LSTM模型对IMDB数据集进行训练,代码如下:

from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
# 作为特征的单词个数
max_features = 10000 
# 在maxlen个单词后截断文本
maxlen = 500 
batch_size = 32
print('Loading data...')
(input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features)
print(len(input_train), 'train sequences')
print(len(input_test), 'test sequences')
print('Pad sequences (samples x time)')
input_train = sequence.pad_sequences(input_train, maxlen=maxlen)
input_test = sequence.pad_sequences(input_test, maxlen=maxlen)
print('input_train shape:', input_train.shape)
print('input_test shape:', input_test.shape)

在这里插入图片描述

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation,Flatten
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import LSTM
# 搭建模型
model = Sequential()
model.add(Embedding(max_features, 32))
# 加LSTM
model.add(LSTM(32))
model.add(Dense(units=256,activation='relu' ))
model.add(Dropout(0.2))
model.add(Dense(units=1,activation='sigmoid' ))
model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])

train_history =model.fit(input_train, y_train,batch_size=128, 
                         epochs=5,validation_split=0.2)

在这里插入图片描述
由以上结果,可以看出LSTM的准确率高于SimpleRNN。

4.LSTM改进—bi_LSTM

在单向的循环神经网络中,模型实际上只使用到了“上文”的信息,而没有考虑到“下文”的信息。在实际场景中,预测可能需要使用到整个输入序列的信息。因此,目前业内主流使用的都是双向的循环神经网络。顾名思义,双向循环神经网络结合了序列起点移动的一个循环神经网络和令一个从序列末尾向序列起点移动的循环神经网络。而作为循环神经网络的一种拓展,LSTM 自然也可以结合一个逆向的序列,组成双向长短时记忆网络(Bi-directional Long Short-Term Memory, Bi-LSTM)。Bi-LSTM 目前在自然语言处理领域的应用场景很多,并且都取得了不错的效果。具体原理不在多述,直接上代码看效果。
代码实现:

from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
# 作为特征的单词个数
max_features = 10000 
# 在maxlen个单词后截断文本
maxlen = 500 
batch_size = 32
print('Loading data...')
(input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features)
print(len(input_train), 'train sequences')
print(len(input_test), 'test sequences')
print('Pad sequences (samples x time)')
input_train = sequence.pad_sequences(input_train, maxlen=maxlen)
input_test = sequence.pad_sequences(input_test, maxlen=maxlen)
print('input_train shape:', input_train.shape)
print('input_test shape:', input_test.shape)

在这里插入图片描述

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation,Flatten
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Bidirectional,LSTM

model = Sequential()
model.add(Embedding(max_features, 32))
model.add(Bidirectional(LSTM(32)))
model.add(Dense(units=256,activation='relu' ))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))

# 尝试使用不同的优化器和不同的优化器配置
model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 
model.fit(input_train, y_train,batch_size=128,
          epochs=5,validation_split=0.2)

在这里插入图片描述

5.总结

本篇主要讲解了RNN循环神经网络,首先讲了最简单的SimpleRNN,讲解了其实现原理,并用其实现了IMDB影评情感分析,接下来,通过SimpleRNN引出了LSTM,分析了二者之间的区别,以及LSTM网络结构上的创新,最后介绍了可以双向传播的bi-LSTM,具有更好的记忆功能,其在实际应用中使用广泛。希望对读者的学习有一定的帮助。

感觉不错的话,记得点赞收藏啊。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/347350.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

node.js+vue图书销售租赁书店运营管理系统mysql vscode

图书销售租赁书店运营管理系 前端技术:nodejsvueelementui,视图层其实质就是vue页面,通过编写vue页面从而展示在浏览器中,编写完成的vue页面要能够和控制器类进行交互,从而使得用户在点击网页进行操作时本图书馆租赁管理系统主要分…

面向新手的git实战教程

嗨!我是团子,大家好久不见呀~ 记得之前在网上学习git相关知识时,看到的文章大部分都是讲解git的基本命令有哪些,用处是什么,但是自己真正上手使用git时,仍然无从下手。 所以今天就想从初始化一个git仓库开始…

Quartz 快速入门案例,看这一篇就够了

前言 Quartz 是基于 Java 实现的任务调度框架,对任务的创建、修改、删除、触发以及监控这些操作直接提供了 api,这意味着开发人员拥有最大的操作权,也带来了更高的灵活性。 什么是任务调度? 任务调度指在将来某个特定的时间、固…

rabbitmq topic模式设置#通配符情况下 消费者队列未接收消息问题排查解决

目录说明说明 生产者配置 Exchange:topic_exchange_shcool Routing key:topic.shcool.# 消费者代码配置 Exchange:topic_exchange_shcool Routing key:topic.shcool.user PostConstructpublic void twoRabbitInit() {//声明交换…

基于微信小程序云开发实现考研题库小程序项目(完整版)

今天手把手的带大家实现一款答题类的题库小程序,如果着急的话,可以直接去看文末源码.下载与项目部署。考研题库小程序云开发实战,完整版提供给大家学习。题库小程序,基于云开发的微信答题小程序,软件架构是微信原生小程…

YOLOv8进行改进并训练自定义的数据集

一.训练数据集准备 YOLOv8的训练数据格式与YOLOv5的训练数据格式一致,这一部分可以进行沿用。之前博文有发布VOC标注格式转YOLO标注格式的脚本,有需要可以查看。 二.项目克隆 YOLOv8项目文件可以直接去github上下载zip文件然后解压,也可以直…

Telerik UI for WPF 2023 R1

Telerik UI for WPF 2023 R1 之 WPF 的 Telerik 用户界面,WPF 控件库开发人员信任,快速构建美观、高性能的 WPF 业务应用程序。现在支持 .NET 6 和 7.0。 概述部分背景图像 主要特征 现代专业主题图标,现代专业主题 通过各种受 Office、Wind…

【机器学习】P0 基础与相关名词

机器学习基础与相关名词什么是机器学习监督学习与非监督学习Classification And RegressionClusteringModel and Function什么是机器学习 机器学习是计算机发展到一定阶段的必然产物。相比于19世纪一个房间大小的计算机,当前的计算机更便携,性能更强&am…

开源项目-旅游信息管理系统

哈喽,大家好,今天给大家带来一个开源系统-旅游信息管理系统 前台 地址:http://ip/index 账号:user 密码:123456 后台 地址:http://ip/login 账号:root 密码:123456 功能模块:旅游路线、旅游景点、旅游酒店、旅游车票、旅游保险、旅游策略、订单管理、留言管理、数据…

一篇解决Linux 中的负载高低和 CPU 开销并不完全对应

负载是查看 Linux 服务器运行状态时很常用的一个性能指标。在观察线上服务器运行状况的时候,我们也是经常把负载找出来看一看。在线上请求压力过大的时候,经常是也伴随着负载的飙高。 但是负载的原理你真的理解了吗?我来列举几个问题&#x…

c/c++开发,无可避免的模板编程实践(篇二)

一、开发者需要对模板参数负责 1.1 为您模板参数提供匹配的操作 在进行模板设计时,函数模板或类模板一般只做模板参数(typename T)无关的操作为主,但是也不见得就不会关联模板参数自身的操作,尤其是在一些自定义的数据…

JVM基础学习

JVM分为两个子系统,两个组件一个子系统是Class loader类装载系统,另一个子系统是Execution Engine执行引擎一个组件是Runtime data area 运行时数据区,Native Interface 本地接口Class loader:根据给定的全限定类名来装载class文件到运行时数…

借助docker, 使用verdaccio搭建npm私服

为何要搭建npm私服 搭建npm私服好处多多,网上随便一篇教程搜出来都罗列了诸多好处,譬如: 公司内部开发环境与外网隔离,内部开发的一些库高度隐私不便外传,内网搭建npm服务保证私密性同属内网,可以确保使用npm下载依赖…

RPC技术选型

前言HTTP1.0 & HTTP1.1 & HTTP2.0 & RPCHTTP1.0无法复用连接HTTP1.0 协议时,HTTP 调用还只能是短链接调用,每次发送请求的时候,都需要进行一次TCP的连接,而TCP的连接释放过程又是比较费事的。这种无连接的特性会使得网…

金三银四跳槽季,JAVA面试撸题就来【笑小枫】微信小程序吧~

JAVA面试撸题就来【笑小枫】微信小程序啦~ 疫情已过,金三银四即将到来,小伙伴们是否有跳槽的打算呢?不管有没有,技术不能丢,让我们一起来撸题吧。 博主最近整理了一批面试题,包括JAVA基础、多线程与锁、Red…

媒体邀约电视台对商业活动选题有什么要求?如何邀请电视台报道

传媒如春雨,润物细无声,大家好随着互联网,移动互联网的快速发展,大众在电视上消磨的时间就越来越短了,但是随着新媒体的出现,传统媒体不断的跟进发展,不断打造自己的媒体矩阵,虽然离…

界面控件DevExpress WinForm——轻松构建类Visual Studio UI(一)

DevExpress WinForm拥有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。DevExpress WinForm能完美构建流畅、美观且易于使用的应用程序,无论是Office风格的界面,还是分析处理大批量的业务数据,它都能轻松胜任…

图穷了,来搞一搞(内附源码)

本章继续我们的爬虫教程,爬什么呢 ,还是斗图,娱乐性的东西,为什么要爬? 因为我图库空了,发现这个网址的图库还是很丰富的。 「注意:如下文,是封装后拆分的,所以详情参照…

Word处理控件Aspose.Words功能演示:使用 C# 在电子邮件正文中发送 Word 文档

Aspose.Words 是一种高级Word文档处理API,用于执行各种文档管理和操作任务。API支持生成,修改,转换,呈现和打印文档,而无需在跨平台应用程序中直接使用Microsoft Word。此外,API支持所有流行的Word处理文件…

《精通Spring4.x 企业应用开发实战》第12章 Spring 的事务管理难点剖析

目录标题前言一、DAO 和事务管理的牵绊二、应用分层的迷惑三、事务方法嵌套调用的迷茫(事务传播行为)1.Spring 事务传播机制回顾2.相互嵌套的服务方法四、多线程的困惑1. Spring 通过单实例化 Bean 简化多线程问题2.启动独立线程调用事务方法五、联合军种作战的混乱1.Spring 事…