长短期记忆(LSTM)网络模型

news2025/6/6 12:03:00

一、概述

  长短期记忆(Long Short-Term Memory,LSTM)网络是一种特殊的循环神经网络(RNN),专门设计用于解决传统 RNN 在处理长序列数据时面临的梯度消失 / 爆炸问题,能够有效捕捉长距离依赖关系。其核心在于引入记忆细胞(Cell State)和门控机制(Gate Mechanism),通过控制信息的流动来实现对长期信息的存储与遗忘。

二、模型原理

  LSTM 由记忆细胞和三个门控单元(遗忘门、输入门、输出门)组成,每个门控单元通过 sigmoid 激活函数输出 0 到 1 之间的数值,表示允许信息通过的程度(0 表示完全禁止,1 表示完全允许)。

1. 记忆细胞状态

  记忆细胞状态就像一条信息传输的 “高速公路”,它贯穿整个 LSTM 网络,负责在不同时间步之间传递信息。信息在记忆细胞状态中传递时,可以相对稳定地保留较长时间,避免了传统 RNN 中信息容易丢失的问题。遗忘门和输入门共同作用于记忆细胞状态,遗忘门决定删除哪些旧信息,输入门决定添加哪些新信息 ,从而实现对记忆细胞状态的更新。

2. 输入门

  输入门负责处理当前时刻的输入信息,决定哪些新的信息会被添加到记忆细胞状态中。它利用 sigmoid 函数输出一个值,用于控制新信息的 “准入程度”。同时,输入内容通过 tanh 函数生成一个候选值向量,这个向量包含了可能要添加到记忆细胞状态中的新信息。最后,将 sigmoid 函数的输出与 tanh 函数生成的候选值向量相乘,得到实际要添加到记忆细胞状态中的信息。

3. 遗忘门

  遗忘门决定了上一时刻记忆细胞状态中哪些信息会被保留到当前时刻。它接收上一时刻的隐藏状态和当前时刻的输入,通过一个 sigmoid 激活函数输出一个 0 到 1 之间的数值。这个数值就像一把 “钥匙”,数值越接近 1,表示上一时刻的该部分信息被保留的程度越高;数值越接近 0,则表示该部分信息被遗忘的程度越高。例如,在处理一段文字序列时,如果之前的内容与当前句子的主题关联不大,遗忘门就会降低这些信息的保留程度。

4. 输出门

  输出门根据当前记忆细胞状态和隐藏状态,决定最终的输出。它首先使用 sigmoid 函数得到一个控制输出的向量。然后,对记忆细胞状态进行 tanh 处理,将处理后的记忆细胞状态与 sigmoid 函数的输出向量相乘,从而得到 LSTM 单元的最终输出。
一个典型LSTM的单元结构为

在这里插入图片描述
  也就是说,对每个LSTM单元,都有四个输入、一个输出,这四个输入也就是对同一组输入数据的线性组合,只是组合了不同参数。具体的计算过程图示为

在这里插入图片描述
  显然,相较于传统的网络结构,LSTM具有四倍的参数量。

三、优势与局限

1. 优势

  LSTM 的门控机制使其在处理长序列数据时,能够有效保留和更新信息,避免梯度消失和梯度爆炸问题,从而学习到长距离的依赖关系,在许多序列数据处理任务中取得了优异的成绩。此外,LSTM 的结构具有较好的通用性,可以适应多种不同类型的序列数据处理任务。

2. 局限

  由于其结构相对复杂,包含多个门和大量参数,训练过程通常需要更多的计算资源和时间,并且容易出现过拟合问题。同时,LSTM 在解释性方面相对较差,难以直观地理解模型是如何做出决策的。

四、应用领域

1. 自然语言处理

  在自然语言处理任务中,LSTM 被广泛应用于文本分类、机器翻译、语音识别、问答系统等。例如,在机器翻译中,LSTM 可以将源语言句子的语义信息编码成固定长度的向量,然后通过解码过程将其转换为目标语言句子;在语音识别中,LSTM 能够处理语音信号中的时间序列信息,将语音转换为文字。

2. 时间序列预测

  LSTM 在时间序列预测领域表现出色,如股票价格预测、天气预测、电力负荷预测等。由于 LSTM 能够有效捕捉时间序列中的长期依赖关系,相比传统方法,它可以更准确地预测未来趋势。例如,在股票价格预测中,LSTM 可以分析历史股价数据中的复杂模式,预测未来股价走势。

3. 其他领域

  此外,LSTM 还在视频分析、生物信息学等领域得到应用。在视频分析中,LSTM 可以处理视频帧序列,实现动作识别、视频内容理解等任务;在生物信息学中,LSTM 可用于基因序列分析,预测基因功能等。

五、Python实现示例

(环境:Python 3.11,PyTorch 2.4.0)

import matplotlib
matplotlib.use('TkAgg')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

# 设置matplotlib的字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 'SimHei' 是黑体,也可设置 'Microsoft YaHei' 等
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号


# 设置随机种子,保证结果可复现
torch.manual_seed(42)
np.random.seed(42)


# 生成示例数据(正弦波序列)
def generate_sequence(length, freq=0.1):
    """生成正弦波序列作为示例数据"""
    x = np.linspace(0, 2 * np.pi * freq, length)
    return np.sin(x)


def create_sequences(data, seq_length):
    """将数据转换为序列和对应目标值的形式"""
    xs, ys = [], []
    for i in range(len(data) - seq_length):
        x = data[i:i + seq_length]
        y = data[i + seq_length]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)


# 生成数据
seq_length = 10
data = generate_sequence(1000)
x, y = create_sequences(data, seq_length)

# 转换为PyTorch张量
x_tensor = torch.FloatTensor(x).view(-1, seq_length, 1)
y_tensor = torch.FloatTensor(y).view(-1, 1)

# 创建数据加载器
dataset = TensorDataset(x_tensor, y_tensor)
train_size = int(0.8 * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


# 定义LSTM模型
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # LSTM层
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        # 全连接层
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # 前向传播LSTM
        out, _ = self.lstm(x, (h0, c0))

        # 只取序列中的最后一个时间步的输出
        out = self.fc(out[:, -1, :])
        return out


# 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel(input_size=1, hidden_size=50, num_layers=1, output_size=1).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


# 训练模型
def train_model(model, train_loader, criterion, optimizer, epochs=100):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # 清零梯度
            optimizer.zero_grad()

            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # 反向传播和优化
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}')


# 评估模型
def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            predictions.extend(outputs.cpu().numpy())
            actuals.extend(targets.cpu().numpy())
    return np.array(predictions), np.array(actuals)


# 训练模型
train_model(model, train_loader, criterion, optimizer, epochs=50)

# 评估模型
predictions, actuals = evaluate_model(model, test_loader)

# 可视化结果
plt.figure(figsize=(10, 6))
plt.plot(actuals, label='Actual Values')
plt.plot(predictions, label='Predicted Values')
plt.title('LSTM预测结果')
plt.xlabel('样本')
plt.ylabel('值')
plt.legend()
plt.show()


在这里插入图片描述
在这里插入图片描述
  示例实现了一个基本的 LSTM 模型,用于预测正弦波序列的下一个值。主要包括以下几个部分:

  数据生成:创建一个正弦波序列,并将其转换为适合 LSTM 训练的序列格式。
  模型定义:定义了一个包含 LSTM 层和全连接层的模型,用于处理序列数据并输出预测结果。
  训练过程:使用均方误差损失函数和 Adam 优化器训练模型。
  评估和可视化:评估模型性能并可视化预测结果与实际值的对比。

  可以通过修改参数如seq_length(序列长度)、hidden_size(LSTM 隐藏层大小)、num_layers(LSTM 层数)等来调整模型,也可以将此框架应用于其他序列预测任务。



End.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2398689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSS3美化页面元素

1. 字体 <span>标签 字体样式⭐ 字体类型&#xff08;font-family&#xff09; 字体大小&#xff08;font-size&#xff09; 字体风格&#xff08;font-style&#xff09; 字体粗细&#xff08;font-weight&#xff09; 字体属性&#xff08;font&#xff09; 2. 文本 文…

WPS 利用 宏 脚本拆分 Excel 多行文本到多行

文章目录 WPS 利用 宏 脚本拆分 Excel 多行文本到多行效果需求背景&#x1f6e0; 操作步骤代码实现代码详解使用场景注意事项总结 WPS 利用 宏 脚本拆分 Excel 多行文本到多行 在 Excel 工作表中&#xff0c;我们经常遇到一列中包含多行文本&#xff08;用换行符分隔&#xff…

AI“实体化”革命:具身智能如何重构体育、工业与未来生活

近年来&#xff0c;人工智能&#xff08;AI&#xff09;技术的飞速发展正在重塑各行各业&#xff0c;而具身智能&#xff08;Embodied AI&#xff09;作为AI领域的重要分支&#xff0c;正逐渐从实验室走向现实应用。具身智能的核心在于让AI系统具备物理实体&#xff0c;能够与环…

R语言基础| 创建数据集

在R语言中&#xff0c;有多种数据类型&#xff0c;用以存储和处理数据。每种数据类型都有其特定的用途和操作函数&#xff0c;使得R语言在处理各种数据分析任务时非常灵活和强大&#xff1a; 向量&#xff08;Vector&#xff09;: 向量是R语言中最基本的数据类型&#xff0c;它…

Centos7搭建zabbix6.0

此方法适用于zabbix6以上版本zabbix6.0前期环境准备&#xff1a;Lamp&#xff08;linux httpd mysql8.0 php&#xff09;mysql官网下载位置&#xff1a;https://dev.mysql.com/downloads/mysql/Zabbix源码包地址&#xff1a;https://www.zabbix.com/cn/download_sourcesZabbix6…

Docker 部署前后端分离项目

1.Docker 1.1 什么是 Docker &#xff1f; Docker 是一种开源的 容器化平台&#xff0c;用于开发、部署和运行应用程序。它通过 容器&#xff08;Container&#xff09; 技术&#xff0c;将应用程序及其依赖项打包在一个轻量级、可移植的环境中&#xff0c;确保应用在不同计算…

云游戏混合架构

云游戏混合架构通过整合本地计算资源与云端能力&#xff0c;形成了灵活且高性能的技术体系&#xff0c;其核心架构及技术特征可概括如下&#xff1a; 一、混合架构的典型模式 分层混合模式‌ 前端应用部署于公有云&#xff08;如渲染流化服务&#xff09;&#xff0c;后端逻辑…

【小红书】API接口,获取笔记核心数据

小红书笔记核心数据API接口详解 - 深圳小于科技提供专业数据服务 深圳小于科技&#xff08;官网&#xff1a;https://www.szlessthan.com&#xff09;推出的小红书笔记核心数据API接口&#xff0c;为开发者提供精准的笔记互动数据分析能力&#xff0c;助力内容运营与商业决策。…

会议室钥匙总丢失?换预约功能的智能门锁更安全

在企业日常运营中&#xff0c;会议室作为重要的沟通与协作场所&#xff0c;其管理效率与安全性直接影响着企业的运作顺畅度。然而&#xff0c;传统会议室管理方式中钥匙丢失、管理不便等问题频发&#xff0c;给企业带来了不少困扰。近期&#xff0c;某企业引入了启辰智慧预约系…

Redis底层数据结构之跳表(SkipList)

SkipList是Redis有序结合ZSet底层的数据结构&#xff0c;也是ZSet的灵魂所在。与之相应的&#xff0c;Redis还有一个无序集合Set&#xff0c;这两个在底层的实现是不一样的。 标准的SkipList&#xff1a; 跳表的本质是一个链表。链表这种结构虽然简单清晰&#xff0c;但是在查…

Ubuntu安装Docker命令清单(以20.04为例)

在你虚拟机上完成Ubuntu的下载后打开终端&#xff01;&#xff01;&#xff01; Ubuntu安装Docker终极命令清单&#xff08;以20.04为例&#xff09; # 1. 卸载旧版本&#xff08;全新系统可跳过&#xff09; sudo apt-get remove docker docker-engine docker.io containerd …

HarmonyOS Next 弹窗系列教程(2)

HarmonyOS Next 弹窗系列教程&#xff08;2&#xff09; 上一章节我们讲了自定义弹出框 (openCustomDialog)&#xff0c;那对于一些简单的业务场景&#xff0c;不一定需要都是自定义&#xff0c;也可以使用 HarmonyOS Next 内置的一些弹窗效果。比如&#xff1a; 名称描述不依…

中小企业搭建网站选择虚拟主机还是云服务器?华为云有话说

这是一个很常见的问题&#xff0c;许多小企业在搭建网站时都会面临这个选择。虚拟主机和云服务器都有各自的优缺点&#xff0c;需要根据自己的需求和预算来决定。 虚拟主机是指将一台物理服务器分割成多个虚拟空间&#xff0c;每个空间都可以运行一个网站。虚拟主机的优点是价格…

使用 HTML + JavaScript 在高德地图上实现物流轨迹跟踪系统

在电商行业蓬勃发展的今天&#xff0c;物流信息查询已成为人们日常生活中的重要需求。本文将详细介绍如何基于高德地图 API 利用 HTML JavaScript 实现物流轨迹跟踪系统的开发。 效果演示 项目概述 本项目主要包含以下核心功能&#xff1a; 地图初始化与展示运单号查询功能…

19-项目部署(Linux)

Linux是一套免费使用和自由传播的操作系统。说到操作系统&#xff0c;大家比较熟知的应该就是Windows和MacOS操作系统&#xff0c;我们今天所学习的Linux也是一款操作系统。 我们作为javaEE开发工程师&#xff0c;将来在企业中开发时会涉及到很多的数据库、中间件等技术&#…

html基础01:前端基础知识学习

html基础01&#xff1a;前端基础知识学习 1.个人建立打造 -- 之前知识的小总结1.1个人简历展示1.2简历信息填写页面 1.个人建立打造 – 之前知识的小总结 1.1个人简历展示 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8&qu…

【RoadRunner】自动驾驶模拟3D场景构建 | 软件简介与视角控制

&#x1f4af; 欢迎光临清流君的博客小天地&#xff0c;这里是我分享技术与心得的温馨角落 &#x1f4af; &#x1f525; 个人主页:【清流君】&#x1f525; &#x1f4da; 系列专栏: 运动控制 | 决策规划 | 机器人数值优化 &#x1f4da; &#x1f31f;始终保持好奇心&…

基于RK3576+FPGA芯片构建的CODESYS软PLC Linux实时系统方案,支持6T AI算力

基于RK3576芯片构建的CODESYS软PLC Linux实时系统方案&#xff0c;结合了异构计算架构与工业实时控制技术&#xff0c;主要特点如下&#xff1a; 一、硬件架构设计 ‌异构多核协同‌ ‌Cortex-A72四核‌&#xff08;2.3GHz&#xff09;&#xff1a;处理运动轨迹规划、AI视觉等…

适配器模式:让不兼容接口协同工作

文章目录 1. 适配器模式概述2. 适配器模式的分类2.1 类适配器2.2 对象适配器 3. 适配器模式的结构4. C#实现适配器模式4.1 对象适配器实现4.2 类适配器实现 5. 适配器模式的实际应用场景5.1 第三方库集成5.2 遗留系统集成5.3 系统重构与升级5.4 跨平台开发 6. 类适配器与对象适…

DDP与FSDP:分布式训练技术全解析

DDP与FSDP:分布式训练技术全解析 DDP(Distributed Data Parallel)和 FSDP(Fully Sharded Data Parallel)均为用于深度学习模型训练的分布式训练技术,二者借助多 GPU 或多节点来提升训练速度。 1. DDP(Distributed Data Parallel) 实现原理 数据并行:把相同的模型复…