Python----循环神经网络(BiLSTM:双向长短时记忆网络)

news2025/6/9 17:45:16

一、LSTM 与 BiLSTM对比

1.1、LSTM

        LSTM(长短期记忆网络) 是一种改进的循环神经网络(RNN),专门解决传统RNN难以学习长期依赖的问题。它通过遗忘门、输入门和输出门来控制信息的流动,保留重要信息并丢弃无关内容,从而有效处理长序列数据。LSTM的核心是细胞状态,它像一条传送带,允许信息在不同时间步之间稳定传递,避免梯度消失或爆炸,适用于时间序列预测、语音识别等任务。

1.2、BiLSTM 

        BiLSTM(双向长短期记忆网络) 在LSTM的基础上增加反向处理层,同时捕捉过去和未来的上下文信息。前向LSTM按时间顺序处理序列,后向LSTM逆序处理,最终结合两个方向的输出,增强模型对全局上下文的理解。BiLSTM在自然语言处理任务(如机器翻译、命名实体识别)中表现优异,但计算成本更高。它特别适合需要双向信息交互的场景,如语义理解、情感分析等。 

        BiLSTM结构包含两个方向的LSTM网络:一个正向(forward)LSTM和一 个反向(backward)LSTM。

        这两个方向的LSTM在模型训练过程中分别处理输入序列,最后的隐藏状态 由这两个方向的LSTM拼接而成。这样的结构使得模型能够同时考虑到输入 序列中每个位置的过去和未来信息,更全面地捕捉序列中的上下文信息。

        如下面这个情感分类的例子,正向的LSTM按照从左到右的顺序处理“我”、 “爱”、“你”,反向的LSTM按照从右到左的顺序处理“你”、“爱”、“我”,然后 将两个LSTM的最后一个隐藏层拼接起来再经过softmax等处理得到分类结果。

        举一个例子,如一句话“我今天很开心,因为我考试考了 100 分”要做情感 分类,LSTM只能从左到右的看,因此在看到“很开心”这个关键词时它获得 的只有上文的信息,而BiLSTM是双向的因此也能看到“因为我考试考了 100 分”这一部分,而这一部分对应最终结果是否准确有很大的帮助。 

特征LSTMBiLSTM
方向性单向(仅过去信息)双向(过去和未来信息)
计算复杂度较低较高(约2倍)
典型应用时间序列预测、语言模型文本分类、序列标注、机器翻译
内存需求较少较多

13、优势 

BiLSTM相对于单向LSTM具有以下优势:

        能够捕捉到输入序列中每个位置的过去和未来信息,更全面地捕捉序列 中的上下文信息。

        可以更好地处理长距离的依赖关系。

        在许多自然语言处理任务中都取得了良好的效果。

二、库函数-LSTM

torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)

LSTM — PyTorch 2.7 documentation

参数描述
input_size输入 x 中预期特征的数量
hidden_size处于隐藏状态 h 的特征数量
num_layers循环层数。例如,设置意味着将两个 LSTM 堆叠在一起以形成一个堆叠的 LSTM。 第二个 LSTM 接收第一个 LSTM 的输出,并且 计算最终结果。默认值:1num_layers=2
bias偏置如果 ,则层不使用 b_ih 和 b_hh 的偏差权重。 违约:FalseTrue
batch_first 如果 ,则提供输入和输出张量 作为 (batch, seq, feature) 而不是 (seq, batch, feature)。 请注意,这不适用于隐藏状态或单元格状态。请参阅 Inputs/Outputs 部分了解详细信息。违约:TrueFalse
dropout 如果为非零,则在每个 除最后一层外的 LSTM 层,其 dropout 概率等于 。默认值:0dropout
bidirectional

如果 ,则变为双向 LSTM。违约:TrueFalse

 

  • weight_ih_l[k]_reverse – 类似于 weight_ih_l[k] 的相反方向。 仅当 时存在。bidirectional=True

  • weight_hh_l[k]_reverse – 类似于 weight_hh_l[k] 的相反方向。 仅当 时存在。bidirectional=True

  • bias_ih_l[k]_reverse – 类似于 bias_ih_l[k] 的相反方向。 仅当 时存在。bidirectional=True

  • bias_hh_l[k]_reverse – 类似于 bias_hh_l[k] 的相反方向。 仅当 时存在。bidirectional=True

proj_size

如果 ,将使用 LSTM 和相应大小的投影。默认值:0> 0

import torch
import numpy as np
from torch import nn
 
# 1.字符输入
text = "In Beijing Sarah bought a basket of apples In Guangzhou Sarah bought a basket of bananas"
 
torch.manual_seed(1)
 
# 3.数据集划分
input_seq = [text[:-1]]
output_seq = [text[1:]]
print("input_seq:", input_seq)
# print("output_seq:", output_seq)
 
# 4.数据编码:one-hot
chars = set(text)
chars = sorted(chars)
# print("chars:", chars)
# {" ":0, "a":1 }
char2int = {char: ind for ind, char in enumerate(chars)}
# print("char2int:", char2int)
# {0:" ", 1: "a"}
int2char = dict(enumerate(chars))
 
# 将字符转成数字编码
input_seq = [[char2int[char] for char in seq] for seq in input_seq]
# print("input_seq:", input_seq)
output_seq = [[char2int[char] for char in seq] for seq in output_seq]
 
# one-hot 编码,pytorch的RNN的输入张量的填充
def one_hot_encode(seq, bs, seq_len, size):
    features = np.zeros((bs, seq_len, size), dtype=np.float32)
    for i in range(bs):
        for u in range(seq_len):
            features[i, u, seq[i][u]] = 1.0
    return torch.tensor(features, dtype=torch.float32)
 
input_seq = one_hot_encode(input_seq, 1, len(text)-1, len(chars))
output_seq = torch.tensor(output_seq, dtype=torch.long).view(-1)
print("output_seq:", output_seq)
 
# 5.定义前向模型
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, out_size):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        self.bilstm1 = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(hidden_size * 2, out_size)
 
    def forward(self, x):
        out, hidden = self.bilstm1(x)
        x = out.contiguous().view(-1, self.hidden_size * 2)
        x = self.fc1(x)
        return x, hidden
 
model = Model(len(chars), 32, len(chars))
 
# 6.定义损失函数和优化器
cri = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
# 7.开始迭代
epochs = 1000
for epoch in range(1, epochs+1):
    output, hidden = model(input_seq)
    loss = cri(output, output_seq)
 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # 8.显示频率设置
    if epoch == 0 or epoch % 50 == 0:
        print(f"Epoch [{epoch}/{epochs}], Loss {loss:.4f}")
 
# print("input_seq.shape:", input_seq.shape)
# print("hidden.shape:", hidden.shape)
# print("output.shape:", output.shape)
# print("input_w:", model.rnn1.weight_ih_l0.shape)
 
# 预测下面几个字符
input_text = "In Beijing Sarah bought a basket of"  # re
to_be_pre_len = 20
 
for i in range(to_be_pre_len):
    chars = [char for char in input_text]
    # print(chars)
    character = np.array([[char2int[c] for c in chars]])
    character = one_hot_encode(character, 1, character.shape[1], 23)
    character = torch.tensor(character, dtype=torch.float32)
 
    out, hidden = model(character)
    char_index = torch.argmax(out[-1]).item()
    input_text += int2char[char_index]
print("预测到的:", input_text)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2405652.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux系统编程-DAY10(TCP操作)

一、网络模型 1、服务器/客户端模型 (1)C/S:client server (2)B/S:browser server (3)P2P:peer to peer 2、C/S与B/S区别 (1)客户端不同&#…

基于eclipse进行Birt报表开发

Birt报表开发最终实现效果: 简洁版的Birt报表开发实现效果,仅供参考! 可动态获取采购单ID,来打印出报表! 下面开始Birt报表开发教程: 首先:汉化的eclipse及Birt值得拥有:至少感觉上…

GPU虚拟化

引言 现有如下环境(注意相关配置:只有一个k8s节点,且该节点上只有一张GPU卡): // k8s版本 $ kubectl version Client Version: version.Info{Major:"1", Minor:"22", GitVersion:"v1.22.7&…

LabVIEW工业级多任务实时测控系统

采用LabVIEW构建了一套适用于工业自动化领域的多任务实时测控系统。系统采用分布式架构,集成高精度数据采集、实时控制、网络通信及远程监控等功能,通过硬件与软件的深度协同,实现对工业现场多类型信号的精准测控,展现 LabVIEW 在…

破解HTTP无状态:基于Java的Session与Cookie协同工作指南

HTTP协议自身是属于“无状态”协议 无状态是指:默认情况下,HTTP协议的客户端和服务器之间的这次通信,和下次通信之间没有直接的关系 但在实际开发中,我们很多时候是需要知道请求之间的关联关系的 上述图中的令牌,通常就…

JS 事件流机制详解:冒泡、捕获与完整事件流

JS 事件流机制详解:冒泡、捕获与完整事件流 文章目录 JS 事件流机制详解:冒泡、捕获与完整事件流一、DOM 事件流基本概念二、事件捕获 (Event Capturing)特点代码示例 三、事件冒泡 (Event Bubbling)特点代码示例 四、完整事件流示例HTML 结构JavaScript…

算法专题七:分治

快排 1.颜色分类 题目链接:75. 颜色分类 - 力扣(LeetCode) class Solution {public void swap(int[] nums, int i, int j){int t = nums[i];nums[i] = nums[j];nums[j] = t;}public void sortColors(int[] nums) {int left=-1 ,i=0 ,right=nums.length;while(i<right){i…

Vue中虚拟DOM的原理与作用

绪论 首先我们先了解&#xff0c;DOM&#xff08;Document Object Model&#xff0c;文档对象模型&#xff09; 是浏览器对 HTML/XML 文档的结构化表示&#xff0c;它将文档解析为一个由节点&#xff08;Node&#xff09;和对象组成的树形结构&#xff08;称为 DOM 树&#xf…

使用vue3+ts+input封装上传组件,上传文件显示文件图标

效果图&#xff1a; 代码 <template><div class"custom-file-upload"><div class"upload"><!-- 显示已选择的文件 --><div class"file-list"><div v-for"(item, index) in state.filsList" :key&q…

【Linux】Ubuntu 创建应用图标的方式汇总,deb/appimage/通用方法

Ubuntu 创建应用图标的方式汇总&#xff0c;deb/appimage/通用方法 对于标准的 Ubuntu&#xff08;使用 GNOME 桌面&#xff09;&#xff0c;desktop 后缀的桌面图标文件主要保存在以下三个路径&#xff1a; 当前用户的桌面目录&#xff08;这是最常见的位置&#xff09;。所…

LangGraph--Agent工作流

Agent的工作流 下面展示了如何创建一个“计划并执行”风格的代理。 这在很大程度上借鉴了 计划和解决 论文以及Baby-AGI项目。 核心思想是先制定一个多步骤计划&#xff0c;然后逐项执行。完成一项特定任务后&#xff0c;您可以重新审视计划并根据需要进行修改。 般的计算图如…

Spring Boot 常用注解面试题深度解析

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;笑口常开&#x1f7ea;生日快乐⬛早点睡觉 &#x1f4d8;博主相关 &#x1f7e7;博主信息&#x1f7e8;博客首页&#x1f7eb;专栏推荐&#x1f7e5;活动信息 文章目录 Spring Boot 常用注解面试题深度解析一、核心…

Linux系统的CentOS7发行版安装MySQL80

文章目录 前言Linux命令行内的”应用商店”安装CentOS的安装软件的yum命令安装MySQL1. 配置yum仓库2. 使用yum安装MySQL3. 安装完成后&#xff0c;启动MySQL并配置开机自启动4. 检查MySQL的运行状态 MySQL的配置1. 获取MySQL的初始密码2. 登录MySQL数据库系统3. 修改root密码4.…

408第一季 - 数据结构 - 栈与队列

栈 闲聊 栈是一个线性表 栈的特点是后进先出 然后是一个公式 比如123要入栈&#xff0c;一共有5种排列组合的出栈 栈的数组实现 这里有两种情况&#xff0c;&#xff0c;一个是有下标为-1的&#xff0c;一个没有 代码不用看&#xff0c;真题不会考 栈的链式存储结构 L ->…

【RTP】Intra-Refresh模式下的 H.264 输出,RTP打包的方式和普通 H.264 流并没有本质区别

对于 Intra-Refresh 模式下的 H.264 输出,RTP 打包 的方式和普通 H.264 流并没有本质区别:你依然是在对一帧一帧的 NAL 单元进行 RTP 分包,只不过这些 NAL 单元内部有部分宏块是 “帧内编码” 而已。下面分步骤说明: 1. 原理回顾:RFC 6184 H.264 over RTP 按照 RFC 6184 …

Redis实战-消息队列篇

前言&#xff1a; 讲讲做消息队列遇到的问题。 今日所学&#xff1a; 异步优化消息队列基于stream实现异步下单 1. 异步优化 1.1 需求分析 1.1.1 现有下单流程&#xff1a; 1.查询优惠劵 2.判断是否是秒杀时间&#xff0c;库存是否充足 3.实现一人一单 在这个功能中&…

(三)Linux性能优化-CPU-CPU 使用率

CPU使用率 user&#xff08;通常缩写为 us&#xff09;&#xff0c;代表用户态 CPU 时间。注意&#xff0c;它不包括下面的 nice 时间&#xff0c;但包括了 guest 时间。nice&#xff08;通常缩写为 ni&#xff09;&#xff0c;代表低优先级用户态 CPU 时间&#xff0c;也就是进…

佰力博科技与您探讨材料介电性能测试的影响因素

1、频率依赖性 材料的介电性能通常具有显著的频率依赖性。在低频下&#xff0c;偶极子的取向极化占主导&#xff0c;介电常数较高&#xff1b;而在高频下&#xff0c;偶极子的取向极化滞后&#xff0c;导致介电常数下降&#xff0c;同时介电损耗增加。例如&#xff0c;VHB4910…

K8S认证|CKS题库+答案| 4. RBAC - RoleBinding

目录 4. RBAC - RoleBinding 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作&#xff1a; 1&#xff09;、切换集群 2&#xff09;、查看SA和role 3&#xff09;、编辑 role-1 权限 4&#xff09;、检查role 5&#xff09;、创建 role和 rolebinding 6&#xff0…

React 新项目

使用git bash 创建一个新项目 建议一开始就创建TS项目 原因在Webpack中改配置麻烦 编译方法:ts compiler 另一种 bable 最好都配置 $ create-react-app cloundmusic --template typescript 早期react项目 yarn 居多 目前npm包管理居多 目前pnpm不通用 icon 在public文件夹中…