nn.LSTM个人记录

news2025/5/23 4:11:51

简介

 

nn.LSTM参数

torch.nn.lstm(input_size,   "输入的嵌入向量维度,例如每个单词用50维向量表示,input_size就是50"
              hidden_size,  "隐藏层节点数量,也是输出的嵌入向量维度"
              num_layers,   "lstm 隐层的层数,默认为1"
              bias,         "隐层是否带 bias,默认为 true"
              batch_first,  "True 或者 False,如果是 True,则 input 为(batchsize, len, input_size),默认值为:False(len, batchsize, input_size)"
              dropout,      "除最后一层,每一层的输出都进行dropout,默认值0"
              bidirectional "如果设置为 True, 则表示双向 LSTM,默认为 False"
              )

维度

batch_first=True,输入维度(batchsize,len,input_size)

batch_first=False,输入维度(len,batchsize, input_size)

batch_first=False,输出维度(len,batchsize,hidden_size)

举例嵌入向量维度为1

 假如输入x为(batchsize,len)的序列,即嵌入向量维度为1,进行一个回归预测。

如果将嵌入向量维度维度设为1就不太合理,因为如果len非常长例如几w,那么经过几w的时间步得到的得到的h维度为(batchsize,1),序列太长丢失很多信息,再输入全连接层预测效果不好。并且lstm实际上将嵌入向量维度从input_size规约到hidden_size。

所以在这里我们将len作为input_size,嵌入向量维度1作为len(即对调了一下)

添加一个维度:

x = x.unsqueeze(0)

x维度变为(1,batchsize,len),相当于设置数据的长度为1,嵌入向量维度为len,通过nn.LSTM输入到网络中。

#lstm为定义的网络
#h[-1]为最后输入到全连接层的嵌入矩阵 但是由于此问题中len为1,所以x等于h[-1]
x, (h, c) = lstm(x)

x维度变为(1,batchsize,hidden_size)

h为每层lstm最后一个时间步的输出一般可以输入到后续的全连接层),维度为(num_layers,batchsize,hidden_size)

c为最后一个时间步 LSTM cell 的状态(记忆单元,一般用不到),维度为(num_layers,batchsize,hidden_size)

移除张量中所有尺寸为 1 的维度,即将第一个维度移除掉:

lstm_out = x.squeeze(0)

x维度变为(batchsize,hidden_size) ,输入到全连接层(线性层,维度(hidden_size,num_class))中,最终输出维度(batchsize,num_class)

参考:

Pytorch — LSTM (nn.LSTM & nn.LSTMCell)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1331800.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

02_算法分析

02_算法分析 0.1 算法的时间复杂度分析0.1.1 函数渐近增长概念:输入规模n>2时,算法A1的渐近增长小于算法B1 的渐近增长随着输入规模的增大,算法的常数操作可以忽略不计测试二:随着输入规模的增大,与最高次项相乘的常…

【计数DP】牛客小白月赛19

登录—专业IT笔试面试备考平台_牛客网 题意 思路 首先做法一定是计数 dp 然后状态设计,先设 dp[i] 然后看影响决策的因素:两边的火焰情况,那就 dp[i][0/1][0/1]表示 前 i 个,该位有无火焰,该位右边有无火焰的方案数…

单片机的RTC获取网络时间

理解网络同步校准RTC的原理需要考虑NTP、SNTP、RTC这三个关键组件的作用和交互。下面详细解释这个过程: 1. NTP(Network Time Protocol): 协议目的:NTP是用于同步计算机和设备时钟的协议。它通过在网络上与时间服务器通…

为什么react call api in cDidMount

为什么react call api in cDM 首先,放到constructor或者cWillMount不是语法错误 参考1 参考2 根据上2个参考,总结为: 1、官网就是这么建议的: 2、17版本后的react 由于fiber的出现导致 cWM 会调用多次! cWM 方法已…

【并发设计模式】聊聊两阶段终止模式如何优雅终止线程

在软件设计中,抽象出了23种设计模式,用以解决对象的创建、组合、使用三种场景。在并发编程中,针对线程的操作,也抽象出对应的并发设计模式。 两阶段终止模式- 优雅停止线程避免共享的设计模式- 只读、Copy-on-write、Thread-Spec…

LangChain 30 ChatGPT LLM将字符串作为输入并返回字符串Chat Model将消息列表作为输入并返回消息

LangChain系列文章 LangChain 实现给动物取名字,LangChain 2模块化prompt template并用streamlit生成网站 实现给动物取名字LangChain 3使用Agent访问Wikipedia和llm-math计算狗的平均年龄LangChain 4用向量数据库Faiss存储,读取YouTube的视频文本搜索I…

【线性代数】决定张成空间的最少向量线性无关吗?

答1: 是的,张成空间的最少向量是线性无关的。 在数学中,张成空间(span space)是一个向量空间,它由一组向量通过线性组合(即每个向量乘以一个标量)生成。如果这组向量是线性无关的&…

CV算法面试题学习

本文记录了CV算法题的学习。 CV算法面试题学习 点在多边形内(point in polygon)高斯滤波器 点在多边形内(point in polygon) 参考自文章1,其提供的代码没有考虑一些特殊情况,所以做了改进。 做法&#xff…

网络爬虫之多任务数据采集(多线程、多进程、协程)

进程:是操作系统中资源分配的基本单位 线程:使用进程分配的资源处理具体任务 一个进程中可以有多个线程:进程相当于一个公司,线程就是公司里面的员工。 一 多线程 多线程都是关于功能的并发执行。而异步编程是关于函数之间的非…

持续集成交付CICD:Jira 远程触发 Jenkins 实现更新 GitLab 分支

目录 一、实验 1.环境 2.GitLab 查看项目 3.Jira新建模块 4. Jira 通过Webhook 触发Jenkins流水线 3.Jira 远程触发 Jenkins 实现更新 GitLab 分支 二、问题 1.Jira 配置网络钩子失败 2. Jira 远程触发Jenkins 报错 一、实验 1.环境 (1)主机 …

uniapp自定义头部导航怎么实现?

一、在pages.json文件里边写上自定义属性 "navigationStyle": "custom" 二、在对应的index页面写上以下&#xff1a; <view :style"{ height: headheight px, backgroundColor: #24B7FF, zIndex: 99, position: fixed, top: 0px, width: 100% …

STM32——CAN协议

文章目录 一.CAN协议的基本特点1.1 特点1.2 电平标准1.3 基本的五个帧1.4 数据帧 二.数据帧解析2.1 帧起始和仲裁段2.2 控制段2.3 数据段和CRC段2.4 ACK段和帧结束 三.总线仲裁四.位时序五.STM32CAN控制器原理与配置5.1 STM32CAN控制器介绍5.2 CAN的模式5.3 CAN框图 六 手册寄存…

Hago 的 Spark on ACK 实践

作者&#xff1a;华相 Hago 于 2018 年 4 月上线&#xff0c;是欢聚集团旗下的一款多人互动社交明星产品。Hago 融合优质的匹配能力和多样化的垂类场景&#xff0c;提供互动游戏、多人语音、视频直播、 3D 虚拟形象互动等多种社交玩法&#xff0c;致力于为用户打造高效、多样、…

P4 音频知识点——PCM音频原始数据

目录 前言 01 PCM音频原始数据 1.1 频率 1.2 振幅&#xff1a; 1.3 比特率 1.4 采样 1.5 量化 1.6 编码 02. PCM数据有以下重要的参数&#xff1a; 采样率&#xff1a; 采集深度 通道数 ​​​​​​​ PCM比特率 ​​​​​​​ PCM文件大小计算&#xff1a; ​…

计算机毕业设计 基于SpringBoot的房屋租赁管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

HTML---网页布局

目录 文章目录 一.常见的网页布局 二.标准文档流 标准文档流常见标签 三.display属性 四.float属性 总结 一.常见网页布局 二.标准文档流 标准文档流常见标签 标准文档流的组成 块级元素<div>、<p>、<h1>-<h6>、<ul>、<ol>等内联元素<…

Hadoop入门学习笔记——一、VMware准备Linux虚拟机

视频课程地址&#xff1a;https://www.bilibili.com/video/BV1WY4y197g7 课程资料链接&#xff1a;https://pan.baidu.com/s/15KpnWeKpvExpKmOC8xjmtQ?pwd5ay8 Hadoop入门学习笔记&#xff08;汇总&#xff09; 目录 一、VMware准备Linux虚拟机1.1. VMware安装Linux虚拟机1.…

机密计算容器前沿探索与 AI 场景应用

作者&#xff1a;壮怀、朱江云 企业与个人对数据隐私保护日益关切&#xff0c;从数据&#xff0c;网络的可信基础设施扩展到闭环可信的计算基础设施&#xff0c;可信的计算&#xff0c;存储&#xff0c; 网络基础设施必定成为云计算的标配。 机密计算技术应运而生&#xff0c;…

Python遥感影像深度学习指南(1)-使用卷积神经网络(CNN、U-Net)和 FastAI进行简单云层检测

【遥感影像深度学习】系列的第一章,Python遥感影像深度学习的入门课程,介绍如何使用卷积神经网络(CNN)从卫星图像中分割云层 1、数据集 在本项目中,我们将使用 Kaggle 提供的 38-Cloud Segmentation in Satellite Images数据集。 该数据集由裁剪成 384x384 (适用…

AWD认识和赛前准备

AWD介绍 AWD: Attack With Defence, 北赛中每个队伍维护多台服务器&#xff0c;服务器中存在多个漏洞&#xff0c;利 用漏洞攻击其他队伍可以进行得分,修复漏洞可以避免被其他队伍攻击失分。 一般分配Web服务器&#xff0c;服务器(多数为Linux)某处存在flag(一般在根目录下)&am…