头歌机器学习实验 第7次实验 局部加权线性回归

news2025/9/19 20:36:46

任务描述

本关任务:编写一个利用局部加权计算回归系数的小程序。

相关知识

为了完成本关任务,你需要掌握:1.局部加权算法的思想;2.局部加权的核心算法。

局部加权算法的思想

在局部加权算法中 ,我们给待预测点附近的每个点赋予一定的权重;然后与前面的类似,在这个子集上基于最小均方差来进行普通的回归。与kNN一样,这种算法每次预测均需要事先选取出对应的数据子集。 该算法解出回归系数w的形式如下:

,

其中w是一个矩阵,用来给每个数据点赋予权重。

局部加权的核心算法
 
  1. def lwlr(testPoint,xArr,yArr,k=1.0):
  2. xMat = np.mat(xArr); yMat = np.mat(yArr).T
  3. m = np.shape(xMat)[0]
  4. weights = np.mat(np.eye((m)))
  5. for j in range(m): #next 2 lines create weights matrix
  6. diffMat = testPoint - xMat[j,:] #difference matrix
  7. weights[j,j] = np.exp(diffMat*diffMat.T/(-2.0*k**2)) #weighted matrix
  8. xTx = xMat.T * (weights * xMat)
  9. if np.linalg.det(xTx) == 0.0:
  10. print ("This matrix is singular, cannot do inverse")
  11. return
  12. ws = xTx.I * (xMat.T * (weights * yMat)) #normal equation
  13. return testPoint * w

编程要求

根据提示,在右侧编辑器补充代码,利用局部加权计算回归系数。

测试说明

根据所学完成右侧编程题。

from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
import numpy as np

# 加载数据
def loadDataSet(fileName):
    """
    Parameters:
        fileName - 文件名
    Returns:
        xArr - x数据集
        yArr - y数据集
    """
    numFeat = len(open(fileName).readline().split('\t')) - 1
    xArr = []; yArr = []
    fr = open(fileName)
    for line in fr.readlines():
        lineArr =[]
        curLine = line.strip().split('\t')
        for i in range(numFeat):
            lineArr.append(float(curLine[i]))
        xArr.append(lineArr)
        yArr.append(float(curLine[-1]))
    return xArr, yArr

# 使用局部加权线性回归计算回归系数w
def lwlr(testPoint, xArr, yArr, k = 1.0):
    """
    Parameters:
        testPoint - 测试样本点
        xArr - x数据集
        yArr - y数据集
        k - 高斯核的k,自定义参数
    Returns:
        ws - 回归系数
    """
    xMat = np.mat(xArr); yMat = np.mat(yArr).T
    m = np.shape(xMat)[0]
    weights = np.mat(np.eye((m)))                                   #创建权重对角矩阵
    for j in range(m):                                              #遍历数据集计算每个样本的权重
        ##########
        diffMat = testPoint - xMat[j,:]   #difference matrix
        weights[j,j] = np.exp(diffMat*diffMat.T/(-2.0*k**2))   #weighted matrix
        ##########
    xTx = xMat.T * (weights * xMat)
    if np.linalg.det(xTx) == 0.0:
        print("矩阵为奇异矩阵,不能求逆")
        return
    ws = xTx.I * (xMat.T * (weights * yMat))                        #计算回归系数
    return testPoint * ws

# 局部加权线性回归测试
def lwlrTest(testArr, xArr, yArr, k=1.0):
    """
    Parameters:
        testArr - 测试数据集,测试集
        xArr - x数据集,训练集
        yArr - y数据集,训练集
        k - 高斯核的k,自定义参数
    Returns:
        ws - 回归系数
    """
    m = np.shape(testArr)[0]                                       #计算测试数据集大小
    yHat = np.zeros(m)
    for i in range(m):                                             #对每个样本点进行预测
        yHat[i] = lwlr(testArr[i],xArr,yArr,k)
    return yHat

# 计算回归系数w
def standRegres(xArr,yArr):
    """
    Parameters:
        xArr - x数据集
        yArr - y数据集
    Returns:
        ws - 回归系数
    """
    xMat = np.mat(xArr); yMat = np.mat(yArr).T
    xTx = xMat.T * xMat                                         #根据文中推导的公示计算回归系数
    if np.linalg.det(xTx) == 0.0:
        print("矩阵为奇异矩阵,不能求逆")
        return
    ws = xTx.I * (xMat.T*yMat)
    return ws


def rssError(yArr, yHatArr):
    """
    误差大小评价函数
    Parameters:
        yArr - 真实数据
        yHatArr - 预测数据
    Returns:
        误差大小
    """
    return ((yArr - yHatArr) **2).sum()


if __name__ == '__main__':
    abX, abY = loadDataSet('./机器学习第8章/abalone.txt')
    print('训练集与测试集相同:局部加权线性回归,核k的大小对预测的影响:')
    yHat01 = lwlrTest(abX[0:99], abX[0:99], abY[0:99], 0.1)
    yHat1 = lwlrTest(abX[0:99], abX[0:99], abY[0:99], 1)
    yHat10 = lwlrTest(abX[0:99], abX[0:99], abY[0:99], 10)
    print('k=0.1时,误差大小为:',rssError(abY[0:99], yHat01.T))
    print('k=1  时,误差大小为:',rssError(abY[0:99], yHat1.T))
    print('k=10 时,误差大小为:',rssError(abY[0:99], yHat10.T))
    print('')
    print('训练集与测试集不同:局部加权线性回归,核k的大小是越小越好吗?更换数据集,测试结果如下:')
    yHat01 = lwlrTest(abX[100:199], abX[0:99], abY[0:99], 0.1)
    yHat1 = lwlrTest(abX[100:199], abX[0:99], abY[0:99], 1)
    yHat10 = lwlrTest(abX[100:199], abX[0:99], abY[0:99], 10)
    print('k=0.1时,误差大小为:',rssError(abY[100:199], yHat01.T))
    print('k=1  时,误差大小为:',rssError(abY[100:199], yHat1.T))
    print('k=10 时,误差大小为:',rssError(abY[100:199], yHat10.T))
    print('')
    print('训练集与测试集不同:简单的线性归回与k=1时的局部加权线性回归对比:')
    print('k=1时,误差大小为:', rssError(abY[100:199], yHat1.T))
    ws = standRegres(abX[0:99], abY[0:99])
    yHat = np.mat(abX[100:199]) * ws
    print('简单的线性回归误差大小:', rssError(abY[100:199], yHat.T.A))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1584464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

头歌-机器学习 第16次实验 EM算法

第1关:极大似然估计 任务描述 本关任务:根据本节课所学知识完成本关所设置的选择题。 相关知识 为了完成本关任务,你需要掌握: 什么是极大似然估计; 极大似然估计的原理; 极大似然估计的计算方法。 什么是极大似然估计 没有接触过或者没有听过”极大似然估计“的同学…

[dvwa] file upload

file upload 0x01 low 直接上传.php 内容写<? eval($_POST[jj]);?> 用antsword连 路径跳两层 0x02 medium 添加了两种验证&#xff0c;格式为图片&#xff0c;大小限制小于1000 上传 POST /learndvwa/vulnerabilities/upload/ HTTP/1.1 Host: dvt.dv Content-Le…

✌2024/4/6—力扣—最长公共前缀✌

代码实现&#xff1a; char *longestCommonPrefix(char **strs, int strsSize) {if (strsSize 0) {return "";}for (int i 0; i < strlen(strs[0]); i) { // 列for (int j 1; j < strsSize; j) { // 行if (strs[0][i] ! strs[j][i]) { // 如果比较字符串的第…

Covalent Network(CQT)推出以太坊质押迁移计划,以增强长期结构化数据可用性、塑造万亿级 LLM 参数体系

作为 Web3 领先的链上数据层&#xff0c;Covalent Network&#xff08;CQT&#xff09;宣布了其将质押操作从 Moonbeam 迁移回以太坊的决定。此举是 Covalent Network&#xff08;CQT&#xff09;走向以太坊时光机&#xff08;EWM&#xff09;的第一步&#xff0c;EWM 是一个为…

Android自定义控件ScrollView实现上下滑动功能

本文实例为大家分享了Android ScrollView实现上下滑动功能的具体代码&#xff0c;供大家参考&#xff0c;具体内容如下 package com.example.zhuang; import android.content.Context; import android.util.AttributeSet; import android.util.DisplayMetrics; import android…

Asterisk 21.2.0编译安装经常遇到的问题和解决办法之pjproject

目录 Asterisk社区官方的说法然而买家秀是这样的pjproject-2.14下载不了的问题如何解决 Asterisk社区官方的说法 编译安装Asterisk 21.2.0版本 按照官网文档&#xff0c;原则上只需要如下几步&#xff1a; ./contrib/scripts/install_prereq install ./configure make make i…

建立一个网站需要多长时间?如何从零开始制作企业网站,建站流程是怎么样的

为了维持你的品牌形象&#xff0c;你必须有一个在线的网站存在&#xff0c;但是创建一个网站需要多长时间呢&#xff1f;从开始到结束&#xff0c;你最期待什么&#xff1f; 我们将介绍网站开发过程的步骤以及每个步骤可能需要多少时间。我们还将探讨您设计和部署网站的选项&a…

手机银行客户端框架之TMF框架介绍

腾讯移动开发平台&#xff08;Tencent Mobile Framework&#xff09;整合了腾讯在移动产品中开发、测试、发布和运营的技术能力&#xff0c;为企业提供一站式、覆盖全生命周期的移动端技术平台。核心服务包括移动客户端开发组件、H5容器、灰度发布、热更新、离线包、网关服务、…

[【JSON2WEB】 13 基于REST2SQL 和 Amis 的 SQL 查询分析器

【JSON2WEB】01 WEB管理信息系统架构设计 【JSON2WEB】02 JSON2WEB初步UI设计 【JSON2WEB】03 go的模板包html/template的使用 【JSON2WEB】04 amis低代码前端框架介绍 【JSON2WEB】05 前端开发三件套 HTML CSS JavaScript 速成 【JSON2WEB】06 JSON2WEB前端框架搭建 【J…

wpf下如何实现超低延迟的RTMP或RTSP播放

技术背景 我们在做Windows平台RTMP和RTSP播放模块对接的时候&#xff0c;有开发者需要在wpf下调用&#xff0c;如果要在wpf下使用&#xff0c;只需要参考C#的对接demo即可&#xff0c;唯一不同的是&#xff0c;视频流数据显示的话&#xff0c;要么通过控件模式&#xff0c;要么…

Python球球大作战

文章目录 写在前面球球大作战程序设计注意事项写在后面 写在前面 安装pygame的命令&#xff1a; pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pygame球球大作战 《球球大作战》是一款简单易上手、充满趣味性和竞技性的休闲手游。游戏的核心玩法可以用一句话概…

李廉洋:4.11黄金原油早盘#行情走势#分析及策略。

美国通胀数据超出预期&#xff0c;抑制了对美联储降息的押注。Coex Partners有限公司宏观经济学家Henrik Gullberg表示&#xff1a;“对新兴市场和风险资产来说&#xff0c;(通胀)高企持续时间更长是个坏消息&#xff0c;还因为它增加了美国和全球经济更明显下滑的风险。CPI数据…

(一)基于IDEA的JAVA基础13

数组遍历 遍历数组就是把数组内的数据一个个的取出来 1.我们可以用for循环&#xff0c;依次把数字类的元素取出来。 2.增强型for循环。 用第一个方法写一下&#xff0c;看一下 public class Test01 { public static void main(String[] args) { //存储一组数据{…

计算机网络 Telnet远程访问交换机和Console终端连接交换机

一、实验要求和内容 1、配置交换机进入特权模式密文密码为“abcd两位班内学号”&#xff0c;远程登陆密码为“123456” 2、验证PC0通过远程登陆到交换机上&#xff0c;看是否可以进去特权模式 二、实验步骤 1、将一台还没配置的新交换机&#xff0c;利用console线连接设备的…

如何在极狐GitLab 使用Docker 仓库功能

本文作者&#xff1a;徐晓伟 GitLab 是一个全球知名的一体化 DevOps 平台&#xff0c;很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版&#xff0c;专门为中国程序员服务。可以一键式部署极狐GitLab。 本文主要讲述了如何在[极狐GitLab…

Unity 获取RenderTexture像素颜色值

拿来吧你~ &#x1f9aa;功能介绍&#x1f32d;Demo &#x1f9aa;功能介绍 &#x1f4a1;不通过Texture2D 而是通过ComputerShader 提取到RenderTexture的像素值&#xff0c;效率有提升哦&#xff01; &#x1f4a1;通过扩展方法调用&#xff0c;方便快捷&#xff1a;xxxRT.G…

借助 Keras 3 轻松上手 Gemma 模型

作者 / Keras 产品经理 Martin Grner Keras 团队非常高兴地宣布&#xff0c;KerasNLP 集合现已支持 Gemma&#xff01;Gemma 是先进的轻量级开放模型系列&#xff0c;采用了与构建 Gemini 模型相同的研究和技术。借助 Keras 3&#xff0c;Gemma 可以在 JAX、PyTorch 和 TensorF…

【MySQL数据库 | 第二十四篇】Limit语句的性能问题和调优策略

前言&#xff1a; MySQL作为最流行的关系型数据库管理系统之一&#xff0c;被广泛应用于各种规模和类型的应用程序中。其强大的功能和灵活的查询语言使得开发人员能够高效地执行各种数据操作和分析。 然而&#xff0c;在处理大量数据或复杂查询时&#xff0c;一些开发人员可能…

【QT+QGIS跨平台编译】175:【QGIS_App跨平台编译】—【错误处理:未定义的class APP_EXPORT】

点击查看专栏目录 文章目录 一、未定义的class APP_EXPORT二、错误处理 一、未定义的class APP_EXPORT 报错信息&#xff1a; 二、错误处理 第18行增加&#xff1a; #include "qgis_app.h"

【MYSQL锁】透彻地理解MYSQL锁

&#x1f525;作者主页&#xff1a;小林同学的学习笔录 &#x1f525;mysql专栏&#xff1a;小林同学的专栏 目录 1.锁 1.1 概述 1.2 全局锁 1.2.1 语法 1.2.1.1 加全局锁 1.2.1.2 数据备份 1.2.1.3 释放锁 1.2.1.4 特点 1.2.1.5 演示 1.3 表级锁 1.3.1 介绍 …