007_补充_ Pytorch 反向传播和Neural ODE的反向传播

news2025/7/7 23:28:18

一、Pytorch反向传播

首先是第一个小例子,训练模型拟合 y = true_w * x + true_b,模型的参数为 param_w, param_b

import torch

true_w = torch.Tensor([[2.0, 3.0], [4.0, 5.0]])  # 初始化真实的参数
true_b = torch.Tensor([[1.0, 2.0], [3.0, 4.0]])  # 默认情况下,创建的Tensor requires_grad=False

x = torch.ones(2, 2, requires_grad=False)  # 默认情况下, requires_grad=False, 创建出来的Tensor不会自动计算梯度
param_w = torch.ones(2, 2, requires_grad=True)  # 设置网络要训练的参数 w, b,初始值都为全1
param_b = torch.ones(2, 2, requires_grad=True)

true_y = true_w * x + true_b  # 计算真实应得到的y
predict_y = param_w * x + param_b  # 计算模型预测的y

loss = (true_y - predict_y).mean()  # L1损失

# 这里在backward之前,param_w param_b的grad都是空的,只有在backward之后grad才有值
loss.backward()
# 在backward之后,param_w param_b的grad输出为
# tensor([[-0.2500, -0.2500],
#         [-0.2500, -0.2500]])
# 具体的计算过程就是,这里loss求了平均,而矩阵中有四个数,也就是0.25
# 于是 param_w.grad = 0.25 * x
# 而 param_b.grad 虽然和 param_w.grad 相等,但是计算的过程不同
# param_b.grad = 0.25 * 1 同时要广播到与 param_b 同样的维度


# 如果要更新参数,通常用的是optimizer,里边的最基本的操作便是把要优化的参数减去梯度乘上损失率
loss_rate = 0.01
param_w = param_w - loss_rate * param_w.grad
param_b = param_b - loss_rate * param_b.grad

上面的代码就是一次反向传播并更新参数的过程。这个过程存在一个问题,之后再说。
关于pytorch的反向传播与自动求导,pytorch会对每个操作,存储其反向传播的方式,
比如 对于上面的loss,输出之后是

tensor(4., grad_fn=<MeanBackward0>)

其中的MeanBackward就是指loss是计算的平均值,相应的要用平均值的反向传播方式与求导方式,也就是MeanBackward。

同样的道理,对于中间的predict_y其对应的反向传播和求导方式就是AddBackward。至于计算图这些内容不再赘述,很多文章讲的很详细。

那么以上的代码存在什么问题呢?对于只更新一次的做法没有任何问题, 但是要训练多次的情况,比如在后边再复制几行

true_y = true_w * x + true_b  # 计算真实应得到的y
predict_y = param_w * x + param_b  # 计算模型预测的y
loss = (true_y - predict_y).mean()  # L1损失
loss.backward()

这个时候再输出param_w.grad和param_b.grad会得到:

tensor([[-0.5000, -0.5000],
        [-0.5000, -0.5000]])

相当于两次梯度值的累加,这就会导致之后的梯度值会越来越大,于是需要在第二次loss.backward()之前将梯度清空

用optimizer清空梯度直接用zero_grad,用手动更新的方法就需要手动设置

param_w.grad.data.zero_()
param_b.grad.data.zero_()

二、Neural ODE的反向传播

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/9640.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux安装jdk17

登录linux 我使用的是Alibaba Cloud Linux 3.2104 LTS 64位操作系统&#xff0c;登录后结果如下&#xff1a; Welcome to Alibaba Cloud Elastic Compute Service !Updates Information Summary: available7 Security notice(s)5 Important Security notice(s)2 Moderate Sec…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java校园二手商品交易系统p11v7

毕业设计说实话没有想象当中的那么难&#xff0c;导师也不会说刻意就让你毕设不通过&#xff0c;不让你毕业啥的&#xff0c;你只要不是太过于离谱的&#xff0c;都能通过的。首先你得要对你在大学期间所学到的哪方面比较熟悉&#xff0c;语言比如JAVA、PHP等这些&#xff0c;数…

laravel对于数据量特别大的导出excel的提速方案

背景&#xff1a;一些业务场景需要导出excel的需求&#xff0c;但是面对日益增长的数据&#xff0c;要导出的数据会越来越大。生成表格的时间也会越来越慢。针对这个问题&#xff0c;目前的业务通过两个角度去提速。 1&#xff1a;异步导出 如果数据量达到一定的体量&#xf…

【毕业设计】大数据大众点评评论文本分析 - python 数据挖掘

文章目录0 前言1 爬虫1.1 整体思路1.2 网页爬取和解析1.3 数据存储1.4 反爬虫对抗2 探索性分析与文本数据预处理2.1 探索性分析2.2 数据预处理2.3 词云展示3 文本的情感分析3.1 先上结果3.2 文本特征提取&#xff08;TF-IDF&#xff09;3.3 机器学习建模3.4 最后输出的准确率4 …

java ssh校园拼餐系统

首先在系统前台&#xff0c;游客用户可以经过账号注册&#xff0c;管理员审核通过后&#xff0c;用账号密码登录系统前台&#xff0c;查看拼餐服务、网站公告、文明拼餐员、会员风彩、系统简介、咨询信息、拼餐信息等栏目信息&#xff0c;进行在线咨询和管理员交流&#xff0c;…

LTSPICE使用教程:二极管钳位电路仿真

在我们查看芯片内部的设计电路时&#xff0c;通常会发现以下的电路结构&#xff1a; 当定义pin脚输入电压Vpin&#xff0c; 1.Vpin>VDD,二极管D1导通&#xff0c;D2截止&#xff0c;此时无论怎样继续加大VPIN的输入电压时&#xff0c; 进入到管脚内部的电压会被钳制在Vint…

【RocketMQ中生产者生产消息的高可用机制、消费者消费消息的高可用机制、消息的重试机制、死信队列于死信消息】

一.知识回顾 【0.RocketMQ专栏的内容在这里哟&#xff0c;帮你整理好了&#xff0c;更多内容持续更新中】 【1.Docker安装部署RocketMQ消息中间件详细教程】 【2.RocketMQ生产者发送消息的三种方式:发送同步消息、异步消息、单向消息&案例实战&详细学习流程】 【3.Rock…

野火FPGA入门(5)

文章目录第17讲&#xff1a;触摸按键控制LED灯第18讲&#xff1a;流水灯第19讲&#xff1a;呼吸灯第20讲&#xff1a;状态机第21讲&#xff1a;无源蜂鸣器驱动实验第17讲&#xff1a;触摸按键控制LED灯 触摸按键可分为四大类&#xff1a;电阻式、电容式、红外感应式、表面声波…

调优工具常用命令

语法格式 mysqldumpslow [ OPTS... ] [ LOGS... ] //命令行格式常用到的格式组合 -s 表示按照何种方式排序c 访问次数l 锁定时间r 返回记录t 查询时间al 平均锁定时间ar 平均返回记录数at 平均查询时间 -t 返回前面多少条数据 -g 后边搭配一个正则匹配模式&#xff0c;大小写…

机械专业学子的芯片封装仿真“逆袭之路”

作者&#xff1a;萧显军 导读&#xff1a;近期&#xff0c;ANSYS公司给清华大学集成电路学院捐赠了一批业界领先的计算机辅助工程(CAE)软件及自动化(EDA)软件&#xff0c;为清华大学的芯片设计仿真的教学科研工作提供更强大的软件服务与技术支撑。 捐的仿真软件包括ANSYS涉及…

小白学Java

ip地址&#xff1a;用于唯一识别标记网络中的每一台计算机 查看方法&#xff1a;ipconfig ip地址的表示形式&#xff1a;点分十进制 xx.xx.xx.xx 每个十进制数的范围&#xff1a;0-255 ip地址的组成 网络地址主机地址 ipv4地址分类&#xff1a; &#xff08;特殊&#xff1a;…

一、react简介

目标 理解react这个框架在前端开发中的地位理解react诞生的原因和意义&#xff08;react是一个用于快速构建前端视图的javaScript库&#xff09;理解什么是虚拟dom、原生js模拟出虚拟dom的表示&#xff0c;模拟出创建虚拟dom的方法&#xff0c;模拟出虚拟dom转换成真实dom的方…

什么是甘特图?什么是项目管理?

数字化与信息化早已成为现今人们工作和生活中不可缺少的一部分。尤其是随着科学技术的进步&#xff0c;人们对数字化的期待也越来也高。作为项目管理中常备的工具&#xff0c;甘特图已经成为不少业内人士中常备的“神器”了。然而依旧有人搞不清甘特图与项目管理区别究竟在哪里…

Revit中创建基于线的砌体墙及【快速砌体排砖】

​  墙可以更改内部结构和材质&#xff0c;但是很难画出砌块样式形成的墙体&#xff0c;我们可以用其他方式画出砌体排砖墙么?这里我们用基于线的常规模型做砌体排砖墙。在开始我们需要做两个族&#xff0c;作为砌体排砖墙的基本单位&#xff0c;也就是一个单独的砌体块。 一…

多亏了这份大佬整理的Java进阶笔记,让我斩获7个offer

移动互联网时代&#xff0c;IT 系统变得愈加复杂&#xff0c;对我们程序员的要求也是越来越高&#xff0c;技术不断更新&#xff0c;我们还不能停止学习&#xff0c;停下来了就会被打上一个‘不合格的程序员’的标签&#xff0c;如何成为一位「不那么差」的程序员&#xff1f; …

java.io.IOException: FIS_AUTH_ERROR in Android Firebase

项目里更换完google-services.json文件后&#xff0c;获取 firebase token 时&#xff0c;显示报错&#xff1a; E/FirebaseInstanceId: Topic sync or token retrieval failed on hard failure exceptions: FIS_AUTH_ERROR. Wont retry the operation.D/AndroidRuntime: Sh…

测试行业3年经验,从大厂裸辞后,面试阿里、字节全都一面挂,被面试官说我的水平还不如应届生

测试员可以先在大厂镀金&#xff0c;以后去中小厂毫无压力&#xff0c;基本不会被卡&#xff0c;事实果真如此吗&#xff1f;但是在我身上却是给了我很大一巴掌... 所谓大厂镀金只是不卡简历而已&#xff0c;如果面试答得稀烂&#xff0c;人家根本不会要你。况且要不是大厂出来…

精品基于ssm的足球联赛管理系统的设计与实现vue

《基于ssm的足球联赛管理系统的设计与实现》该项目含有源码、论文等资料、配套开发软件、软件安装教程、项目发布教程等 使用技术&#xff1a; 开发语言&#xff1a;Java 框架&#xff1a;ssm 前端技术&#xff1a;JavaScript、VUE.js&#xff08;2.X&#xff09;、css3 J…

记录一次服务器CPU负载高,利用率正常的处理方法

背景&#xff1a; 在一次查看服务器监控的时候偶然发现其中一台服务器的CPU负载很高&#xff0c;但是CPU利用率基本没有&#xff0c;通过top命令完全看不出来问题所在&#xff0c;经过一些思路的排查发现了原因并处理&#xff0c;现记录下来。 现象&#xff1a; top命令查看…