从认识AI开始-----解密LSTM:RNN的进化之路

news2025/6/5 3:47:16

前言

我在上一篇文章中介绍了 RNN,它是一个隐变量模型,主要通过隐藏状态连接时间序列,实现了序列信息的记忆与建模。然而,RNN在实践中面临严重的“梯度消失”与“长期依赖建模困难”问题:

  • 难以捕捉相隔很远的时间步之间的关系
  • 隐状态在不断更新中容易遗忘早期信息。

为了解决这些问题,LSTM(Long Short-Term Memory) 网络于 1997 年被 Hochreiter等人提出,该模型是对RNN的一次重大改进。


一、LSTM相比RNN的核心改进

接下来,我们通过对比RNN、LSTM,来看一下具体的改进:

模型特点优势缺点
RNN单一隐藏转态,时间步传递结构简答容易造成梯度消失/爆炸,对长期依赖差
LSTM多门控机制 + 单独的“记忆单元”解决长距离依赖问题,保留长期信息结构复杂,计算开销大

通过对比,我们可以发现,其实LSTM的核心思想是:引入了一个专门的“记忆单元”,在通过门控机制对信息进行有选择的保留、遗忘与更新


二、LSTM的核心结构

LSTM的核心结构如下图所示:

 如图可以轻松的看出,LSTM主要由门控机制和候选记忆单元组成,对于每个时间步,LSTM都会进行以下操作:

1. 忘记门

忘记门F_t主要的作用是:控制保留多少之前的记忆:

F_t=\sigma(X_t@W_{xf}+H_{t-1}@W_{hf}+b_f)

2. 输入门

输入门I_t主要的作用是:决定当前输入中哪些信息信息被写入记忆:

I_t=\sigma(X_t@W_{xi}+H_{t-1}@W_{hi}+b_i)

3. 候选记忆单元

\tilde C_t=tanh(X_t@W_{xc}+H_{t-1}@W_{hc}+b_c)

4. 输出门

输出门O_t的作用是:决定是是否使用隐状态:

O_t=\sigma(X_t@W_{xo}+H_{t-1}@W_{ho}+b_o)

5. 真正记忆单元

记忆单元( C_t )用于长期存储信息,解决RNN容易遗忘的问题:

C_t=F_t*C_{t-1}+I_t*\tilde C_{t}

7. 隐藏转态

H_t=O_t*tanh(C_t)

LSTM引入了专门的记忆单元 C_t  ,长期存储信息,解决了传统RNN容易遗忘的问题。


三、手写LSTM

通过上面的介绍,我们现在已经知道了LSTM的实现原理,现在,我们试着手写一个LSTM核心层:

首先,初始化需要训练的参数:

import torch
import torch.nn as nn
import torch.nn.functional as F

def params(input_size, output_size, hidden_size):
    
    W_xi, W_hi, b_i = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)
    W_xf, W_hf, b_f = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)
    W_xo, W_ho, b_o = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)
    W_xc, W_hc, b_c = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)
    W_hq = torch.randn(hidden_size, output_size) * 0.1
    b_q = torch.zeros(output_size)
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]
    for param in params:
        param.requires_grad = True
    return params

接着,我们需要初始化0时刻的隐藏转态:

import torch

def init_state(batch_size, hidden_size):
    return (torch.zeros((batch_size, hidden_size)), torch.zeros((batch_size, hidden_size)))

然后, 就是LSTM的核心操作:

import torch
import torch.nn as nn
def lstm(X, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for x in X:
        I = torch.sigmoid(torch.mm(x, W_xi) + torch.mm(H, W_hi) + b_i)
        F = torch.sigmoid(torch.mm(x, W_xf) + torch.mm(H, W_hf) + b_f)
        O = torch.sigmoid(torch.mm(x, W_xo) + torch.mm(H, W_ho) + b_o)
        C_tilde = torch.tanh(torch.mm(x, W_xc) + torch.mm(H, W_hc) + b_c)
        C = F * C + I * C_tilde
        H = O * torch.tanh(C)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=1), (H, C)
         

四、使用Pytroch实现简单的LSTM

在Pytroch中,已经内置了lstm函数,我们只需要调用就可以实现上述操作:

import torch
import torch.nn as nn

class mylstm(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(mylstm, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, x, h0, c0):
        out, (hn, cn) = self.lstm(x, h0, c0)
        out = self.fc(out)
        return out, (hn, cn)

# 示例
input_size = 10
hidden_size = 20
output_size = 10
batch_size = 1
seq_len = 5
num_layer = 1 # lstm堆叠层数

h0 = torch.zeros(num_layer, batch_size, hidden_size)
c0 = torch.randn(num_layer, batch_size, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)

model = mylstm(input_size=input_size, hidden_size=hidden_size, output_size=output_size)

out, _ = model(x, (h0, c0))
print(out.shape)

总结

在现实中,LSTM的实际应用场景很多,比如语言模型、文本生成、时间序列预测、情感分析等长序列任务重,这是因为相比于RNN而言,LSTM能够更高地捕捉长期依赖,而且也更好的缓解了梯度消失问题;但是由于LSTM引入了三个门控机制,导致参数量比RNN要多,训练慢。

总的来说,LSTM是对传统RNN的一次革命性升级,引入门控机制和记忆单元,使模型能够选择性地记忆与遗忘,从而有效地捕捉长距离依赖。尽管LSTM近年来Transformer所取代,但LSTM依然是理解深度学习序列模型不可绕开的一环,有时在其他任务上甚至优于Transformer。


如果小伙伴们觉得本文对各位有帮助,欢迎:👍点赞 | ⭐ 收藏 |  🔔 关注。我将持续在专栏《人工智能》中更新人工智能知识,帮助各位小伙伴们打好扎实的理论与操作基础,欢迎🔔订阅本专栏,向AI工程师进阶!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2396842.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode0513. 找树左下角的值-meidum

1 题目:找树左下角的值 官方标定难度:中 给定一个二叉树的 根节点 root,请找出该二叉树的 最底层 最左边 节点的值。 假设二叉树中至少有一个节点。 示例 1: 输入: root [2,1,3] 输出: 1 示例 2: 输入: [1,2,3,4,null,5,6,null,null,7]…

命令行式本地与服务器互传文件

文章目录 1. 背景2. 传输方式2.1 SCP 协议传输2.2 SFTP 协议传输 3. 注意 命令行式本地与服务器互传文件 1. 背景 多设备协同工作中,因操作系统的不同,我们经常需要将另外一个系统中的文件传输到本地PC进行浏览、编译。多设备文件互传,在嵌入…

LabelImg: 开源图像标注工具指南

LabelImg: 开源图像标注工具指南 1. 简介 LabelImg 是一个图形化的图像标注工具,使用 Python 和 Qt 开发。它是目标检测任务中最常用的标注工具之一,支持 PASCAL VOC 和 YOLO 格式的标注输出。该工具开源、免费,并且跨平台支持 Windows、Lin…

计算机网络 TCP篇常见面试题总结

目录 TCP 的三次握手与四次挥手详解 1. 三次握手(Three-Way Handshake) 2. 四次挥手(Four-Way Handshake) TCP 为什么可靠? 1. 序列号与确认应答(ACK) 2. 超时重传(Retransmis…

树欲静而风不止,子欲养而亲不待

2025年6月2日,13~26℃,一般 待办: 物理2 、物理 学生重修 职称材料的最后检查 教学技能大赛PPT 遇见:使用通义创作了一副照片,很好看!都有想用来创作自己的头像了! 提示词如下: A b…

Kotlin中的::操作符详解

Kotlin提供了::操作符,用于创建对类或对象的成员(函数、属性)的引用。这种机制叫做成员引用(Member Reference)。这是Kotlin高阶函数和函数式编程的重要组成部分。 简化函数传递 在Java中,我们这样传方法: list.forEach(item -> System.…

深入详解编译与链接:翻译环境和运行环境,翻译环境:预编译+编译+汇编+链接,运行环境

目录 一、翻译环境和运行环境 二、翻译环境:预编译编译汇编链接 (一)预处理(预编译) (二)编译 1、词法分析 2、语法分析 3、语义分析 (三)汇编 (四&…

定时任务:springboot集成xxl-job-core(二)

定时任务实现方式: 存在的问题: xxl-job的原理: 可以根据服务器的个数进行动态分片,每台服务器分到的处理数据是不一样的。 1. 多台机器动态注册 多台机器同时配置了调度器xxl-job-admin之后,执行器那里会有多个注…

DeviceNET转EtherCAT网关:医院药房自动化的智能升级神经中枢

在现代医院药房自动化系统中,高效、精准、可靠的设备通信是保障患者用药安全与效率的核心。当面临既有支持DeviceNET协议的传感器、执行器(如药盒状态传感器、机械臂限位开关)需接入先进EtherCAT高速实时网络时,JH-DVN-ECT疆鸿智能…

一:UML类图

一、类的设计 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 学习设计模式的第一步是看懂UML类图,类图能直观的表达类、对象之间的关系,这将有助于后续对代码的编写。 类图在软件设计及应用框架前期设计中是不可缺少的一部分,它的主要成分包括:类名、…

Java 中 MySQL 索引深度解析:面试核心知识点与实战

🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Java 中 MySQL 索引深度解析:面试…

设计模式之结构型:装饰器模式

装饰器模式(Decorator Pattern) 定义 装饰器模式是一种​​结构型设计模式​​,允许​​动态地为对象添加新功能​​,而无需修改其原始类。它通过将对象包装在装饰器类中,以​​组合代替继承​​,实现功能的灵活扩展(如 Java I/O …

MySQL安装及启用详细教程(Windows版)

MySQL安装及启用详细教程(Windows版) 📋 概述 本文档将详细介绍MySQL数据库在Windows系统下的下载、安装、配置和启用过程。 📥 MySQL下载 官方下载地址 官方网站: https://dev.mysql.com/downloads/社区版本: https://dev.my…

【HarmonyOS Next之旅】DevEco Studio使用指南(二十九) -> 开发云数据库

目录 1 -> 开发流程 2 -> 创建对象类型 3 -> 添加数据条目 3.1 -> 手动创建数据条目文件 3.2 -> 自动生成数据条目文件 4 -> 部署云数据库 1 -> 开发流程 云数据库是一款端云协同的数据库产品,提供端云数据的协同管理、统一的数据模型和…

批量导出CAD属性块信息生成到excel——CAD C#二次开发(插件实现)

本插件可实现批量导出文件夹内大量dwg文件的指定块名的属性信息到excel,效果如下: 插件界面: dll插件如下: 使用方法: 1、获取此dll插件。 2、cad命令行输入netload ,加载此dll(要求AutoCAD&…

Goreplay最新版本的安装和简单使用

一:概述 Gor 是一个开源工具,用于捕获实时 HTTP 流量并将其重放到测试环境中,以便使用真实数据持续测试您的系统。它可用于提高对代码部署、配置更改和基础设施更改的信心。简单易用。 项目地址:buger/goreplay: GoReplay is an …

Android Studio 解决报错 not support JCEF 记录

问题:Android Studio 安装Markdown插件后,报错not support JCEF不能预览markdown文件。 原因:Android Studio不是新装,之前没留意IDE自带的版本是不支持JCEF的。 解决办法: 在菜单栏选中Help→Find Action&#xff…

sigmastar实现SD卡升级

参考文章:http://wx.comake.online/doc/DD22dk2f3zx-SSD21X-SSD22X/customer/development/software/Px/zh/sys/P3/usb%20&%20sd%20update.html#21-sd 1、构建SD卡升级包 在project下make image完成后使用make_sd_upgrade_sigmastar.sh脚本打包SD卡升级包。 ./make_sd_up…

kafka学习笔记(三、消费者Consumer使用教程——配置参数大全及性能调优)

本章主要介绍kafka consumer的配置参数及性能调优的点,其kafka的从零开始的安装到生产者,消费者的详解介绍、源码及分析及原理解析请到博主kafka专栏 。 1.消费者Consumer配置参数 配置参数默认值含义bootstrap.servers无(必填)…

【论文笔记】Transcoders Find Interpretable LLM Feature Circuits

Abstract 机制可解释性(mechanistic interpretability)的核心目标是路径分析(circuit analysis):在模型中找出与特定行为或能力对应的稀疏子图。 然而,MLP 子层使得在基于 Transformer 的语言模型中进行细粒度的路径分析变得困难。具体而言,…