前向传播与反向传播参数的更新方式(略高于高中数学水平)(附公式、代码)

news2025/7/29 12:16:07

前向传播与反向传播意义及其参数的更新方式

文章目录

  • 前向传播与反向传播意义及其参数的更新方式
      • 一、前言
      • 二、前反向传播的作用
      • 三、前向传播
      • 四、反向传播
        • 代码

一、前言

因为本身非科班出身,数学又学的很差,一直都是傻瓜式地用tensorflow和pytorch搭网络。前一段时间竞赛的时候尝试着用简单神经网络做了个题,同学突然问起反向传播的具体原理,一时语塞,遂下决心把这个问题搞明白。这篇学习笔记将以我的认知顺序也就是由浅至深的顺序叙述,里面可能涉及到一些神经网络的基础知识,比如学习率、激活函数、损失函数等,详情可以看看这里,本文不再赘述

写文章的时候查阅了一些资料,感觉写得最好的是这篇文章,我的一些思路也有所参考,推荐去看看,记得给大佬点star : )

二、前反向传播的作用

这个问题应该大部分接触过神经网络的人都有所了解,我最开始的认知也就停留在这一步

前向传播,也叫正向传播,其实就是参数在神经网络中从输入层到输出层传输过程

反向传播,其实就是根据输出层的输出实际值的差距,更新神经网络中参数的过程

而一次正向传播加上一次反向传播就是一次网络的学习

话虽如此,参数在网络中到底是如何变化的呢

三、前向传播

首先我们来看一个神经网络,这个神经网络是如此的简单,这种简单结构的网络可以使我们更好地理解神经网络的工作方式。

请添加图片描述

所谓前向传播,其实就是将神经网络的上一层作为下一层的输入,并计算下一层的输出,一直到输出层位置

如上图,假如输入层输入x,那么参数前向传播到隐藏层其实就是输入x权重矩阵相乘加上偏置项之和再通过激活函数,假设我们使用的激活函数为
f ( x ) = x × 2 f(x)=x\times2 f(x)=x×2
此时输入层的输出就是
f ( x × w 1 + b 1 ) = ( x × w 1 + b 1 ) × 2 f(x\times w_1 + b_1) = (x\times w_{1}+b_{1})\times2 f(x×w1+b1)=(x×w1+b1)×2
当参数继续向前传播,通过隐藏层的输出到输出层,其值为
∑ f ( f ( x × w 1 + b 1 ) × w 2 + b 2 ) = 2 × ( 2 × ( w 1 x + b 1 ) × w 2 + b 2 ) ( w 2 和 b 2 是一个 1 ∗ 3 的向量,比较复杂,就不展开了) \sum f(f(x\times w_1 + b_1)\times w_2 + b_2)=2\times(2\times(w_1x+b_1)\times w_2+b_2)(w2和b2是一个1 * 3的向量,比较复杂,就不展开了) f(f(x×w1+b1)×w2+b2)=2×(2×(w1x+b1)×w2+b2)w2b2是一个13的向量,比较复杂,就不展开了)
上面的式子的值其实就是神经网络的输出了,这样两个算式描述了一次前向传播的全部过程

四、反向传播

由于反向传播涉及到导数运算,而我的数学能力已经退化到小学水平了,所以这里我们直接使用一个1 * 1 * 1的 “神经网络” 来做演示
请添加图片描述

这里我们的损失函数选择使用最常见的均方误差(MSE),即定义损失值为预测值与实际值的差的平方除以样本数,这个损失函数对异常值比较敏感,适用于回归问题
L O S S = M S E ( y _ , y ) = ∑ i = 1 n ( y − y _ ) 2 n LOSS=MSE({y_\_},y) = \frac{{\sum\nolimits_{i = 1}^n {{{(y - y_\_)}^{2}}} }}{n} LOSS=MSE(y_,y)=ni=1n(yy_)2
而更新参数的依据,就是使最后预测的结果朝着损失函数值减小的方向移动,故我们用损失函数对每一个参数求偏导,让各个参数往损失函数减小的方向变化。假设我们这里的激活函数为
f ( x ) = x f(x) = x f(x)=x

损失函数对各参数求偏导的结果如下
定义输入层为输出为 h 1 ,隐藏层输出为 h 2 , y 预测值为 y _ 定义输入层为输出为h_1,隐藏层输出为h_2,y预测值为y_\_ 定义输入层为输出为h1,隐藏层输出为h2y预测值为y_
∂ L ∂ y = 2 ( y _ − y ) / / 单样本情况下, n = 1 \frac{\partial L}{\partial y} =2(y_\_-y) \quad//单样本情况下,n=1 yL=2(y_y)//单样本情况下,n=1
∂ L ∂ w 2 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ w 2 = 2 ( y _ − y ) × 1 × h 1 = 2 ( y _ − y ) × 1 × ( w 1 x + b 1 ) \frac{ \partial L }{ \partial w_2 } =\frac{ \partial L }{ \partial y }\times\frac{ \partial y }{ \partial h_2 } \times\frac{ \partial h_2 }{ \partial w_2 } =2(y_\_-y)\times1\times h_1 =2(y_\_-y)\times1\times (w_1x+b_1) w2L=yL×h2y×w2h2=2(y_y)×1×h1=2(y_y)×1×(w1x+b1)
∂ L ∂ b 2 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ b 2 = 2 ( y _ − y ) × 1 × 1 = 2 ( y _ − y ) \frac{\partial L}{\partial b_2} =\frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_2} \times \frac{\partial h_2}{\partial b_2} =2(y_\_-y)\times1\times 1 =2(y_\_-y) b2L=yL×h2y×b2h2=2(y_y)×1×1=2(y_y)
∂ L ∂ w 1 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ h 1 × ∂ h 1 ∂ w 1 = 2 ( y _ − y ) × 1 × w 2 × x \frac{\partial L}{\partial w_1} =\frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_2} \times \frac{\partial h_2}{\partial h_1} \times \frac{\partial h_1}{\partial w_1} =2(y_\_-y)\times1\times w_2\times x w1L=yL×h2y×h1h2×w1h1=2(y_y)×1×w2×x

∂ L ∂ b 1 = ∂ L ∂ y × ∂ y ∂ h 2 × ∂ h 2 ∂ h 1 × ∂ h 1 ∂ b 1 = 2 ( y _ − y ) × 1 × w 2 × 1 \frac{\partial L}{\partial b_1} =\frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_2} \times \frac{\partial h_2}{\partial h_1} \times \frac{\partial h_1}{\partial b_1} =2(y_\_-y)\times1\times w_2\times 1 b1L=yL×h2y×h1h2×b1h1=2(y_y)×1×w2×1

反向传播算法建立在梯度下降法的基础上,已经算出各参数偏导的情况下,需要使用梯度下降法进行参数更新,我们以学习率为μ为例,各参数的更新如下

Δ w 2 = − μ ∂ L ∂ w 2 Δ L O S S = − μ × 2 ( y _ − y ) × 1 × ( w 1 x + b 1 ) ) × ( y _ − y ) \Delta w_2 = -\mu \frac{ \partial L }{ \partial w_2 } \Delta LOSS =-\mu\times2(y_\_-y)\times1\times (w_1x+b_1))\times(y_{\_}-y) Δw2=μw2LΔLOSS=μ×2(y_y)×1×(w1x+b1))×(y_y)

Δ b 2 = − μ ∂ L ∂ b 2 Δ L O S S = − μ × 2 ( y _ − y ) × ( y _ − y ) \Delta b_2 =-\mu \frac{\partial L}{\partial b_2}\Delta LOSS =-\mu\times2(y_\_-y)\times(y_\_-y) Δb2=μb2LΔLOSS=μ×2(y_y)×(y_y)

Δ w 1 = − μ ∂ L ∂ w 1 Δ L O S S = − μ × 2 ( y _ − y ) × w 2 × x × ( y _ − y ) \Delta w_1 =-\mu\frac{\partial L}{\partial w_1}\Delta LOSS =-\mu \times 2(y_\_-y)\times w_2\times x\times (y_\_-y) Δw1=μw1LΔLOSS=μ×2(y_y)×w2×x×(y_y)

Δ b 1 = − μ ∂ L ∂ b 1 Δ L O S S = − μ × 2 ( y _ − y ) × w 2 × ( y _ − y ) \Delta b_1 =-\mu \frac{\partial L}{\partial b_1}\Delta LOSS =-\mu \times 2(y_\_-y)\times w_2\times(y_\_-y) Δb1=μb1LΔLOSS=μ×2(y_y)×w2×(y_y)

为什么这里要引入学习率的概念呢,有一篇博客非常形象的说明了这个问题,感兴趣的可以看看原文,省流量的可以看下面这个表格,这个表格说明了当学习率等于1的时候可能遇到的困境

轮数当前轮参数值梯度x学习率更新后参数值
152x5x1=105-10=-5
2-52x-5x1=-10-5-(-10)=5
352x5x1=105-10=-5

很明显,这里参数没有更新,输出结果就像大禹治水,三过家门而不入,训练也就毫无意义

代码

自己懒得写了,在网上找了一个,出处:CSDN
其实这个代码还挺难找的,各位也知道现在CSDN的内容环境,可以用一拖四来形容,但这部分代码写的挺不错,已经向原作者征求使用许可了,但作者现在还没回,如果他的回复是不同意,我会删除这部分代码再自己写一个

import numpy as np
import matplotlib.pyplot as plt

# 激活函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))
 
# 向前传递
def forward(X, W1, W2, W3, b1, b2, b3):
    # 隐藏层1
    Z1 = np.dot(W1.T,X)+b1  # X=n*m ,W1.T=h1*n,b1=h1*1,Z1=h1*m
    A1 = sigmoid(Z1)  # A1=h1*m
    # 隐藏层2
    Z2 = np.dot(W2.T, A1) + b2  # W2.T=h2*h1,b2=h2*1,Z2=h2*m
    A2 = sigmoid(Z2)  # A2=h2*m
    # 输出层
    Z3=np.dot(W3.T,A2)+b3  # W3.T=(h3=1)*h2,b3=(h3=1)*1,Z3=1*m
    A3=sigmoid(Z3)  # A3=1*m
 
    return Z1,Z2,Z3,A1,A2,A3
 
# 反向传播
def backward(Y,X,A3,A2,A1,Z3,Z2,Z1,W3,W2,W1):
    n,m = np.shape(X)
    dZ3 = A3-Y # dZ3=1*m
    dW3 = 1/m *np.dot(A2,dZ3.T) # dW3=h2*1
    db3 = 1/m *np.sum(dZ3,axis=1,keepdims=True) # db3=1*1
 
    dZ2 = np.dot(W3,dZ3)*A2*(1-A2) # dZ2=h2*m
    dW2 = 1/m*np.dot(A1,dZ2.T) #dw2=h1*h2
    db2 = 1/m*np.sum(dZ2,axis=1,keepdims=True) #db2=h2*1
 
    dZ1 = np.dot(W2, dZ2) * A1 * (1 - A1) # dZ1=h1*m
    dW1 = 1 / m * np.dot(X, dZ1.T)  # dW1=n*h
    db1 = 1 / m * np.sum(dZ1,axis=1,keepdims=True)  # db1=h*m
 
    return dZ3,dZ2,dZ1,dW3,dW2,dW1,db3,db2,db1
 
def costfunction(Y,A3):
    m, n = np.shape(Y)
    J=np.sum(Y*np.log(A3)+(1-Y)*np.log(1-A3))/m
    # J = (np.dot(y, np.log(A2.T)) + np.dot((1 - y).T, np.log(1 - A2))) / m
    return -J
 
# Data = np.loadtxt("gua2.txt")
# X = Data[:, 0:-1]
# X = X.T
# Y = Data[:, -1]
# Y=np.reshape(1,m)
X=np.random.rand(100,200)
n,m=np.shape(X)
Y=np.random.rand(1,m)
n_x=n
n_y=1
n_h1=5
n_h2=4
W1=np.random.rand(n_x,n_h1)*0.01
W2=np.random.rand(n_h1,n_h2)*0.01
W3=np.random.rand(n_h2,n_y)*0.01
b1=np.zeros((n_h1,1))
b2=np.zeros((n_h2,1))
b3=np.zeros((n_y,1))
alpha=0.1
number=10000
for i in range(0,number):
    Z1,Z2,Z3,A1,A2,A3=forward(X,W1,W2,W3,b1,b2,b3)
    dZ3, dZ2, dZ1, dW3, dW2, dW1, db3, db2, db1=backward(Y,X,A3,A2,A1,Z3,Z2,Z1,W3,W2,W1)
    W1=W1-alpha*dW1
    W2=W2-alpha*dW2
    W3=W3-alpha*dW3
    b1=b1-alpha*db1
    b2=b2-alpha*db2
    b3=b3-alpha*db3
    J=costfunction(Y,A3)
    if (i%100==0):
        print(i)
    plt.plot(i,J,'ro')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/367887.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【容器】学习docker容器网络

在前面讲解容器基础时,曾经提到过一个 Linux 容器能看见的“网络栈”,实际上是被隔离在它自己的 Network Namespace 当中的。 而所谓“网络栈”,就包括了:网卡(Network Interface)、回环设备(L…

Mac mini 外接移动硬盘无法写入或者无法显示的解决方法

文章目录1. 背景2. 让NTFS格式的移动硬盘正常读写方法3. 打开“启动安全性实用工具”4. 更改“安全启动”设置1. 背景 刚买mac min(2023年2月3日)不久,发现macOS的玩起来并不容易,勇习惯了windows系统的习惯,感觉 mac…

【storybook】你需要一款能在独立环境下开发组件并生成可视化控件文档的框架吗?(二)

storybook回顾继续说说用法配置文件介绍回顾 上篇博客地址: https://blog.csdn.net/tuzi007a/article/details/129192502说了部分用法。 继续说说用法 配置文件介绍 开发环境的配置都在.storybook目录中,里面包含了2个文件 main.js preview.js先看m…

STM32 触摸屏移植GUI控制控件

目录 1、emWin 支持指针输入设备。 2、 模拟触摸屏驱动 3、实现触摸屏的流程 3.1 实现硬件函数 3.2 实现对GUI_TOUCH_Exec()的定期调用 3.3 使用上一步确定的值,在初始化函数LCD_X_Config()当中添加对GUI_TOUCH_Calibrate()的调用 4、…

Kubernetes入门教程 --- 使用二进制安装

Kubernetes入门教程 --- 使用二进制安装1. Introduction1.1 架构图1.2 关键字介绍1.3 简述2. 使用Kubeadm Install2.1 申请三个虚拟环境2.2 准备安装环境2.3 配置yum源2.4 安装Docker2.4.1 配置docker加速器并修改成k8s驱动2.5 时间同步2.6 安装组件3. 基础知识3.1 Pod3.2 控制…

【一些回忆】2022.02.26-2023.02.26 一个普通男孩蜕变的365天

💃🏼 本人简介:男 👶🏼 年龄:18 🤞 作者:那就叫我亮亮叭 📕 专栏:一些回忆 为什么选择在这个时间节点回忆一下呢? 一是因为今天距离2023高考仅剩1…

双指针法应用总结

一、双指针法(一)概况1.类型:快慢指针(相同方向循环)、对撞指针(相反方向循环)、滑动窗口2.用途:提高效率,通常能将将O(n^2)的时间复杂度,降为O(n)3.可应用的…

selenium基本操作

爬虫与反爬虫之间的斗争爬虫:对某个网站数据或图片感兴趣,开始抓取网站信息;网站:请求次数频繁,并且访问ip固定,user_agent也是python,开始限制访问;爬虫:通过设置user_a…

数据库浅谈之 DuckDB AGG 底层实现

数据库浅谈之 DuckDB AGG 底层实现 HELLO,各位博友好,我是阿呆 🙈🙈🙈 这里是数据库浅谈系列,收录在专栏 DATABASE 中 😜😜😜 本系列阿呆将记录一些数据库领域相关的知…

离线维基百科阅读器Kiwix Serve

本文软件是网友 刘源 推荐的,因为他已经安装成功了,所以老苏拖拖拉拉的就从去年拖到了现在; 😂 什么是 Kiwix ? Kiwix 是一个用于浏览离线内容的自由开源浏览器,最初用于离线浏览维基百科。Kiwix 可以读取以压缩形式存…

[神经网络]基干网络之VGG、ShuffleNet

一、VGG VGG是传统神经网络堆叠能达到的极限深度。 VGG分为VGG16和VGG19,其均有以下特点: ①按2x2的Pooling层,网络可以分成若干段 ②每段之内由若干same卷积操作构成,段内Feature Map数量固定不变; ③Feature Map按2的…

对个人博客系统进行web自动化测试(包含测试代码和测试的详细过程)

目录 一、总述 二、登录页面测试 一些准备工作 验证页面显示是否正确 验证正常登录的情况 该过程中出现的问题 验证登录失败的情况 关于登录界面的总代码 测试视频 三、注册界面的自动化测试 测试代码 过程中出现的bug 测试视频 四、博客列表页测试(…

【Leedcode】数据结构中链表必备的面试题(第四期)

【Leedcode】数据结构中链表必备的面试题(第四期) 文章目录【Leedcode】数据结构中链表必备的面试题(第四期)1.题目2.思路图解(1)思路一(2)思路二3.源代码总结1.题目 相交链表: 如下(示例)&…

小白福利!我开发了一个快速部署库

1、开发背景 很多入门的同学,在跟着视频敲完代码之后,在打包出来的产物犯了难 如果是 hash 路由,要么使用后端部署,要么使用 github 或者 gitee 提供的静态部署服务如果是 history 路由,那只能使用后端框架进行部署&a…

内网渗透(五十三)之域控安全和跨域攻击-利用域信任密钥获取目标域控

系列文章第一章节之基础知识篇 内网渗透(一)之基础知识-内网渗透介绍和概述 内网渗透(二)之基础知识-工作组介绍 内网渗透(三)之基础知识-域环境的介绍和优点 内网渗透(四)之基础知识-搭建域环境 内网渗透(五)之基础知识-Active Directory活动目录介绍和使用 内网渗透(六)之基…

前端学习日记——Vue之Vuex初识(一)

前言 学习前端一段时间了,因为一直是做Python开发,所以凭借着语言的通性学习Javascript、Vue轻快很多,但一些碎片化的知识及插件的使用方法还是需要记录一下,时而复习,形成系统化的知识体系(PS:…

【Linux线程池】

Linux线程池Linux线程池线程池的概念线程池的优点线程池的应用场景线程池的实现Linux线程池 线程池的概念 线程池是一种线程使用模式。 线程过多会带来调度开销,进而影响缓存局部和整体性能,而线程池维护着多个线程,等待着监督管理者分配可并…

JavaScript if…else 语句

条件语句用于基于不同的条件来执行不同的动作。条件语句通常在写代码时,您总是需要为不同的决定来执行不同的动作。您可以在代码中使用条件语句来完成该任务。在 JavaScript 中,我们可使用以下条件语句:if 语句 - 只有当指定条件为 true 时&a…

【企业云端全栈开发实践-3】Spring Boot文件上传服务+拦截器

本节目录一、静态资源访问二、文件上传原理三、拦截器3.1 拦截器定义代码3.2 拦截器注册一、静态资源访问 使用IDEA创建Spring Boot项目时,会默认创建classpath://static/目录,静态资源一般放在这个目录下即可。 如果默认的静态资源过滤策略不能满足开…

做独立开发者,能在AppStore赚到多少钱?

成为一名独立开发者,不用朝九晚五的上班,开发自己感兴趣的产品,在AppStore里赚美金,这可能是很多程序员的梦想,今天就来盘一盘,这个梦想实现的概率有多少。 先来了解一些数据: 2022年5月26日&am…