前向传播与反向传播参数的更新方式(附公式代码)

news2025/7/24 3:52:40

前向传播与反向传播意义及其参数的更新方式

文章目录

  • 前向传播与反向传播意义及其参数的更新方式
      • 一、前言
      • 二、前反向传播的作用
      • 三、前向传播
      • 四、反向传播
        • 代码

一、前言

因为本身非科班出身,数学又学的很差,一直都是傻瓜式地用tensorflow和pytorch搭网络。前一段时间竞赛的时候尝试着用简单神经网络做了个题,同学突然问起反向传播的具体原理,一时语塞,遂下决心把这个问题搞明白。这篇学习笔记将以我的认知顺序也就是由浅至深的顺序叙述,里面可能涉及到一些神经网络的基础知识,比如学习率、激活函数、损失函数等,详情可以看看这里,本文不再赘述

写文章的时候查阅了一些资料,感觉写得最好的是这篇文章,我的一些思路也有所参考,推荐去看看,记得给大佬点star : )

二、前反向传播的作用

这个问题应该大部分接触过神经网络的人都有所了解,我最开始的认知也就停留在这一步

前向传播,也叫正向传播,其实就是参数在神经网络中从输入层到输出层传输过程

反向传播,其实就是根据输出层的输出实际值的差距,更新神经网络中参数的过程

而一次正向传播加上一次反向传播就是一次网络的学习

话虽如此,参数在网络中到底是如何变化的呢

三、前向传播

首先我们来看一个神经网络,这个神经网络是如此的简单,这种简单结构的网络可以使我们更好地理解神经网络的工作方式。

请添加图片描述

所谓前向传播,其实就是将神经网络的上一层作为下一层的输入,并计算下一层的输出,一直到输出层位置

如上图,假如输入层输入x,那么参数前向传播到隐藏层其实就是输入x权重矩阵相乘加上偏置项之和再通过激活函数,假设我们使用的激活函数为
f ( x ) = x × 2 f(x)=x\times2 f(x)=x×2
此时输入层的输出就是
f ( x × w 1 + b 1 ) = ( x × w 1 + b 1 ) × 2 f(x\times w_1 + b_1) = (x\times w_{1}+b_{1})\times2 f(x×w1+b1)=(x×w1+b1)×2
当参数继续向前传播,通过隐藏层的输出到输出层,其值为
∑ f ( f ( x × w 1 + b 1 ) × w 2 + b 2 ) = 2 × ( 2 × ( w 1 x + b 1 ) × w 2 + b 2 ) ( w 2 和 b 2 是一个 1 ∗ 3 的向量,比较复杂,就不展开了) \sum f(f(x\times w_1 + b_1)\times w_2 + b_2)=2\times(2\times(w_1x+b_1)\times w_2+b_2)(w2和b2是一个1 * 3的向量,比较复杂,就不展开了) f(f(x×w1+b1)×w2+b2)=2×(2×(w1x+b1)×w2+b2)w2b2是一个13的向量,比较复杂,就不展开了)
上面的式子的值其实就是神经网络的输出了,这样两个算式描述了一次前向传播的全部过程

四、反向传播

由于反向传播涉及到导数运算,而我的数学能力已经退化到小学水平了,所以这里我们直接使用一个1 * 1 * 1的 “神经网络” 来做演示
请添加图片描述

这里我们的损失函数选择使用最常见的均方误差(MSE),即定义损失值为预测值与实际值的差的平方除以样本数,这个损失函数对异常值比较敏感,适用于回归问题
L O S S = M S E ( y _ , y ) = ∑ i = 1 n ( y − y _ ) 2 n LOSS=MSE({y_\_},y) = \frac{{\sum\nolimits_{i = 1}^n {{{(y - y_\_)}^{2}}} }}{n} LOSS=MSE(y_,y)=ni=1n(yy_)2
而更新参数的依据,就是使最后预测的结果朝着损失函数值减小的方向移动,故我们用损失函数对每一个参数求偏导,让各个参数往损失函数减小的方向变化。假设我们这里的激活函数为
f ( x ) = x f(x) = x f(x)=x

损失函数对各参数求偏导的结果如下
定义输入层为输出为 h 1 ,隐藏层输出为 h 2 , y 预测值为 y _ 定义输入层为输出为h_1,隐藏层输出为h_2,y预测值为y_\_ 定义输入层为输出为h1,隐藏层输出为h2y预测值为y_
∂ L ∂ y = 2 ( y _ − y ) / / 单样本情况下, n = 1 \frac{\partial L}{\partial y} =2(y_\_-y) \quad//单样本情况下,n=1 yL=2(y_y)//单样本情况下,n=1
∂ L ∂ w 2 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ w 2 = 2 ( y _ − y ) × 1 × h 1 = 2 ( y _ − y ) × 1 × ( w 1 x + b 1 ) \frac{ \partial L }{ \partial w_2 } =\frac{ \partial L }{ \partial y }\times\frac{ \partial y }{ \partial h_2 } \times\frac{ \partial h_2 }{ \partial w_2 } =2(y_\_-y)\times1\times h_1 =2(y_\_-y)\times1\times (w_1x+b_1) w2L=yL×h2y×w2h2=2(y_y)×1×h1=2(y_y)×1×(w1x+b1)
∂ L ∂ b 2 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ b 2 = 2 ( y _ − y ) × 1 × 1 = 2 ( y _ − y ) \frac{\partial L}{\partial b_2} =\frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_2} \times \frac{\partial h_2}{\partial b_2} =2(y_\_-y)\times1\times 1 =2(y_\_-y) b2L=yL×h2y×b2h2=2(y_y)×1×1=2(y_y)
∂ L ∂ w 1 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ h 1 × ∂ h 1 ∂ w 1 = 2 ( y _ − y ) × 1 × w 2 × x \frac{\partial L}{\partial w_1} =\frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_2} \times \frac{\partial h_2}{\partial h_1} \times \frac{\partial h_1}{\partial w_1} =2(y_\_-y)\times1\times w_2\times x w1L=yL×h2y×h1h2×w1h1=2(y_y)×1×w2×x

∂ L ∂ b 1 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ h 1 × ∂ h 1 ∂ b 1 = 2 ( y _ − y ) × 1 × w 2 × 1 \frac{\partial L}{\partial b_1} =\frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_2} \times \frac{\partial h_2}{\partial h_1} \times \frac{\partial h_1}{\partial b_1} =2(y_\_-y)\times1\times w_2\times 1 b1L=yL×h2y×h1h2×b1h1=2(y_y)×1×w2×1

反向传播算法建立在梯度下降法的基础上,已经算出各参数偏导的情况下,需要使用梯度下降法进行参数更新,我们以学习率为μ为例,各参数的更新如下

Δ w 2 = − μ ∂ L ∂ w 2 Δ L O S S = − μ × 2 ( y _ − y ) × 1 × ( w 1 x + b 1 ) ) × ( y _ − y ) \Delta w_2 = -\mu \frac{ \partial L }{ \partial w_2 } \Delta LOSS =-\mu\times2(y_\_-y)\times1\times (w_1x+b_1))\times(y_{\_}-y) Δw2=μw2LΔLOSS=μ×2(y_y)×1×(w1x+b1))×(y_y)

Δ b 2 = − μ ∂ L ∂ b 2 Δ L O S S = − μ × 2 ( y _ − y ) × ( y _ − y ) \Delta b_2 =-\mu \frac{\partial L}{\partial b_2}\Delta LOSS =-\mu\times2(y_\_-y)\times(y_\_-y) Δb2=μb2LΔLOSS=μ×2(y_y)×(y_y)

Δ w 1 = − μ ∂ L ∂ w 1 Δ L O S S = − μ × 2 ( y _ − y ) × w 2 × x × ( y _ − y ) \Delta w_1 =-\mu\frac{\partial L}{\partial w_1}\Delta LOSS =-\mu \times 2(y_\_-y)\times w_2\times x\times (y_\_-y) Δw1=μw1LΔLOSS=μ×2(y_y)×w2×x×(y_y)

Δ b 1 = − μ ∂ L ∂ b 1 Δ L O S S = − μ × 2 ( y _ − y ) × w 2 × ( y _ − y ) \Delta b_1 =-\mu \frac{\partial L}{\partial b_1}\Delta LOSS =-\mu \times 2(y_\_-y)\times w_2\times(y_\_-y) Δb1=μb1LΔLOSS=μ×2(y_y)×w2×(y_y)

为什么这里要引入学习率的概念呢,有一篇博客非常形象的说明了这个问题,感兴趣的可以看看原文,省流量的可以看下面这个表格,这个表格说明了当学习率等于1的时候可能遇到的困境

轮数当前轮参数值梯度x学习率更新后参数值
152x5x1=105-10=-5
2-52x-5x1=-10-5-(-10)=5
352x5x1=105-10=-5

很明显,这里参数没有更新,输出结果就像大禹治水,三过家门而不入,训练也就毫无意义

代码

自己懒得写了,在网上找了一个,出处:CSDN
其实这个代码还挺难找的,各位也知道现在CSDN的内容环境,可以用一拖四来形容

import numpy as np
import matplotlib.pyplot as plt

# 激活函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))
 
# 向前传递
def forward(X, W1, W2, W3, b1, b2, b3):
    # 隐藏层1
    Z1 = np.dot(W1.T,X)+b1  # X=n*m ,W1.T=h1*n,b1=h1*1,Z1=h1*m
    A1 = sigmoid(Z1)  # A1=h1*m
    # 隐藏层2
    Z2 = np.dot(W2.T, A1) + b2  # W2.T=h2*h1,b2=h2*1,Z2=h2*m
    A2 = sigmoid(Z2)  # A2=h2*m
    # 输出层
    Z3=np.dot(W3.T,A2)+b3  # W3.T=(h3=1)*h2,b3=(h3=1)*1,Z3=1*m
    A3=sigmoid(Z3)  # A3=1*m
 
    return Z1,Z2,Z3,A1,A2,A3
 
# 反向传播
def backward(Y,X,A3,A2,A1,Z3,Z2,Z1,W3,W2,W1):
    n,m = np.shape(X)
    dZ3 = A3-Y # dZ3=1*m
    dW3 = 1/m *np.dot(A2,dZ3.T) # dW3=h2*1
    db3 = 1/m *np.sum(dZ3,axis=1,keepdims=True) # db3=1*1
 
    dZ2 = np.dot(W3,dZ3)*A2*(1-A2) # dZ2=h2*m
    dW2 = 1/m*np.dot(A1,dZ2.T) #dw2=h1*h2
    db2 = 1/m*np.sum(dZ2,axis=1,keepdims=True) #db2=h2*1
 
    dZ1 = np.dot(W2, dZ2) * A1 * (1 - A1) # dZ1=h1*m
    dW1 = 1 / m * np.dot(X, dZ1.T)  # dW1=n*h
    db1 = 1 / m * np.sum(dZ1,axis=1,keepdims=True)  # db1=h*m
 
    return dZ3,dZ2,dZ1,dW3,dW2,dW1,db3,db2,db1
 
def costfunction(Y,A3):
    m, n = np.shape(Y)
    J=np.sum(Y*np.log(A3)+(1-Y)*np.log(1-A3))/m
    # J = (np.dot(y, np.log(A2.T)) + np.dot((1 - y).T, np.log(1 - A2))) / m
    return -J
 
# Data = np.loadtxt("gua2.txt")
# X = Data[:, 0:-1]
# X = X.T
# Y = Data[:, -1]
# Y=np.reshape(1,m)
X=np.random.rand(100,200)
n,m=np.shape(X)
Y=np.random.rand(1,m)
n_x=n
n_y=1
n_h1=5
n_h2=4
W1=np.random.rand(n_x,n_h1)*0.01
W2=np.random.rand(n_h1,n_h2)*0.01
W3=np.random.rand(n_h2,n_y)*0.01
b1=np.zeros((n_h1,1))
b2=np.zeros((n_h2,1))
b3=np.zeros((n_y,1))
alpha=0.1
number=10000
for i in range(0,number):
    Z1,Z2,Z3,A1,A2,A3=forward(X,W1,W2,W3,b1,b2,b3)
    dZ3, dZ2, dZ1, dW3, dW2, dW1, db3, db2, db1=backward(Y,X,A3,A2,A1,Z3,Z2,Z1,W3,W2,W1)
    W1=W1-alpha*dW1
    W2=W2-alpha*dW2
    W3=W3-alpha*dW3
    b1=b1-alpha*db1
    b2=b2-alpha*db2
    b3=b3-alpha*db3
    J=costfunction(Y,A3)
    if (i%100==0):
        print(i)
    plt.plot(i,J,'ro')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/368403.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

「RISC-V Arch」SBI 规范解读(下)

第六章 定时器扩展(EID #0x54494D45"TIME") 这个定时器扩展取代了遗留定时器扩展(EID #0x00),并遵循 v0.2 中定义的调用规约。 6.1 函数:设置定时器(FID #0) struct sbi…

同花顺面试(部分)

又没录上,只能凭零星记忆来记录了 知识图谱有了解吗,知道用在你们系统哪个环节吗tomcat内存设置的多大。32位系统的话有限制吗?复盘:后面一想,可能是说32位系统地址寻址空间有限,内存最多4Ggprc你们用的序…

消息队列--Kafka

Kafka简介集群部署配置Kafka测试Kafka1.Kafka简介 数据缓冲队列。同时提高了可扩展性。具有峰值处理能力,使用消息队列能够使关键组件顶住突发的访问压力,而不会因为突发的超负荷的请求而完全崩溃。 Kafka是一个分布式、支持分区的(partition…

C/C++开发,无可避免的内存管理(篇一)-内存那些事

一、内存管理机制 任何编程语言在访问和操作内存时都会涉及大量的计算工作。但相对其他语言,c/c开发者必须自行采取措施确保所访问的内存是有效的,并且与实际物理存储相对应,以确保正在执行的任务不会访问不应该访问的内存位置。C/C语言及编译…

【Java】volatile

一、volatile volatile是Java虚拟机提供的轻量级的同步机制,它有3个特性: 1)保证可见性 2)不保证原子性 3)禁止指令重排 当写一个volatile变量时,JMM会把该…

openEuler部署Ceph集群(块存储)

openEuler部署Ceph集群1 目标2 环境2.1 服务器信息2.2 软件信息3 部署流程3.1 获取系统镜像3.2 创建虚拟机3.3 配置虚拟机3.3.1 配置互信3.3.2 关闭防火墙3.3.3 配置免密登录3.3.4 配置NTP3.3.4.1 安装NTP服务3.3.4.2 配置NTP服务端3.3.4.3 配置NTP客户端3.3.4.4 启动NTP服务3.…

pyqt5通过CANoe COM Server来操作CANoe仿真工程

文章目录前言一、COM接口技术二、UI界面设计三、功能实现四、工程运行测试前言 继续学习《CANoe开发从入门到精通》。 今天在《CANoe仿真工程开发》的基础上,开发实现pyqt5应用程序来操控CANoe工程。 一、COM接口技术 COM(Component Object Model&…

Linux基础命令-find搜索文件位置

文章目录 find 命令介绍 语法格式 命令基本参数 参考实例 1)在root/data目录下搜索*.txt的文件名 2)搜索一天以内最后修改时间的文件;并将文件删除 3)搜索777权限的文件 4)搜索一天之前变动的文件复制到test…

不懂什么是智慧工厂,看这篇文章就够了!

一、智慧工厂是什么? 一直以来,自动化在某种程度上始终是工厂的一部分,甚至高水平的自动化也非新生事物。然而,“自动化”一词通常表示单一且独立的任务或流程的执行。过去,机器自行“决策”的情况往往是以自动化为基…

【基础篇】9 # 排序:冒泡排序(Bubble Sort)、插入排序(Insertion Sort)、选择排序(Selection Sort)

说明 【数据结构与算法之美】专栏学习笔记 如何分析一个排序算法? 1、排序算法的执行效率 最好情况、最坏情况、平均情况时间复杂度时间复杂度的系数、常数 、低阶比较次数和交换(或移动)次数 2、排序算法的内存消耗 3、排序算法的稳定…

Fabric.js使用说明Part 2

目录一、Fabric.js使用说明Part 1Fabric.js简介 开始方法事件canvas常用属性对象属性图层层级操作复制和粘贴二、Fabric.js使用说明Part 2锁定拖拽和缩放画布分组动画图像滤镜渐变右键菜单删除三、Fabric.js使用说明Part 3自由绘画绘制背景图片绘制文本绘制线和路径一、锁定Fab…

传统豪华品牌引领?智能座舱进入「沉浸式娱乐体验」新周期

智能座舱正在进入硬件定型、软件(功能)升级以及多应用融合的新周期。 高工智能汽车研究院监测数据显示,2022年中国市场(不含进出口)乘用车搭载智能数字座舱(大屏语音车联网OTA)前装标配交付795…

【死磕数据库专栏启动】在CentOS7中安装 MySQL5.7版本实战

文章目录前言实验环境一. 安装MySQL1.1 配置yum源1.2 安装之前的环境检查1.3 下载MySQL的包1.4 开始使用yum安装1.5 启动并测试二. 设置新密码并重新启动2.1 设置新密码2.2 重新登录测试总结前言 学习MySQL是一件比较枯燥的事情,学习开始之前要先安装MySQL数据库&a…

【Linux修炼】14.磁盘结构/文件系统/软硬链接/动静态库

每一个不曾起舞的日子,都是对生命的辜负。 磁盘结构/文件系统/软硬链接/动静态库前言一.磁盘结构1.1 磁盘的物理结构1.2 磁盘的存储结构1.3 磁盘的逻辑结构二.理解文件系统2.1 对IO单位的优化2.2 磁盘分区与分组2.3 分组的管理方法2.4 文件操作三.软硬链接3.1理解硬…

测试4年裸辞失业,面试17k的测试岗被按在地上摩擦,结局让我崩溃大哭...

作为IT行业的大热岗位——软件测试,只要你付出了,就会有回报。说它作为IT热门岗位之一是完全不虚的。可能很多人回说软件测试是吃青春饭的,但放眼望去,哪个工作不是这样的呢?会有哪家公司愿意养一些闲人呢?…

「smardaten」上架钉钉应用中心!让进步再一次发生

使用钉钉的团队小伙伴们,smardaten给您送来福利啦~为了给更多团队提供更优质的应用开发体验,方便用户在线、快速使用无代码,数睿数据近期在【钉钉应用中心】发布smardaten在线版本。继与华为云、亚马逊云建立战略合作之后,smardat…

微信小程序实现分享到朋友圈的功能

分享朋友圈官方API:分享到朋友圈 1、分享到朋友圈接口设置事项 2、onShareTimeline()注意事项 3、分享朋友圈后,测试发现,没有数据请求。 用户在朋友圈打开分享的小程序页面,并不会真正打开小程序,而是进入一个“小程…

浏览器缓存策略

先走强缓存,再走协商缓存 强缓存 不发送请求,直接使用缓存的内容 状态码200 当前会话没有关闭的话就是走memory cache,否则就是disk cache 由响应头的 Pragma(逐渐废弃,优先级最高),catch-…

LeetCode 817. 链表组件

LeetCode 817. 链表组件 难度:middle\color{orange}{middle}middle 题目描述 给定链表头结点 headheadhead,该链表上的每个结点都有一个 唯一的整型值 。同时给定列表 numsnumsnums,该列表是上述链表中整型值的一个子集。 返回列表 numsnu…

自动驾驶仿真:ECU TEST 、VTD、VERISTAND连接配置

文章目录一、ECU TEST 连接配置简介二、TBC配置 test bench configuration三、TCF配置 test configuration提示:以下是本篇文章正文内容,下面案例可供参考 一、ECU TEST 连接配置简介 1、ECU TEST(简称ET),用于HIL仿…