李沐--动手学深度学习--GRU

news2025/6/10 13:08:30

1.GRU从零开始实现

#9.1.2GRU从零开始实现
import torch
from torch import nn
from d2l import torch as d2l

#首先读取 8.5节中使用的时间机器数据集
batch_size,num_steps = 32,35
train_iter,vocab = d2l.load_data_time_machine(batch_size,num_steps)
#初始化模型参数
def get_params(vocab_size,num_hiddens,device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape,device=device)*0.01

    def three():
        return (normal((num_inputs,num_hiddens)),
                normal((num_hiddens,num_hiddens)),
                torch.zeros(num_hiddens,device=device))

    W_xz,W_hz,b_z = three()  #更新门参数
    W_xr,W_hr,b_r = three()  #重置门参数
    W_xh,W_hh,b_h = three()  #候选隐状态参数
    #输出层参数
    W_hq = normal((num_hiddens,num_outputs))
    b_q = torch.zeros(num_outputs,device=device)
    #附加梯度
    params = [W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q]
    for param in params:
        param.requires_grad_(True)
    return params
#定义隐状态的初始化函数init_gru_state
def init_gru_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),)
#门控循环单元模型
def gru(inputs,state,params):
    W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz)+(H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr)+(H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh)+((R*H) @ W_hh) + b_h)
        H = Z * H + (1-Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs,dim=0),(H,)
#训练与预测:打印输出训练集的困惑度,以及前缀“time traveler”和“traveler”的预测序列上的困惑度。
vocab_size,num_hiddens,device = len(vocab),256,d2l.try_gpu()
num_epochs,lr = 500,1
model = d2l.RNNModelScratch(len(vocab),num_hiddens,device,get_params,
                            init_gru_state,gru)
print(d2l.train_ch8(model,train_iter,vocab,lr,num_epochs,device))
d2l.plt.show()

2.GRU简洁实现

#9.1.3简洁实现
import torch
from torch import nn
from d2l import torch as d2l
#首先读取 8.5节中使用的时间机器数据集
batch_size,num_steps = 32,35
train_iter,vocab = d2l.load_data_time_machine(batch_size,num_steps)

vocab_size,num_hiddens,device = len(vocab),256,d2l.try_gpu()
num_epochs,lr = 500,1

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs,num_hiddens)
model = d2l.RNNModel(gru_layer,len(vocab))
model = model.to(device)
print(d2l.train_ch8(model,train_iter,vocab,lr,num_epochs,device))
d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2406656.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EasyRTC音视频实时通话功能在WebRTC与智能硬件整合中的应用与优势

一、WebRTC与智能硬件整合趋势​ 随着物联网和实时通信需求的爆发式增长,WebRTC作为开源实时通信技术,为浏览器与移动应用提供免插件的音视频通信能力,在智能硬件领域的融合应用已成必然趋势。智能硬件不再局限于单一功能,对实时…

【版本控制】GitHub Desktop 入门教程与开源协作全流程解析

目录 0 引言1 GitHub Desktop 入门教程1.1 安装与基础配置1.2 核心功能使用指南仓库管理日常开发流程分支管理 2 GitHub 开源协作流程详解2.1 Fork & Pull Request 模型2.2 完整协作流程步骤步骤 1: Fork(创建个人副本)步骤 2: Clone(克隆…

Android屏幕刷新率与FPS(Frames Per Second) 120hz

Android屏幕刷新率与FPS(Frames Per Second) 120hz 屏幕刷新率是屏幕每秒钟刷新显示内容的次数,单位是赫兹(Hz)。 60Hz 屏幕:每秒刷新 60 次,每次刷新间隔约 16.67ms 90Hz 屏幕:每秒刷新 90 次,…

【PX4飞控】mavros gps相关话题分析,经纬度海拔获取方法,卫星数锁定状态获取方法

使用 ROS1-Noetic 和 mavros v1.20.1, 携带经纬度海拔的话题主要有三个: /mavros/global_position/raw/fix/mavros/gpsstatus/gps1/raw/mavros/global_position/global 查看 mavros 源码,来分析他们的发布过程。发现前两个话题都对应了同一…

ubuntu中安装conda的后遗症

缘由: 在编译rk3588的sdk时,遇到编译buildroot失败,提示如下: 提示缺失expect,但是实测相关工具是在的,如下显示: 然后查找借助各个ai工具,重新安装相关的工具,依然无解。 解决&am…

【Java多线程从青铜到王者】单例设计模式(八)

wait和sleep的区别 我们的wait也是提供了一个还有超时时间的版本,sleep也是可以指定时间的,也就是说时间一到就会解除阻塞,继续执行 wait和sleep都能被提前唤醒(虽然时间还没有到也可以提前唤醒),wait能被notify提前唤醒&#xf…

react菜单,动态绑定点击事件,菜单分离出去单独的js文件,Ant框架

1、菜单文件treeTop.js // 顶部菜单 import { AppstoreOutlined, SettingOutlined } from ant-design/icons; // 定义菜单项数据 const treeTop [{label: Docker管理,key: 1,icon: <AppstoreOutlined />,url:"/docker/index"},{label: 权限管理,key: 2,icon:…

VSCode 使用CMake 构建 Qt 5 窗口程序

首先,目录结构如下图: 运行效果: cmake -B build cmake --build build 运行: windeployqt.exe F:\testQt5\build\Debug\app.exe main.cpp #include "mainwindow.h"#include <QAppli

Win系统权限提升篇UAC绕过DLL劫持未引号路径可控服务全检项目

应用场景&#xff1a; 1、常规某个机器被钓鱼后门攻击后&#xff0c;我们需要做更高权限操作或权限维持等。 2、内网域中某个机器被钓鱼后门攻击后&#xff0c;我们需要对后续内网域做安全测试。 #Win10&11-BypassUAC自动提权-MSF&UACME 为了远程执行目标的exe或者b…

Qwen系列之Qwen3解读:最强开源模型的细节拆解

文章目录 1.1分钟快览2.模型架构2.1.Dense模型2.2.MoE模型 3.预训练阶段3.1.数据3.2.训练3.3.评估 4.后训练阶段S1: 长链思维冷启动S2: 推理强化学习S3: 思考模式融合S4: 通用强化学习 5.全家桶中的小模型训练评估评估数据集评估细节评估效果弱智评估和民间Arena 分析展望 如果…

RushDB开源程序 是现代应用程序和 AI 的即时数据库。建立在 Neo4j 之上

一、软件介绍 文末提供程序和源码下载 RushDB 改变了您处理图形数据的方式 — 不需要 Schema&#xff0c;不需要复杂的查询&#xff0c;只需推送数据即可。 二、Key Features ✨ 主要特点 Instant Setup: Be productive in seconds, not days 即时设置 &#xff1a;在几秒钟…

表单设计器拖拽对象时添加属性

背景&#xff1a;因为项目需要。自写设计器。遇到的坑在此记录 使用的拖拽组件时vuedraggable。下面放上局部示例截图。 坑1。draggable标签在拖拽时可以获取到被拖拽的对象属性定义 要使用 :clone, 而不是clone。我想应该是因为draggable标签比较特。另外在使用**:clone时要将…

CSS 工具对比:UnoCSS vs Tailwind CSS,谁是你的菜?

在现代前端开发中&#xff0c;Utility-First (功能优先) CSS 框架已经成为主流。其中&#xff0c;Tailwind CSS 无疑是市场的领导者和标杆。然而&#xff0c;一个名为 UnoCSS 的新星正以其惊人的性能和极致的灵活性迅速崛起。 这篇文章将深入探讨这两款工具的核心理念、技术差…

Qt的学习(二)

1. 创建Hello Word 两种方式&#xff0c;实现helloworld&#xff1a; 1.通过图形化的方式&#xff0c;在界面上创建出一个控件&#xff0c;显示helloworld 2.通过纯代码的方式&#xff0c;通过编写代码&#xff0c;在界面上创建控件&#xff0c; 显示hello world&#xff1b; …

工厂方法模式和抽象工厂方法模式的battle

1.案例直接上手 在这个案例里面&#xff0c;我们会实现这个普通的工厂方法&#xff0c;并且对比这个普通工厂方法和我们直接创建对象的差别在哪里&#xff0c;为什么需要一个工厂&#xff1a; 下面的这个是我们的这个案例里面涉及到的接口和对应的实现类&#xff1a; 两个发…

鸿蒙Navigation路由导航-基本使用介绍

1. Navigation介绍 Navigation组件是路由导航的根视图容器&#xff0c;一般作为Page页面的根容器使用&#xff0c;其内部默认包含了标题栏、内容区和工具栏&#xff0c;其中内容区默认首页显示导航内容&#xff08;Navigation的子组件&#xff09;或非首页显示&#xff08;Nav…

CMS内容管理系统的设计与实现:多站点模式的实现

在一套内容管理系统中&#xff0c;其实有很多站点&#xff0c;比如企业门户网站&#xff0c;产品手册&#xff0c;知识帮助手册等&#xff0c;因此会需要多个站点&#xff0c;甚至PC、mobile、ipad各有一个站点。 每个站点关联的有站点所在目录及所属的域名。 一、站点表设计…

ZYNQ学习记录FPGA(二)Verilog语言

一、Verilog简介 1.1 HDL&#xff08;Hardware Description language&#xff09; 在解释HDL之前&#xff0c;先来了解一下数字系统设计的流程&#xff1a;逻辑设计 -> 电路实现 -> 系统验证。 逻辑设计又称前端&#xff0c;在这个过程中就需要用到HDL&#xff0c;正文…

Java中HashMap底层原理深度解析:从数据结构到红黑树优化

一、HashMap概述与核心特性 HashMap作为Java集合框架中最常用的数据结构之一&#xff0c;是基于哈希表的Map接口非同步实现。它允许使用null键和null值&#xff08;但只能有一个null键&#xff09;&#xff0c;并且不保证映射顺序的恒久不变。与Hashtable相比&#xff0c;Hash…

【记录坑点问题】IDEA运行:maven-resources-production:XX: OOM: Java heap space

问题&#xff1a;IDEA出现maven-resources-production:operation-service: java.lang.OutOfMemoryError: Java heap space 解决方案&#xff1a;将编译的堆内存增加一点 位置&#xff1a;设置setting-》构建菜单build-》编译器Complier