RNN lstm

news2025/8/10 7:51:47

文章目录

  • 什么是RNN
    • RNN工作原理图解
    • 多种RNN形态
    • RNN的公式原理
  • pytorch RNN 样例
  • RNN实践
  • lstm 案例
    • 踩坑 module ‘torchtext.data‘ has no attribute ‘Field
    • 踩坑 en_core_web_sm
    • 相关教程

什么是RNN

阅读ytb视频莫烦: 什么是循环神经网络 RNN (深度学习)? What is Recurrent Neural Networks (deep learning)?。

RNN工作原理图解

RNN是怎样工作的?假如在t时刻,神经网络输入x(t),神经网络会计算状态s(t),并输出y(t)。

到t+1时刻,输入为x(t+1),神经网络会根据s(t)和s(t+1)来输出y(t+1)。

多种RNN形态

RNN经过适当组合,有不同的输入和输出形式,从而能解决不同领域的问题。比如输入一张图片,输出描述它的一段话。


或者输入一段中文,输出一段英文。

在这里插入图片描述

RNN的公式原理

传统RNN的实现主要是下图中的红框部分。


用公式表达如下:


其中 o t o_t ot并不是最重要的部分,而输出 s 1 , s 2 , . . . , s t s_1, s_2, ..., s_t s1,s2,...,st是关键。

pytorch RNN 样例

根据pytorch官方文档,torch.nn.RNN.html可知,RNN计算隐藏层的方式如下,相当于分别对上个隐藏层输出 h t − 1 h_{t-1} ht1 x t x_t xt作线性转换,相加后经过激活层tanhrelu

我们结合代码案例,对这个API做简化版的解释

import torch
from torch import nn

rnn = nn.RNN(10, 20)
input = torch.randn(5, 3, 10)
h0 = torch.randn(1, 3, 20)
output, hn = rnn(input, h0)

# torch.Size([5, 3, 20])
# torch.Size([1, 3, 20])
# tensor(True)
print(output.shape)
print(hn.shape)
print(torch.all(output[-1] == hn))

构造函数的参数简化版解释如下:

  • input_size – The number of expected features in the input x
  • hidden_size – The number of features in the hidden state h

所以 nn.RNN(10, 20):的意思是,输入的每个单词长度为10,输出的每个向量长度为20。另外,batch_first 参数默认为False,它会影响输入的维度顺序,当为False时,输入维度是(seq, batch, feature),为True时是(batch, seq, feature)。

输入的参数简化版解释如下:

输入: input, h_0

  • input: batch_first 默认为False时,维度为(seq, batch, feature)
  • h_0: 在本例默认其它参数情况下,维度为(1, batch, feature)

所以,代码块中的inputh0变量分别代表各个时刻t的输入,以及初始的隐藏层状态。

输出的参数简化版解释如下:

输出: output, h_n

  • output: 在其它参数默认时,维度为(sequence, batch, H o u t H_{out} Hout)。它代表每个时刻t的隐藏层输出 h 1 , h 2 , . . . , h T h_1, h_2, ..., h_T h1,h2,...,hT
  • h_n: 在其它参数默认时,维度为(1, batch, H o u t H_{out} Hout),它代表最后时刻T的隐藏层输出 h t h_t ht

所以,代码块中output的维度是[5,3,20],其中batch是3,序列长度为5(有5个单词)。 而hn的维度是[1,3,20],每个batch都取了 h T h_T hT。同时,print(torch.all(output[-1] == hn))输出为True说明hn就是output[-1],hn是最后时刻T的隐藏层输出。

总结而言,我们将一个batch为3,每句话有5个单词,每个单词向量长度为10的tensor输入到rnn。它将输出batch为3,每句话有5个单词,每个单词向量长度为20的变量output,其中hn是和output[-1]等价。在下图中标注了每个变量对应图中的部分。

RNN实践

pytorch教程SEQUENCE MODELS AND LONG SHORT-TERM MEMORY NETWORKS 的模型代码很清晰,架构完整,但是缺乏训练数据集,training_data变量的数据很匮乏。

https://towardsdatascience.com/lstm-text-classification-using-pytorch-2c6c657f8fc0
训练集在 https://www.kaggle.com/datasets/nopdev/real-and-fake-news-dataset

自定义Dataset的用法,可学习知识点pytorch dataset

lstm 案例

参考swarnabha/pytorch-text-classification-torchtext-lstm,讲了一个用LSTM训练kaggle数据集的案例。

第一步,使用scikit-learn的工具方法切分pandas dataframe形式的数据集

# split data into train and validation
train_df, valid_df = train_test_split(train)
print(train_df.head())
print(valid_df.head())

第二步,设置tokenize策略

TEXT = data.Field(tokenize = 'spacy', include_lengths = True)
LABEL = data.LabelField(dtype = torch.float)

利用Field,将Pandas dataframe包装成torchtext dataset

fields = [('text',TEXT), ('label',LABEL)]
train_ds, val_ds = DataFrameDataset.splits(fields, train_df=train_df, val_df=valid_df)

第三步,构建词库,对单词作one-hot编码。

TEXT.build_vocab(train_ds,
                 max_size = MAX_VOCAB_SIZE,
                 vectors = 'glove.6B.200d',
                 unk_init = torch.Tensor.zero_)

LABEL.build_vocab(train_ds)

第四步,切分数据集

train_iterator, valid_iterator = data.BucketIterator.splits(
    (train_ds, val_ds),
    batch_size = BATCH_SIZE,
    sort_within_batch = True,
    device = device)

最后,循环读取批次,将embedding送入lstm网络。

for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_iterator)

踩坑 module ‘torchtext.data‘ has no attribute ‘Field

由于版本兼容性问题,运行代码可能遇到错误AttributeError: module ‘torchtext.data‘ has no attribute ‘Field‘,也可以参考attributeerror-module-torchtext-data-has-no-attribute-field。使用torchtext 0.10(可能会安装旧版的pytorch,所以用conda开个新环境,凑合着用吧),然后from torchtext import data改成from torchtext.legacy import data

阅读torchtext的版本更新与api变迁可以得知APi变迁。

  • 在0.8版本以前,是from torchtext import data
  • 在0.9到0.12版本之间,是from torchtext.legacy import data
  • 在0.12版本之后,该data库已经删除,如果坚持要用,需要参考新版API教程。

总之,要复用教程的API,最好用torchtext 0.9或0.10。pytorch-sentiment-analysis列出了一些基于该版本API的教程,可以参考它的第一个教程运行下。

踩坑 en_core_web_sm

如果运行遇到以下问题,说明需要下载en_core_web_sm并安装。

Can’t find model ‘en_core_web_sm’. It doesn’t seem to be a Python package or a valid path to a data directory.

参考NLP Spacy中en_core_web_sm安装问题,及最新版下载地址,到github的release界面搜索"en_core_web_sm",找最新版的压缩包下载并用pip install <path to .tar>安装。

笔者下载的是3.4.1版本。如果下载2.5.2,可能会出现cannot read config之类的错误。笔者不知道怎么解决。

相关教程

bentrevett/pytorch-sentiment-analysis的教程1提到了Field是如何帮助构建vocab的,以及如何对句子的单词作清洗工作(为了减少要训练的embedding的数,保留最高频率的单词)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/33903.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot 引入 smart-doc 接口文档管理插件,以及统一接口返回

最近在将多个服务端项目的接口进行整合管理&#xff0c;原本使用的是Swagger接口文档管理插件&#xff0c;网上搜了一下类似的插件&#xff0c;发现这个smart-doc插件&#xff0c;似乎挺简约优雅的&#xff0c;而且还可以推送接口文档到Torna&#xff0c;进行统一管理&#xff…

2023-2028年中国硅碳负极材料行业市场预测与投资规划分析报告

本报告由锐观咨询重磅推出&#xff0c;对中国硅碳负极材料行业的发展现状、竞争格局及市场供需形势进行了具体分析&#xff0c;并从行业的政策环境、经济环境、社会环境及技术环境等方面分析行业面临的机遇及挑战。还重点分析了重点企业的经营现状及发展格局&#xff0c;并对未…

kafka学习(七):消息队列与JMS

1、消息队列 我们可以把消息队列比作是一个存放消息的容器&#xff0c;当我们需要使用消息的时候可以取出消息供自己使用。 1.1、消息队列有什么用&#xff1f; 消息队列是分布式系统中重要的组件&#xff0c;使用消息队列主要是为了通过异步处理提高系统性能和削峰、降低系统…

MCE | 神经元为胰腺癌细胞提供营养

胰腺导管腺癌 (PDAC) &#xff0c;最常见的胰腺癌 (Pancreatic cancer) 类型 &#xff0c;是最致命的实体肿瘤之一&#xff0c;具有很高的侵袭性。PDAC 治疗的不良预后与其独特而复杂的微环境和代谢可塑性有关。PDAC 的肿瘤微环境 (TME) 主要成分是细胞外基质 (ECM)、脉管系统、…

tensorflow2 MobileNet

简介 深度学习的发展伴随着模型参数的暴涨&#xff0c;导致对运行模型的设备有很大的限制&#xff0c;普通的卷积神经网络模型难以运用到移动或嵌入式设备中&#xff0c;主要是这些设备的内存有限&#xff0c;其次这些设备的算力不能满足足够的响应速度&#xff0c;即实时性差…

[附源码]java毕业设计疫情期间回乡人员管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Swin Transformer目标检测实验——环境配置的步骤和避坑

Swin Transformer1. 网上基础教程&#xff08;带视频讲解&#xff09;2. 配置虚拟环境时遇到的一些问题&#xff08;按操作顺序排列&#xff09;1. 网上基础教程&#xff08;带视频讲解&#xff09; 大家是不是都从b站来的呀&#xff0c;先给你们基础环境的配置和搭配的视频教…

【SQLite】三、SQLite 的常用语法

作者主页&#xff1a;Designer 小郑 作者简介&#xff1a;软件工程师一枚&#xff0c;来自浙江宁波&#xff0c;负责开发管理公司OA项目&#xff0c;专注软件前后端开发&#xff08;Vue、SpringBoot和微信小程序&#xff09;、系统定制、远程技术指导。CSDN学院、蓝桥云课认证讲…

[论文阅读笔记18] DiffusionDet论文笔记与代码解读

扩散模型近期在图像生成领域很火, 没想到很快就被用在了检测上. 打算对这篇论文做一个笔记. 论文地址: 论文 代码: 代码 0. 扩散模型简述 首先介绍什么是扩散模型. 我们考虑生成任务, 即encoder-decoder形式的模型, encoder提取输入的抽象信息, 并尝试在decoder中恢复出来. 扩…

亚马逊鲲鹏系统:批量注册亚马逊买家号软件

之前我们有谈到过&#xff0c;想要注册亚马逊买家号&#xff0c;需要邮箱、ip、信用卡、收货地址和手机号。自己手动注册一个一个的太麻烦&#xff0c;还会花费大量的时间&#xff0c;那么有没有可以节约时间的自动化操作软件呢&#xff1f;想要自动化操作软件&#xff0c;来试…

金属带宽度测量方案

一、硬件部分 1.相机 像素&#xff1a;4864*3232 相机选择 1600 万像素即 4864*3232&#xff0c;即检测视场长宽比为 3&#xff1a;2 工件最大的直径为 320mm&#xff0c;假设检测的视场范围为 510*340 因 此 每 个 像 素 大 小 为 340mm/32800.104mm &#xff0c; 即 检 测 精…

数组形式的整数加法

1 问题 整数的 数组形式 num 是按照从左到右的顺序表示其数字的数组。 例如&#xff0c;对于 num 1321 &#xff0c;数组形式是 [1,3,2,1] 。 给定 num &#xff0c;整数的 数组形式 &#xff0c;和整数 k &#xff0c;返回 整数 num k 的 数组形式 。 2 方法 根据问题的描述…

ROS中使用protoBuf通信

ROS自身话题也挺好的&#xff0c;不过暂时还不知道如何判断网络&#xff0c;因此&#xff0c;还是想换回tcp/udp通信。 但是发现通信时数据比较多&#xff0c;调查一下&#xff0c;发现ROS支持google的protoBuf。 先建立一个ROS的项目&#xff0c;方便后面我们测试。 然后建立…

疫情物资储藏库建设规划问题,使用matlab+cplex+yalmib求解

疫情物资储藏库建设规划问题&#xff0c;使用matlabcplexyalmib求解一、Cplex安装及配置二、yalmib安装及配置三、案例分析一、Cplex安装及配置 一、安装Cplex Cplex 是一款商业化的规划问题求解器&#xff0c;支持python和matlab等常用语言&#xff0c;功能非常强大。可以根据…

19 0A-检索服务器支持的所有DTC的状态

诊断协议那些事儿 诊断协议那些事儿专栏系列文章&#xff0c;19服务作为UDS中子功能最多的服务&#xff0c;一共有28种子功能&#xff0c;本文将介绍常用的19 0A服务&#xff1a;检索服务器支持的所有DTC的状态。此子功能不论DTC是否发生、状态如何&#xff0c;都让ECU返回所有…

1533_AURIX_TriCore内核架构_指令集信息

全部学习汇总&#xff1a; GreyZhang/g_tricore_architecture: some learning note about tricore architecture. (github.com) 学习的顺序有一点调整&#xff0c;切换到了内核的第二卷。先了解一下指令集的基本信息。第二卷主要就是讲指令集以及扩展的&#xff0c;而且基本上只…

RSA加密算法Python实现

RSA加密算法Python实现1.RSA算法简介2.RSA算法涉及的数学知识2.1互素2.2 欧拉定理2.3求模逆元2.4 取模运算2.5 最大公因数2.6 最小公倍数2.7 欧几里得算法2.8 扩展欧几里得算法3.RSA算法数学实现3.1理论3.2实践4.RSA算法代码实现4.1RSA算法代码实现14.1RSA算法代码实现21.RSA算…

STP、RSTP、MSTP

STP、RSTP、MSTP的配置 本篇介绍STP、RSTP、MSTP的配置和常用的管理命令。 STP/RSTP/MSTP简介 以太网中为了进行链路备份&#xff0c;提高网络可靠性&#xff0c;通常会使用冗余链路&#xff0c;但是这也带来了网络环路的问题。网络环路会引发广播风暴和MAC地址表振荡等问题…

连续仨月霸占牛客榜首,京东 T8 呕心沥血神作:700 页 JVM 虚拟机实战手册

虚拟机是一种抽象化的计算机&#xff0c;通过在实际的计算机上仿真模拟各种计算机功能来实现的。Java 虚拟机有自己完善的硬体架构&#xff0c;如处理器、堆栈、寄存器等&#xff0c;还具有相应的指令系统。JVM 屏蔽了与具体操作系统平台相关的信息&#xff0c;使得 Java 程序只…

基于结构应力方法的焊接结构疲劳评估及实例分析(下篇)

作者 | 裴宪军 &#xff0c;仿真秀专栏作者 一、写在文前 焊接技术作为现代制造业中的支柱技术之一&#xff0c;是制造强国的关键保障。由于其整体性强、轻量化、经济性好等优点&#xff0c;焊接结构被广泛应用于轨道交通、航空航天&#xff0c;船舶、重型装备等领域&#xf…