lesson04-简单回归案例实战(理论+代码)

news2025/6/4 10:57:35

理解线性回归及梯度下降优化

引言

在机器学习的基础课程中,我们经常遇到的一个重要概念就是线性回归。今天,我们将深入探讨这一主题,并通过具体的例子来了解如何利用梯度下降方法对模型进行优化。

线性回归简介

线性回归是一种统计方法,用于确定两个变量之间的关系。简单来说,如果我们有一个自变量 XX 和因变量 YY,线性回归可以帮助我们找到一条最佳拟合直线,这条直线可以用公式 Y=WX+bY=WX+b 来表示,其中 WW 是权重,bb 是偏置。

损失函数

为了评估模型的好坏,我们需要定义一个损失函数。对于线性回归而言,通常使用平方误差作为损失函数,即 loss=(WX+b−y)2loss=(WX+b−y)2。

梯度下降优化

梯度下降是一种迭代优化算法,用来最小化损失函数。每次迭代过程中,我们会更新参数 WW 的值,具体更新规则为 w′=w−lr×∇loss/∇ww′=w−lr×∇loss/∇w,这里的 lrlr 表示学习率,控制着每一步调整的幅度。

迭代优化

通过不断调整 WW 和 bb 的值,使得损失函数逐渐减小,直到达到局部或全局最小值点。这个过程需要多次迭代计算,直至满足预设的停止条件为止。

下一课时预告

接下来的一课时,我们将一起探索著名的MNIST手写数字识别任务,敬请期待!

结语

感谢大家的关注与支持,希望今天的分享能够加深您对线性回归以及梯度下降算法的理解。让我们共同期待下一节课的到来吧!

实战代码

import numpy as np

# y = wx + b
def compute_error_for_line_given_points(b, w, points):
    totalError = 0
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        totalError += (y - (w * x + b)) ** 2
    return totalError / float(len(points))

def step_gradient(b_current, w_current, points, learningRate):
    b_gradient = 0
    w_gradient = 0
    N = float(len(points))
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        b_gradient += -(2/N) * (y - ((w_current * x) + b_current))
        w_gradient += -(2/N) * x * (y - ((w_current * x) + b_current))
    new_b = b_current - (learningRate * b_gradient)
    new_m = w_current - (learningRate * w_gradient)
    return [new_b, new_m]

def gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations):
    b = starting_b
    m = starting_m
    for i in range(num_iterations):
        b, m = step_gradient(b, m, np.array(points), learning_rate)
    return [b, m]

def run():
    points = np.genfromtxt("data.csv", delimiter=",")
    learning_rate = 0.0001
    initial_b = 0 # initial y-intercept guess
    initial_m = 0 # initial slope guess
    num_iterations = 1000
    print("Starting gradient descent at b = {0}, m = {1}, error = {2}"
          .format(initial_b, initial_m,
                  compute_error_for_line_given_points(initial_b, initial_m, points))
          )
    print("Running...")
    [b, m] = gradient_descent_runner(points, initial_b, initial_m, learning_rate, num_iterations)
    print("After {0} iterations b = {1}, m = {2}, error = {3}".
          format(num_iterations, b, m,
                 compute_error_for_line_given_points(b, m, points))
          )

if __name__ == '__main__':
    run()

🧠 一、代码概述

这段代码的主要目的是:

  • 使用一个简单的线性模型:y = mx + b
  • 给定一个二维数据集 data.csv,其中每行有两个值:x 和 y
  • 使用梯度下降算法迭代地更新 m 和 b,使得预测的 y 尽可能接近真实值
  • 最终输出经过多次迭代后的最优 m 和 b 值,并计算最终误差

📁 二、文件结构说明

  1. 导入库

    import numpy as np
    • 引入 NumPy 库,用于高效的数值计算和数组操作。
  2. 函数定义

    • compute_error_for_line_given_points(b, w, points)
      计算当前直线的平均平方误差(MSE)
    • step_gradient(b_current, w_current, points, learningRate)
      执行一次梯度下降步骤,返回更新后的 b 和 m
    • gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations)
      迭代运行梯度下降过程
    • run()
      主函数,加载数据、调用训练函数、打印结果
  3. 主程序入口

if __name__ == '__main__':
    run()

 

📌 三、函数详解

1. compute_error_for_line_given_points(b, w, points)

功能:

计算当前模型参数下的均方误差(Mean Squared Error, MSE)

公式:

MSE=1N∑i=1N(yi−(wxi+b))2MSE=N1​i=1∑N​(yi​−(wxi​+b))2

参数:
  • b: 当前截距(bias / y-intercept)
  • w: 当前斜率(weight / slope)
  • points: 数据点集合,是一个二维数组,每行表示一个 (x, y) 点
返回值:
  • 平均误差值(越小越好)

2. step_gradient(b_current, w_current, points, learningRate)

功能:

执行一次梯度下降步骤,根据当前的 bm 更新它们的值。

核心公式(梯度下降更新规则):

b′=b−η⋅∂MSE∂bb′=b−η⋅∂b∂MSE​

m′=m−η⋅∂MSE∂mm′=m−η⋅∂m∂MSE​

其中:

  • ηη 是学习率(learning rate)
  • 梯度是通过对损失函数分别对 b 和 m 求导得到的
导数推导:

∂MSE∂b=2N∑i=1N(yi−(mxi+b))⋅(−1)∂b∂MSE​=N2​i=1∑N​(yi​−(mxi​+b))⋅(−1)

∂MSE∂m=2N∑i=1N(yi−(mxi+b))⋅(−xi)∂m∂MSE​=N2​i=1∑N​(yi​−(mxi​+b))⋅(−xi​)

你在代码中实现了这两个梯度的累加。

返回值:
  • [new_b, new_m]:更新后的模型参数

3. gradient_descent_runner(...)

功能:

循环执行 step_gradient 多次,完成完整的梯度下降过程。

参数:
  • points: 数据集
  • starting_bstarting_m: 初始参数
  • learning_rate: 学习率
  • num_iterations: 迭代次数
输出:
  • 最终的 b 和 m

4. run()

功能:
  • 加载 CSV 数据文件
  • 设置初始参数
  • 调用梯度下降函数进行训练
  • 打印训练前后误差和参数变化

输出结果展示

这表明经过 1000 次迭代后,模型已经基本收敛。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2396380.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 面试中的数据库设计深度解析

🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Java 面试中的数据库设计深度解析一、数据库…

国内首发!具有GPU算力的AI扫描仪

奥普思凯重磅推出的具有GPU算力的扫描仪,是一款真正意义上的AI扫描仪,奥普思凯将嵌有OCR发票识别核心的高性能NPU算力棒与高速扫描仪相结合,实现软件硬件相结合,采用一体化外观设计,实现高速扫描、快速识别表单&#x…

【开发技巧指北】IDEA修改默认绑定Maven的仓库地址

【开发技巧指北】IDEA修改默认绑定Maven的仓库地址 Microsoft Windows 11 家庭中文版 IIntelliJ IDEA 2025.1.1.1 默认的IDEA是有自己捆绑的Maven的(这是修改完毕的截图) 修改默认的Maven配置,路径是IDEA安装路径下的plugins D:\Softwares\I…

【2025最新】Java图书借阅管理系统:从课程作业到实战应用的完整解决方案

【2025最新】Java图书借阅管理系统:从课程作业到实战应用的完整解决方案 目录 【2025最新】Java图书借阅管理系统:从课程作业到实战应用的完整解决方案**系统概述** **核心功能模块详解****1. 系统登录与权限控制****2. 借阅管理模块****3. 用户角色管理…

springcloud openfeign 请求报错 java.net.UnknownHostException:

现象 背景 项目内部服务之间使用openfeign通过eureka注册中心进行服务间调用,与外部通过http直接调用。外部调用某个业务方提供的接口需要证书校验,因对方未提供证书故设置了忽略证书校验代码如下 Configuration public class IgnoreHttpsSSLClient {B…

【harbor】--配置https

使用自建的 CA 证书来自签署和启用 HTTPS 通信。 (1)生成 CA认证 使用 OpenSSL 生成一个 2048位的私钥这是 自建 CA(证书颁发机构) 的私钥,后续会用它来签发证书。 # 1创建CA认证 cd 到harbor [rootlocalhost harbo…

OptiStruct实例:消声器前盖ERP分析(2)RADSND基础理论

13.2 Radiated Sound Output Analysis( RADSND ) RADSND 方法通过瑞利积分来求解结构对外的辐射噪声。其基本思路是分为两个阶段,如图 13-12 所示。 图13-12 结构辐射噪声计算示意图 第一阶段采用有限元方法,通过频响分析(模态叠加法、直接法)工况计算结…

barker-OFDM模糊函数原理及仿真

文章目录 前言一、巴克码序列二、barker-OFDM 信号1、OFDM 信号表达式2、模糊函数表达式 三、MATLAB 仿真1、MATLAB 核心源码2、仿真结果①、barker-OFDM 模糊函数②、barker-OFDM 距离分辨率③、barker-OFDM 速度分辨率④、barker-OFDM 等高线图 四、资源自取 前言 本文进行 …

3.RV1126-OPENCV 图像叠加

一.功能介绍 图像叠加:就是在一张图片上放上自己想要的图片,如LOGO,时间等。有点像之前提到的OSD原理一样。例如:下图一张图片,在左上角增加其他图片。 二.OPENCV中图像叠加常用的API 1. copyTo方法进行图像叠加 原理…

使用 HTML + JavaScript 实现一个日历任务管理系统

在现代快节奏的生活中,有效的时间管理变得越来越重要。本项目是一个基于 HTML 和 JavaScript 开发的日历任务管理系统,旨在为用户提供一个直观、便捷的时间管理工具。系统不仅能够清晰地展示当月日期,还支持事件的添加、编辑和删除操作&#…

车载诊断架构SOVD --- 车辆发现与建连

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 钝感力的“钝”,不是木讷、迟钝,而是直面困境的韧劲和耐力,是面对外界噪音的通透淡然。 生活中有两种人,一种人格外在意别人的眼光;另一种人无论…

Notepad++找回自动暂存的文件

场景: 当你没有保存就退出Notepad,下次进来Notepad会自动把你上次编辑的内容显示出来,以便你继续编辑。除非你手动关掉当前页面,这样Notepad就会删除掉自动保存的内容。 问题: Notepad会将自动保存的文件地址,打开Note…

DL00924-基于深度学习YOLOv11的工程车辆目标检测含数据集

文末有代码完整出处 🚗 基于深度学习YOLOv11的工程车辆目标检测——引领智能识别新潮流! 🚀 随着人工智能技术的飞速发展, 目标检测 已经在各个领域取得了显著突破,尤其是在 工程车辆识别 这一关键技术上。今天&#…

Axure RP11安装、激活、汉化

一:注册码 Axure RP11.0.0.4122在2025-5-29日亲测有效: 49bb9513c40444b9bcc3ce49a7a022f9

自编码器Auto-encoder(李宏毅)

目录 编码器的概念: 为什么需要编码器? 编码器什么原理? 去噪自编码器: 自编码器的应用: 特征解耦 离散隐表征 编码器的概念: 重构:输入一张图片,通过编码器转化成向量,要求再…

数据结构之堆(topk问题、堆排序)

一、堆的初步认识 堆虽然是用数组存储数据的数据结构,但是它的底层却是另一种表现形式。 堆分为大堆和小堆,大堆是所有父亲大于孩子,小堆是所有孩子大于父亲。 通过分析我们能得出父子关系的计算公式,parent(child-1)/2&#xff…

SpringBoot使用ffmpeg实现视频压缩

ffmpeg简介 FFmpeg 是一个开源的跨平台多媒体处理工具集,用于录制、转换、编辑和流式传输音频和视频。它功能强大,支持几乎所有常见的音视频格式,是多媒体处理领域的核心工具之一。 官方文档:https://ffmpeg.org/documentation.h…

2025-05-31 Python深度学习9——网络模型的加载与保存

文章目录 1 使用现有网络2 修改网络结构2.1 添加新层2.2 替换现有层 3 保存网络模型3.1 完整保存3.2 参数保存(推荐) 4 加载网络模型4.1 加载完整模型文件4.2 加载参数文件 5 Checkpoint5.1 保存 Checkpoint5.2 加载 Checkpoint 本文环境: Py…

长安链起链调用合约时docker ps没有容器的原因

在调用这个命令的时候,发现并没有出现官方预期的合约容器,这是因为我们在起链的时候没有选择用docker的虚拟环境,实际上这不影响后续的调用,如果想要达到官方的效果那么你只需要在起链的时候输入yes即可,如图三所示

Appium+python自动化(七)- 认识Appium- 上

简介 经过前边的各项准备工作,终于才把appium搞定。 一、appium自我介绍 appium是一款开源的自动化测试工具,可以支持iOS和安卓平台上的原生的,基于移动浏览器的,混合的应用(APP)。 1、 使用appium进…