Lecture3 梯度下降(Gradient Descent)

news2024/5/18 3:47:52

目录

1 问题背景

2 批量梯度下降 (Batch Gradient Descent)

3 鞍点(Saddle Point)

3 随机梯度下降 (Stochastic Gradient Descent)

4 小批量梯度下降 (Mini-batch Gradient Descent)


1 问题背景

图1 上节课讲述的穷举法求最优权重值

  在Lecture2中,介绍了使用穷举法来确定最优\omega值,然而当遇到\omega范围较大,或者数量过多等情况时,穷举法的时间复杂度过大。因此,我们需要优化该算法。

2 批量梯度下降 (Batch Gradient Descent)

  在这次课中,介绍了一种寻找\omega最优值的算法——批量梯度下降 (Batch Gradient Descent, BGD)

简单介绍下该算法。首先对于下图:

图2 训练过程中权重初始值与最优值的位置

  假设我们目前的起始\omega位于上图红色点,为了找到最优\omega点(位于绿点),那么我们需要向左边移动,这样才能到达最优\omega点。

图3 我们需要计算梯度以向左移动权值点

 

   如何让权值点向左还是向右移动呢?此时我们需要计算当前点的梯度(Gradient),也就是用成本函数对权重进行求导,如果梯度<0,则向函数值递减方向移动;梯度>0,则向函数值递增方向移动。

  因为要移动起来,所以我们每移动一步,就要更新一下\omega值。更新函数如下图Update处:

图4 更新权重值的函数公式

  在这个更新函数中, α代表学习率(Learning Rate),学习率是机器学习中常用的一个超参数,它定义了每次更新参数时步长的大小,即每次更新参数时参数值变化的幅度。如果学习率设置得过大,所求结果可能会在最优解的附近来回震荡,而无法找到全局最优解。如果学习率设置得过小,那么模型的训练将会非常缓慢,甚至找不到最优解。

  这个式子中,梯度前面用了减号,是为了朝函数值递减方向,也就是往最优\omega所在的点移动,所以在梯度前面加负号。 就这样持续一步步地更新\omega,直到找到最优\omega

下面我们来具体讲讲如何去计算更新函数中的\frac{\partial cost}{\partial \omega} :

图5 更新函数

计算过程中,需要用到上节课总结的两个公式:

图6 均方误差MSE

图7 预测值y_hat

  接着把上述两个公式代入原式:

图8 求导过程

   蓝色处,因为cost=MSE,所以直接代入上一节课的MSE公式,然后对\omega求导。

  绿色处,由有理运算法则,和的导数等于导数的和,所以这里可以把\frac{\partial}{\partial \omega}移入求和式子中,对里面先进行求导后,再求和相加。

  黄色处,根据复合导数的链式求导法进行求导。

代码实现

from matplotlib import pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0  # 初始权重,由这个权重开始迭代

'''线性模型,算出预测值y_hat'''
def forward(x):
    return x * w


'''均方误差MSE'''
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)  # 算出y_hat
        cost += (y_pred - y) ** 2  # (y_hat - y)²
    return cost / len(xs)  # 除以样本总数求均值


'''梯度下降公式'''
def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)


print('Predict (before training)', 4, forward(4))  # 训练前,模型对输入的4的最终预测结果

cost_list = [] # 保存每轮迭代后的cost值
epoch_list = [] # 保存每轮的迭代后的epoch值
for epoch in range(100):  # 进行100轮训练
    cost_val = cost(x_data, y_data)
    grad_val = gradient(x_data, y_data)
    w -= 0.01 * grad_val # 使用梯度下降法更新权重,0.01表示学习率
    print('Epoch:', epoch, 'w=%.2f' % w, 'loss=%.2f' % cost_val)
    cost_list.append(cost_val)
    epoch_list.append(epoch)
print('Predict (after training)', 4, forward(4))  # 训练后,模型对输入的4的最终预测结果

'''绘图'''
plt.plot(epoch_list, cost_list)
plt.ylabel('Cost')
plt.xlabel('Epoch')
plt.grid()
plt.show()
图9 输出结果图像

  将MSE公式和Linear Model公式代入整合,的最终更新函数:

图10 最终的更新函数

补充

训练后的结果一般来说,cost会趋于收敛情况

图11 通常训练后cost图像会趋于收敛

 如果发生如下情况,说明训练失败,原因有很多,其中之一可能是学习率取得太大:

图12 训练失败

 

  这就是批量梯度下降算法,本质上是一个贪心算法(Greedy Algorithm)。不过该算法有局限性,比如当前的预测值\omega正好位于下图绿线处,因为再往右移动会梯度会发生变化,使得程序直接终止,于是误将红的点作为最优\omega值,而忽略了处于蓝色点的最优\omega值:

图13 局部最优和全局最优示意图

   我们把上图中的红点称为局部最优点(Local Optimum),蓝色点称为全局最优点(Global Optimum)。因此对于该梯度下降算法,很可能会找到局部最优点,而忽略了全局最优点。不过这种现象不必担心,因为在实际训练中,往往很难陷入局部最优点。

3 鞍点(Saddle Point)

  在实际训练中,往往很难陷入局部最优点,而最需要解决的问题是鞍点(Saddle Point),鞍点是机器学习和数学中的一个概念,它指的是一个特殊的局部极小值,在某些方向上是极小值,但在其他方向上是极大值。在一元函数中,梯度=0的点就是鞍点。比如下图中,红色小球所处的位置就在鞍点,此时梯度为零,会导致更新函数无法更新(因为梯度=0,\omega=\omega-α*0相当于没有发生更新):

图14 鞍点示意图

 

 从多维角度来分析,比如下图红球处于马鞍面(Saddle Surface),从一个切面看可以处于最小值,从另一个切面看又处于最大值:

图15 位于马鞍面的鞍点

 

  在优化问题中,鞍点是一种特殊的局部最优解,是一个难以优化的点,因为优化算法可能很难从鞍点附近找到全局最优解。这是因为,如果优化算法在鞍点附近搜索,它可能会被误导到其他附近的局部最优解,而不是真正的全局最优解。所以在深度学习中,需要克服的最大问题就是鞍点而非局部最优问题。

3 随机梯度下降 (Stochastic Gradient Descent)

  随机梯度下降 (Stochastic Gradient Descent, SGD)在深度学习中很常用,和BGD算法的区别是,BGD使用所有的样本的均值的平均损失来作为\omega的更新依据,而SGD是从所有样本中随机选择单个样本的损失值来对\omega进行更新。

  随机梯度下降的优点是,每次仅使用一个数据点的梯度,因此在每次迭代时都有可能沿着非0梯度的方向更新参数,这样就避免陷入到鞍点导致无法更新参数。

图16 BGD到SGD公式上的改变

代码实现

import random

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0


def forward(x):
    return x * w


def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2


def gradient(x, y):
    return 2 * x * (x * w - y)


print('Predict (before training)', 4, forward(4))
for epoch in range(100):
    t = random.randrange(0, 3) # 随机得到一个样本
    x = x_data[t]
    y = y_data[t]
    grad = gradient(x, y)
    w = w - 0.01 * grad
    print("\tgrad: ", x, y, '%.2f' % grad)
    l = loss(x, y)
    print("progress:", epoch, "w=%.2f" % w, "loss=%.2f" % l)
print('Predict (after training)', 4, forward(4))

部分输出结果

Predict (before training) 4 4.0
    grad:  3.0 6.0 -18.00
progress: 0 w=1.18 loss=6.05
    grad:  2.0 4.0 -6.56
progress: 1 w=1.25 loss=2.28
    grad:  3.0 6.0 -13.58
progress: 2 w=1.38 loss=3.44
    grad:  1.0 2.0 -1.24
progress: 3 w=1.39 loss=0.37
    grad:  2.0 4.0 -4.85
progress: 4 w=1.44 loss=1.24

···

    grad:  1.0 2.0 -0.00
progress: 97 w=2.00 loss=0.00
    grad:  1.0 2.0 -0.00
progress: 98 w=2.00 loss=0.00
    grad:  2.0 4.0 -0.00
progress: 99 w=2.00 loss=0.00
Predict (after training) 4 7.999910864525451

4 小批量梯度下降 (Mini-batch Gradient Descent)

  SGD算法虽然可以在一定程度上避免陷入局部最优以及鞍点问题,但是运算所需时间复杂度过高,每次仅使用一个数据点的梯度,因此它的收敛速度通常比较慢。

  因此有一个折中的办法,就是使用小批量梯度下降 (Mini-batch Gradient Descent) 算法。简单来说,小批量梯度下降是一种介于批量梯度下降和随机梯度下降之间的优化算法。结合了这两种方法,通过使用小的随机选择的训练数据子集(称为mini-batch)计算损失函数关于参数的梯度的平均值来更新模型参数。

  总之,小批量梯度下降算法实现了BGD的高计算效率和SGD的良好收敛性之间的平衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/333457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python+django校园在线拍卖系统pycharm项目源码

登录页和注册页 管理员、用户和拍卖者都可以在此页面登录到该系统 拍卖者注册的页面&#xff0c;拍卖者如果没有账号可以点击注册进入到系统 开发语言&#xff1a;Python 框架&#xff1a;django Python版本&#xff1a;python3.7.7 数据库&#xff1a;mysql 数据库工具&…

在国内有几个CCIE考点?

笔试在VUE指定思科考试中心报考加考试; 实验考试在国内只有北京和香港两个考点。还有一些其他考点&#xff0c;下面让我们一起来看一下。 CCIE笔试考点 思科ccie认证的笔试考试地点都在Vue考试中心,Vue考试中心是思科官方授权的考试地点,在国内大部分城市都有分布 CCIE笔试报…

基于ssm的航空售票系统

博主介绍&#xff1a;java高级开发&#xff0c;从事互联网行业六年&#xff0c;熟悉各种主流语言&#xff0c;精通java、python、php、爬虫、web开发&#xff0c;已经从做了六年的毕业设计程序开发&#xff0c;开发过上千套毕业设计程序&#xff0c;没有什么华丽的语言&#xf…

c# 将数据导出到EXCEL文件

第一步&#xff1a;项目中加入引用。 在鼠标右击项目&#xff0c;点击【添加】弹出菜单列表&#xff0c;选择【项目引用】弹出【引用管理器】对话框&#xff0c;选择【COM】-【Microsoft Excel 16.0 Object Library】&#xff0c;如图所示&#xff1a; 第二步&#xff0c;编辑…

【项目精选】基于Java的敬老院管理系统的设计和实现

本系统主要是针对敬老院工作人员即管理员和员工设计的。敬老院管理系统 将IT技术为养老院提供一个接口便于管理信息,存储老人个人信息和其他信息,查找 和更新信息的养老院档案,节省了员工的劳动时间,大大降低了成本。 其主要功能包括&#xff1a; 系统管理员用户功能介绍&#…

体系结构概述

目录1.系统模型操作模式和状态寄存器和特殊寄存器2.存储器系统3.栈空间操作4.异常和中断5.嵌套向量中断控制器&#xff08;NVIC&#xff09;1.系统模型 操作模式和状态 Cortex-M0处理器包含两种操作模式和两种状态。 处理器在运行程序时处于Thumb状态&#xff0c;在这种状态…

个人信息保护认证

个人信息保护认证是证明个人信息处理者在认证范围内开展的个人信息收集、存储、使用、加工、传输、提供、公开、删除以及跨境等处理活动符合认证依据标准要求。适用范围 本规则依据《中华人民共和国认证认可条例》制定&#xff0c;规定了对个人信息处理者开展个人信息收集、存储…

【漏洞复现】phpStudy 小皮 Windows面板 RCE漏洞

文章目录前言一、漏洞描述二、漏洞复现前言 本篇文章仅用于漏洞复现研究和学习&#xff0c;切勿从事非法攻击行为&#xff0c;切记&#xff01; 一、漏洞描述 Phpstudy小皮面板存在RCE漏洞&#xff0c;通过分析和复现方式发现其实本质上是一个存储型XSS漏洞导致的RCE。通过系…

关于IB课程,你需要知道些什么?

1IB课程到底是什么&#xff1f; IB课程 IB课程是由国际文凭组织IBO&#xff08;International Baccalaureate Organization&#xff09;开设的、针对3-19岁学生的、从幼儿园到大学预科的课程&#xff0c;服务对象为全球3-19岁的学生。IBO历史 1968年IBO组织在瑞士日内瓦成立&am…

php mysql保健品购物商城系统

目 录 1 绪论 1 1.1 开发背景 1 1.2 研究的目的和意义 1 1.3 研究现状 2 2 开发技术介绍 2 2.1 B/S体系结构 2 2.2 PHP技术 3 2.3 MYSQL数据库 4 2.4 Apache 服务器 5 2.5 WAMP 5 2.6 系统对软硬件要求 6 …

2.计算机基础-计算机网络面试题—基础知识、容器、面向对象、并发编程

本文目录如下&#xff1a;计算机基础-计算机网络 面试题一、基础知识简述 TCP 和 UDP 的区别&#xff1f;http与https的区别?Session 和 Cookie 有什么区别&#xff1f;URL是什么&#xff1f;由哪些部分组成&#xff1f;OSI 的 五层模型 都有哪些&#xff1f;get 和 post 请求…

CIAM 如何平衡数据安全与客户体验?| 身份云研究院

普华永道研究表明&#xff0c;32% 的用户会因为一次体验不佳而放弃使用一个产品。无独有偶&#xff0c;据数据分析公司 Preact 研究显示&#xff0c;首次注册、登录或验证是最主要的用户流失环节&#xff0c;占整体流失率的 22.9%。 对于任何在网上做生意的公司来说&#xff0…

MySQL为什么要改进LRU算法?

普通LRU算法 LRU算法介绍 LRU Least Recently Used&#xff08;最近最少使用&#xff09;&#xff1a;也就是末尾淘汰法&#xff0c;新数据从链表头部加入&#xff0c;释放空间时从末尾淘汰数据。 1.当要访问某个页时&#xff0c;如果不在Buffer Pool中&#xff0c;需要把该…

java实现电子发票中的发票税号等信息识别的几种可用方案

先说一下背景&#xff1a;今天领导突然说需要做一个电子发票中发票税号的识别&#xff0c;于是乎就开始去调研看有哪些方案&#xff0c;最先想到的就是OCR文字识别&#xff0c;自己去画框训练模型去识别税号等相关信息话不多说开整思路&#xff1a;思路一&#xff1a;百度AI平台…

逻辑仿真工具VCS的使用-Makefile

Gvim写RTL code&#xff0c;VCS仿真&#xff0c;Verdi看波形&#xff0c;DC做综合下约束&#xff0c;Primetime做STA&#xff0c;Spyglass做异步时序分析。 VCS全称Verilog Computer Simulation &#xff0c;VCS是逻辑仿真EDA工具的编译源代码的命令。要用VCS做编译仿…

C进阶:预处理

&#x1f916;本篇文章主要讲解预处理的知识&#xff0c;即使你是小白也可以看的懂&#xff0c;若你对预处理有所不解&#xff0c;确定不来看看吗&#xff1f;&#x1f63f; 目录 一.代码运行是的两种环境 二.翻译环境 三.预定义符号 四.#define 1.define 定义宏 2.带有…

有哪些必知人工智能著名应用?

当前有哪些人工智能的著名应用&#xff1f; 生成式人工智能 输入要求&#xff0c;输出结果 趣讲大白话&#xff1a;不要小看科技进步哦 *********** 人工智能引领智能时代 见图 【安志强趣讲信息科技】 掌握信息科技常识&#xff0c;未来竞争才不吃亏 世纪之问&#xff1a;做…

一篇文章读懂阿里云企业级数据库最佳实践

今天阿里数据库不再是简单的电商业务&#xff0c;而是涵盖了视频娱乐、IM、地图、在线零售、新零售、物流、在线旅游、音乐、IoT等纵多领域。2017年双十一交易额达1682亿&#xff0c;数据库交易峰值也以数十倍的速度在增长。超大规模的业务压力&#xff0c;在阿里巴巴内部淬炼出…

Chrome 浏览器的四大进程

一个进程就是一个程序的运行实例。详细解释就是&#xff0c;启动一个程序的时候&#xff0c;操作系统会为该程序创建一块内存&#xff0c;用来存放代码、运行中的数据和一个执行任务的主线程&#xff0c;我们把这样的一个运行环境叫进程。 总结来说&#xff0c;进程和线程之间的…

2023年地方两会政府工作报告汇总(各省市23年重点工作)

新年伊始&#xff0c;全国各地两会密集召开&#xff0c;各省、市、自治区2023年政府工作报告相继出炉&#xff0c;各地经济增长预期目标均已明确。相较于2022年&#xff0c;多地经济增长目标放缓&#xff0c;经济不断向“高质量”发展优化转型。今年是二十大后的开局之年&#…