长短期记忆神经网络(LSTM)基础学习与实例:预测序列的未来

news2025/5/15 0:16:57

目录

1. 前言

2. LSTM的基本原理

2.1 LSTM基本结构

2.2 LSTM的计算过程

3. LSTM实例:预测序列的未来

3.1 数据准备

3.2 模型构建

3.3 模型训练

3.4 模型预测

3.5 完整程序预测序列的未来 

4. 总结


1. 前言

在深度学习领域,循环神经网络(RNN)是处理序列数据的重要工具。然而,传统的RNN在处理长序列时常常会遇到梯度消失或梯度爆炸的问题,导致模型无法有效学习长期依赖关系。为了解决这一问题,长短期记忆神经网络(LSTM)应运而生。LSTM通过引入特殊的结构设计,能够有效地捕获序列数据中的长期依赖关系,因此在自然语言处理、时间序列预测等领域取得了显著的成果。

LSTM在许多序列数据处理任务中表现出色,包括但不限于:

  • 自然语言处理:文本生成、机器翻译、情感分析等。

  • 时间序列预测:股票价格预测、天气预报等。

  • 语音识别:将语音信号转换为文字。

  • 视频分析:动作识别、场景理解等。

如果没有RNN的基础,可以去看这篇博客,

《循环神经网络(RNN)基础入门与实践学习:电影评论情感分类任务》

2. LSTM的基本原理

2.1 LSTM基本结构

传统的RNN在处理长序列时,梯度会随着序列长度的增加而逐渐消失或爆炸。这种现象使得RNN难以学习到序列中的长期依赖关系。例如,在处理一段较长的文本时,RNN可能无法将开头的信息有效传递到结尾,导致模型性能受限。

LSTM通过引入门控机制(gate mechanism)解决了这一问题。门控机制可以控制信息的流动,决定哪些信息应该被保留,哪些信息应该被遗忘,从而有效地捕获长期依赖关系。

LSTM的核心结构包括三个门:

  1. 遗忘门(Forget Gate):决定哪些信息应该被遗忘。

  2. 输入门(Input Gate):决定哪些新信息应该被存储到单元状态中。

  3. 输出门(Output Gate):决定哪些信息应该被输出。

此外,LSTM还有一个单元状态(Cell State),用于在时间步之间传递和存储信息。

在实际中,其整体结构如下:

其中蓝色小球里面存放的就是门控结构的神经元 。

2.2 LSTM的计算过程

这里讲的是上图中的蓝色小球内部结构。

在每个时间步,LSTM根据输入数据和前一时刻的隐藏状态,计算三个门的值,并更新单元状态和隐藏状态。具体计算过程如下:

  1. 遗忘门

    其中,ft​ 是遗忘门的输出,σ 是sigmoid激活函数,Wf​ 和 bf​ 是权重和偏置,ht−1​ 是前一时刻的隐藏状态,xt​ 是当前时刻的输入。

  2. 输入门

    其中,it​ 是输入门的输出,C~t​ 是候选单元状态。

  3. 单元状态更新

    其中,Ct​ 是更新后的单元状态。

  4. 输出门

    其中,ot​ 是输出门的输出,ht​ 是当前时刻的隐藏状态。

为了更方便理解,其结构图如下:

从左往右依次为遗忘门,输入门,输出门。 

3. LSTM实例:预测序列的未来

3.1 数据准备

我们以一个简单的时间序列预测任务为例,预测未来某个时间点的值。首先生成一些模拟数据:

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, Dense

# 生成模拟数据
def generate_data(sequence_length=1000):
    x = np.linspace(0, 50, sequence_length)
    y = np.sin(x) + 0.1 * np.random.randn(sequence_length)
    return y

# 准备训练数据
def prepare_data(data, window_size=50):
    X, y = [], []
    for i in range(len(data) - window_size):
        X.append(data[i:i+window_size])
        y.append(data[i+window_size])
    return np.array(X), np.array(y)

# 生成数据
data = generate_data()
window_size = 50
X, y = prepare_data(data, window_size)

# 划分训练集和测试集
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

# 数据形状调整
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

3.2 模型构建

使用Keras构建一个简单的LSTM模型:

# 构建LSTM模型
model = Sequential()
model.add(LSTM(50, return_sequences=True, input_shape=(window_size, 1)))
model.add(LSTM(50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')

# 打印模型结构
model.summary()
  • window_size=50:每个样本包含 50 个时间步。

  • 输入数据的形状为 (样本数, 50, 1)

  • 第一个 LSTM 层的输出形状为 (样本数, 50, 50),因为:

    • 每个时间步输出 50 个特征(由 50 个神经元生成)。

    • return_sequences=True,所以输出了所有 50 个时间步的特征。

如果后续还有一个 LSTM 层,则第二个 LSTM 层的输入形状为 (50, 50)(每个时间步有 50 个特征)。

3.3 模型训练

训练模型并记录训练过程:

# 训练模型
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)

# 绘制训练损失和验证损失
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

3.4 模型预测

使用训练好的模型进行预测,并可视化结果:

# 预测
predictions = model.predict(X_test)

# 绘制真实值和预测值
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='True Values')
plt.plot(predictions, label='Predictions')
plt.title('True Values vs. Predictions')
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.legend()
plt.show()

3.5 完整程序预测序列的未来 

 完整程序如下方便调试:

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, Dense
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# 生成模拟数据
def generate_data(sequence_length=1000):
    x = np.linspace(0, 50, sequence_length)
    y = np.sin(x) + 0.1 * np.random.randn(sequence_length)
    return y

# 准备训练数据
def prepare_data(data, window_size=50):
    X, y = [], []
    for i in range(len(data) - window_size):
        X.append(data[i:i+window_size])
        y.append(data[i+window_size])
    return np.array(X), np.array(y)

# 生成数据
data = generate_data()
window_size = 50
X, y = prepare_data(data, window_size)

# 划分训练集和测试集
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

# 数据形状调整
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

# 构建LSTM模型
model = Sequential()
model.add(LSTM(60, return_sequences=True, input_shape=(window_size, 1)))
model.add(LSTM(50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')

# 打印模型结构
model.summary()

# 训练模型
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)

# 绘制训练损失和验证损失
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 预测
predictions = model.predict(X_test)

# 绘制真实值和预测值
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='True Values')
plt.plot(predictions, label='Predictions')
plt.title('True Values vs. Predictions')
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.legend()
plt.show()

4. 总结

长短期记忆神经网络(LSTM)是一种强大的序列建模工具,能够有效地捕获长期依赖关系。通过引入遗忘门、输入门和输出门,LSTM解决了传统RNN在处理长序列时的梯度消失问题。在本文中,我们详细介绍了LSTM的基本原理和结构,并通过一个时间序列预测的实例展示了如何使用Keras实现LSTM模型。

尽管LSTM在许多任务中表现出色,但它也有一些局限性,例如计算复杂度较高、训练时间较长等。随着深度学习技术的发展,许多改进的变体(如GRU、双向LSTM等)也逐渐被提出。在实际应用中,选择合适的模型需要根据具体任务和数据特点进行权衡。我是橙色小博,关注我,一起在人工智能领域学习进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2327335.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++多继承

可以用多个基类来派生一个类。 格式为: class 类名:类名1,…, 类名n { private: … ; //私有成员说明; public: … ; //公有成员说明; protected: … ; //保护的成员说明; }; class D: public A, protected B, private C { …//派…

【深度学习新浪潮】DeepSeek近期的技术进展及未来动向

一、近期技术进展 模型迭代与性能提升 DeepSeek-V3-0324版本更新:2025年3月24日发布,作为V3的小版本升级,参数规模达6850亿,采用混合专家(MoE)架构,激活参数370亿。其代码能力接近Claude 3.7,数学推理能力显著提升,且在开源社区(如Hugging Face)上线。DeepSeek-R1模…

工业4.0时代下的人工智能新发展

摘要:随着德国工业4.0时代以及中国制造2025的提出,工业智能化的改革的时代正逐渐到来,然而我国整体工业水平仍然处于工业2.0水平。围绕工业4.0中智能工厂、智能生产、智能物流这三大主题,结合国内外研究现状,对人工智能…

监控易一体化运维:高性能与易扩展,赋能运维新高度

在当今数字化时代,云技术、大数据、智慧城市等前沿科技蓬勃发展,企业和城市对 IT 基础设施的依赖程度与日俱增。在这样的大环境下,运维系统的高性能与易扩展性对于保障业务稳定运行和推动发展的关键意义。今天,为大家深入剖析监控…

机器学习stats_linregress

import numpy as np from scipy import stats# r stats.linregress(xs, ys) 是一个用于执行简单线性回归的函数,通常来自 scipy.stats 库。# 具体含义如下:# stats.linregress:执行线性回归分析,拟合一条最佳直线来描述两个变量 …

Linux系统01---指令

目录 学习的方法 Linux 系统介绍 2.1 Unix 操作系统(了解) 2.2 Linux 操作系统(了解) 2.3 Linux 操作系统的主要特性(重点) 2.4 Linux 与 Unix 的区别与联系 2.5 GUN 与 GPL(了解&#…

【蓝桥杯14天冲刺课题单】Day 8

1.题目链接:19714 数字诗意 这道题是一道数学题。 先考虑奇数,已知奇数都可以表示为两个相邻的数字之和,2k1k(k1) ,那么所有的奇数都不会被计入。 那么就需要考虑偶数什么情况需要被统计。根据打表,其实可以发现除了…

DeepSeek 开源的 3FS 如何?

DeepSeek 3FS(Fire-Flyer File System)是一款由深度求索(DeepSeek)于2025年2月28日开源的高性能并行文件系统,专为人工智能训练和推理任务设计。以下从多个维度详细解析其核心特性、技术架构、应用场景及行业影响&…

通过 Docker Swarm 集群探究 Overlay 网络跨主机通信原理

什么是Overlay网络, 用于解决什么问题 ? Overlay网络通过在现有网络之上创建一个虚拟网络层, 解决不同主机的容器之间相互通信的问题 如果没有Overlay网络,实现跨主机的容器通信通常需要以下方法: 端口映射使用宿主机网络模式 这些方法牺牲了容器网络…

HarmonyOS NEXT开发进阶(十四):HarmonyOS应用开发者基础认证试题集汇总及答案解析

文章目录 一、前言二、判断题(134道)三、单选题(210道)四、多选题(123道)五、拓展阅读 一、前言 鸿蒙原生技能学习阶段,通过官方认证的资格十分有必要,在项目实战前掌握基础开发理论…

MSVC编译遇到C2059、C2143、C2059、C2365、C2059等错误的解决方案

MSVC编译时,遇到如下错误: c:\program files (x86)\windows kits\10\include\10.0.18362.0\um\msxml.h(1842): error C2059: 语法错误:“常数” [D:\jenkins_home\workspace\xxx.vcxproj] c:\program files (x86)\windows kits\10\include\10.0.18362.0…

AI重塑云基础设施,亚马逊云科技打造AI定制版IaaS“样板房”

AI正在彻底重塑云基础设施。 IDC最新《2025年IDC MarketScape:全球公有云基础设施即服务(IaaS)报告》显示,AI正在通过多种方式重塑云基础设施,公有云IaaS有望继续保持快速增长,预计2025年全球IaaS的整体规…

Linux系统之systemctl管理服务及编译安装配置文件安装实现systemctl管理服务

目录 一.systemctl 管理服务 1.systemctl管理 2.设置服务卡机自启动或开机不启动 二.编译安装配置文件编写使得可以使用systemctl管理 1、编写配置文件原因 2、添加配置文件实现systemctl管理服务 一.systemctl 管理服务 1.systemctl管理 基本格式: systemc…

【NLP 52、多模态相关知识】

生活应该是美好而温柔的,你也是 —— 25.4.1 一、模态 modalities 常见: 文本、图像、音频、视频、表格数据等 罕见: 3D模型、图数据、气味、神经信号等 二、多模态 1、Input and output are of different modalities (eg: tex…

Element Plus 常用组件

2025/4/1 向全栈工程师迈进!!! 常见Element Plus组件的使用,其文章中“本次我使用到的按钮如下”是我自己做项目时候用到的,记录以加强记忆。阅读时可以跳过。 一、Button按钮 1.1基础按钮 在element plus中提供的按…

2025年优化算法:真菌生长优化算法(Fungal Growth Optimizer,FGO)

真菌生长优化算法(Fungal Growth Optimizer,FGO) 是发表在中科院一区期刊“ARTIFICIAL INTELLIGENCE REVIEW”(IF:6.7)的2025年3月智能优化算法 01.引言 Fungal Growth Optimizer (FGO) 是一种基于真菌生长行为的元启发式优化算法…

阿里通义千问发布全模态开源大模型Qwen2.5-Omni-7B

Qwen2.5-Omni 是一个端到端的多模态模型,旨在感知多种模态,包括文本、图像、音频和视频,同时以流式方式生成文本和自然语音响应。汇聚各领域最先进的机器学习模型,提供模型探索体验、推理、训练、部署和应用的一站式服务。https:/…

论文阅读:基于增强通用深度图像水印的混合篡改定位技术 OmniGuard

一、论文信息 论文名称:OmniGuard: Hybrid Manipulation Localization via Augmented Versatile Deep Image Watermarking作者团队:北京大学发表会议:CVPR2025论文链接:https://arxiv.org/pdf/2412.01615二、动机与贡献 动机: 随着生成式 AI 的快速发展,其在图像编辑领…

深挖 DeepSeek 隐藏玩法·智能炼金术2.0版本

前引:屏幕前的你还在AI智能搜索框这样搜索吗?“这道题怎么写”“苹果为什么红”“怎么不被发现翘课” ,。看到此篇文章的小伙伴们!请准备好你的思维魔杖,开启【霍格沃茨模式】,看我如何更新秘密的【知识炼金…

【新手初学】SQL注入getshell

一、引入 木马介绍: 木马其实就是一段程序,这个程序运行到目标主机上时,主要可以对目标进行远程控制、盗取信息等功能,一般不会破坏目标主机,当然,这也看黑客是否想要搞破坏。 木马类型: 按照功…