从0实现线性回归

news2025/7/11 14:58:42

编码题:

按要求完成下面的内容

1请用python完成从0实现线性回归,尝试使用不同的训练参数(学习率,迭代次数), 以及不同的评价方法(MSE,MAE,RMSE,R2)等。

2比较说明sklearn的线性模型和自己实现的线性模型(通过上述代码实现以及训练过程, 比较不同超参数以及评价方法的影响)

第一问:

首先从0开始实现线性回归

1)创建数据

        直接调用load_boston 加载波士顿房价数据集:

2)5种评价指标:

3)自定义模型:

 4)定义损失函数:

5)修改参数进行比较结果如下:

 

6)损失函图:

第二问:在第一问的基础上直接调用sklearn 里面的 LinearRegression()模型即可, 然后在进行比较,改变参数进行比较

结果如下:

根据结果比较得到结论:

  1. b(前进的距离)小一点的话,每次改变的变化就小一点,精准度会比较高一点,适合训练次数比较多的模型,反之则反之
  2. Alpha (学习率)这个值笼统来讲及时比较大的时候学习效果就比较明显,对于b的调整也会比较明显从这里很容易看出:

代码部分:

from sklearn.datasets import load_boston

from sklearn.linear_model import LinearRegression

import numpy as np

import matplotlib.pyplot as plt

import warnings



warnings.filterwarnings("ignore")  # 忽略警告运行

# 数据的加载

def feature_scalling(X):

    mean = X.mean(axis=0)

    std = X.std(axis=0)

    return (X - mean) / std

def load_data():

    data = load_boston()  # 注意波士顿房价这个数据将会在skleran 1.2这个版本移除

    X = data.data

    y = data.target.reshape(-1, 1)

    X = feature_scalling(X)

    return X, y

# 评价指标1 ——  均方误差(MSE)

def MSE(y, y_pre):

    return np.mean((y - y_pre) ** 2)

# 评价指标2 ——  均方根误差(RMSE)

def RMSE(y, y_pre):

    return np.sqrt(MSE(y, y_pre))

# 评价指标3 —— 平均绝对误差(MAE)

def MAE(y, y_pre):

    return np.mean(np.abs(y-y_pre))

# 评价指标4 —— 平均绝对百分比误差(MAPE)

def MAPE(y, y_pre):

    return np.mean(np.abs(y-y_pre)/y)

# 评价指标4 —— R^2 评价指标

def R2(y, y_pre):

    u = np.sum((y-y_pre)**2)

    v = np.sum((y-np.mean(y_pre))**2)

    return 1-(u/v)

# 预测函数

def prediction(X, W, bias):

    return np.matmul(X, W) + bias

# 损失值函数

def cost_function(X, y, W, bias):

    m, n = X.shape

    y_hat = prediction(X, W, bias)

    return 0.5 * (1 / m) * np.sum((y - y_hat) ** 2)

# 自定义模型调整值

def gradient_descent(X, y, W, bias, alpha):

    m, n = X.shape  # m个数据

    y_hat = prediction(X, W, bias)

    grad_w = -(1 / m) * np.matmul(X.T, (y - y_hat))

    grad_b = -(1 / m) * np.sum(y - y_hat)  # 求解梯度

    W = W - alpha * grad_w  # 梯度下降

    bias = bias - alpha * grad_b  # 调整前进的距离

    return W, bias

# 自定义模型 1

'''

 b = 0.1 #前进的距离

    alpha = 0.2 #学习率

'''

def train_by_my1(X, y, ite=200):

    m, n = X.shape  # 506,13

    W = np.random.randn(n, 1)

    b = 0.1  # 前进的距离

    alpha = 0.2  # 学习率

    costs = []  # 每一次的损失函数

    for i in range(ite):  # 训练 ite = 200 轮

        J = cost_function(X, y, W, b)  # 计算损失值

        costs.append(J)

        W, b = gradient_descent(X, y, W, b, alpha)

    y_pre = prediction(X, W, b)

    print("----------my_train1 训练模型--------------")

    print("my_train1 MSE评价指标: ", MSE(y, y_pre))

    print("my_train1 R^2评价指标: ", R2(y, y_pre))

    print("my_train1 MAPE评价指标: ", MAPE(y, y_pre))

    print("my_trian1 RMSE评价指标: ", RMSE(y, y_pre))

    print("my_train1 MAE评价指标; ", MAE(y, y_pre))

    return costs

# 自定义模型超参数修改2

'''

 b = 0.2 #前进的距离

    alpha = 0.5 #学习率

'''

def train_by_my2(X, y, ite=200):

    m, n = X.shape  # 506,13

    W = np.random.randn(n, 1)

    b = 0.05  # 前进的距离

    alpha = 0.1  # 学习率

    costs = []  # 每一次的损失函数

    for i in range(ite):  # 训练 ite = 200 轮

        J = cost_function(X, y, W, b)  # 计算损失值

        costs.append(J)

        W, b = gradient_descent(X, y, W, b, alpha)

    y_pre = prediction(X, W, b)

    print("----------my_train2 训练模型--------------")

    print("my_train2 MSE评价指标: ", MSE(y, y_pre))

    print("my_train2 R^2评价指标: ", R2(y, y_pre))

    print("my_train2 MAPE评价指标: ", MAPE(y, y_pre))

    print("my_trian2 RMSE评价指标: ", RMSE(y, y_pre))

    print("my_train2 MAE评价指标; ", MAE(y, y_pre))

    return costs

# skleran 训练模型

def train_by_sklearn(x, y):

    model = LinearRegression()  # 创建模型

    model.fit(x, y)  # fit训练模型

    y_pre = model.predict(x)  # 根据模型预测

    print("----------skleran 训练模型--------------")

    print("sklearn_train MSE评价指标: ", MSE(y, y_pre))

    print("sklearn_train R^2评价指标: ", R2(y, y_pre))

    print("sklearn_train MAPE评价指标: ", MAPE(y, y_pre))

    print("sklearn_train RMSE评价指标: ", RMSE(y, y_pre))

    print("sklearn_train MAE评价指标; ", MAE(y, y_pre))

if __name__ == '__main__':

    x, y = load_data()

#     train_by_sklearn(x, y)

    costs = train_by_my1(x, y)

    train_by_my2(x, y)

    plt.plot(range(len(costs)), costs, label='损失值', c='black')

    plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体

    plt.legend(fontsize=15)

    plt.xlabel('迭代次数', fontsize=15)

    plt.tight_layout()  # 调整子图间距

    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/5609.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

断言(assert)的用法

参考:https://www.runoob.com/w3cnote/c-assert.html 目录作用总结与注意事项Demo作用 assert 是个宏,并且作用并非"报错"。 assert() 的用法像是一种"契约式编程",程序满足我的假设条件,才能正常良好的运作…

做视频素材资源(free视频,音频,图片)

素材资源 一、视频 Videezy :https://www.videezy.com/ Videovo:https://www.videvo.net/ mixkit:https://mixkit.co/,可以 distill:https://wedistill.io/ splitshire:https://www.splitshire.com/ pixa…

Mysql常见指令以及用法(保姆级)

文章目录基础篇通用语法及分类DDL(数据定义语言)数据库操作注意事项表操作DML(数据操作语言)添加数据注意事项更新和删除数据DQL(数据查询语言)基础查询条件查询聚合查询(聚合函数)分…

前端性能-首次加载优化70%

前言 本篇文章,我们来总结归纳下万恶的this以及衍生出来的call/apply/bind对this进行绑定,想了很久,决定用实例演示的方式来讲解this,这样才能够理解this,因为this确实变化莫测,只靠概念,是不能…

【JS 构造|原型|原型链|继承(圣杯模式)|ES6类语法】下篇

⌚️⌚️⌚️个人格言:时间是亳不留情的,它真使人在自己制造的镜子里照见自己的真相! 📖Git专栏:📑Git篇🔥🔥🔥 📖JavaScript专栏:📑js实用技巧篇…

【数据结构】带头双向循环链表基本操作的实现(C语言)

🚀 作者简介:一名在后端领域学习,并渴望能够学有所成的追梦人。 🐌 个人主页:蜗牛牛啊 🔥 系列专栏:🛹初出茅庐C语言、🛴数据结构 📕 学习格言:博…

峰会实录 | StarRocks PMC Chair 赵纯:数据分析的极速统一3.0 时代

作者:StarRocks PMC Chair 赵纯(本文为作者在 StarRocks Summit Asia 2022 上的分享) 一年前,StarRocks 源码开放,StarRocks 社区也正式成立。经过一年发展,社区已经获得了 3400 个 Star,7500 …

如何用windows上架ios到苹果商城

1.苹果账号 1.你需要申请苹果账号 官网有提示:Sign In - Apple 2.登录 登录后店家 account,进入account。 点击证书,进入。 2.开始上架步骤 1.注册标识符(Bundle ID) 进入这个界面后,点击 Identifiers …

Elasticsearch快照备份

目录 1、Repositories 1、配置路径 2、注册快照存储库 2、查看注册的库 3、创建快照 1、为全部索引创建快照 2、为指定索引创建快照 4、查看备份完成的列表 5、删除快照 6、从快照恢复 1、恢复指定索引 2、恢复所有索引(除.开头的系统索引) …

【Redis】 数据结构:Redis对象与编码(底层结构)对应关系详解

【Redis】 数据结构:Redis对象与编码(底层结构)对应关系详解 文章目录【Redis】 数据结构:Redis对象与编码(底层结构)对应关系详解Redis对象与编码(底层结构)对应关系引入Redis数据结构-RedisObjectredisObject数据结构Redis的编码方式五种数据结构Redis…

2022年深信服杯四川省大学生信息安全技术大赛-CTF-Reverse复现(部分)

Rush B 开始先设置一下数字以16进制格式显示 看主函数 __int64 __fastcall main(int a1, char **a2, char **a3) {int v3; // eaxsize_t v4; // raxint v5; // ecxchar v6; // alint v7; // ecxint v9; // [rsp3Ch] [rbp-404h]char s[1000]; // [rsp40h] [rbp-400h] BYREFchar …

免杀技术(详细)

恶意软件 ● 病毒、木马、蠕虫、键盘记录、僵尸程序、流氓软件、勒索软件、广告程序 ● 在用户非资源的情况下执行安装 ● 出于某种恶意的目的:控制、窃取、勒索、偷窥、推送、攻击。。。。。 恶意程序最重要的防护手段 ● 杀毒软件 / 防病毒软件 ● 客户端 / 服…

c# .net MAUI基础篇 环境安装、新建项目、安卓模拟器安装、项目运行

c# .net MAUI基础篇 环境安装、新建项目、安卓模拟器安装、项目运行 免费教学视频地址由趣编程ACE老师提供: 1..NET MAUI优势及安装和创建_哔哩哔哩_bilibili 一、介绍 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架,用于使用 C# 和 XAML 创建本机移…

【面经】之小鼠喝药问题

题目 现在有 10 只小白鼠和 1000 支药水,1000 支药水中有且仅有一支药水有毒,如果小白鼠喝下毒药,那么毒发的时间是两小时。 现在只给你两小时的时间,请问如何用这 10 只小白鼠测出哪支药水有毒?(忽略小白…

【Java编程进阶】标识符和关键字

在学习Java程序设计基础的时候,主要有标识符,变量,数据类型,流程控制这些主要的内容。 推荐学习专栏:Java 编程进阶之路【从入门到精通】 文章目录1. 标识符2. 关键字1. 标识符 什么是标识符? 标识符是用…

linux下的PPPOE设置

1.打开终端 #sudo pppoeconf 进入配置,输入用户名和密码. 2.建立连接 #sudo pon dsl-provider 3.断开连接 #sudo poff dsl-provider Welcome to the ADSL client setup. First, I will run some checks on your system to make sure the PPPoE client is installed properly.…

The 2022 CCPC Guangzhou Onsite M. XOR Sum(数位dp 数位背包)

题目 给定n,m,k(0<n<1e15,0<m<1e12,1<k<18)&#xff0c; 求长度为k的数组a&#xff0c;ai为[0,m]的整数&#xff0c; 满足的方案数 答案对1e97取模 题解 第一反应想起了hdu3693&#xff0c;但比对了一下&#xff0c;感觉那个题难很多&#xff0c; 两年…

一看就会的Java方法

文章目录一、方法的定义和使用&#x1f351;1、为什么引入方法&#xff1f;&#x1f351;2、方法的定义&#x1f351;3、方法调用的执行过程&#x1f351;4、实参和形参的关系二、方法重载&#x1f351;1、为什么需要方法重载&#x1f351;2、方法重载的概念和特点&#x1f351…

四旋翼无人机学习第8节--OpenMV电路分析

这里写目录标题0 前言1 openmv优秀作品介绍2 stm32单片机原理图绘制3 stm32单片机外接电容分析3 stm32单片机外接电容绘制4 stm32单片机外接晶振分析5 stm32单片机外接晶振绘制6 stm32单片机复位电路分析7 stm32单片机复位电路设计0 前言 简单的说一下&#xff0c;openmv模块是…

微信小程序 | 吐血整理的日历及日程时间管理

&#x1f4cc;个人主页&#xff1a;个人主页 ​&#x1f9c0; 推荐专栏&#xff1a;小程序开发成神之路 --【这是一个为想要入门和进阶小程序开发专门开启的精品专栏&#xff01;从个人到商业的全套开发教程&#xff0c;实打实的干货分享&#xff0c;确定不来看看&#xff1f; …