KNN-近邻算法 及 模型的选择与调优(facebook签到地点预测)

news2025/7/17 20:24:04

什么是K-近邻算法(K Nearest Neighbors)

1、K-近邻算法(KNN)

1.1 定义

如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

来源:KNN算法最早是由Cover和Hart提出的一种分类算法

1.2 距离公式

两个样本的距离可以通过如下公式计算,又叫欧式距离

距离公式:
在这里插入图片描述
KNN核心思想:
你的“邻居”来推断出你的类别

  • 计算距离:
    • 距离公式
      • 欧氏距离
      • 曼哈顿距离 绝对值距离
      • 明可夫斯基距离

2、电影类型分析

假设我们有现在几部电影
在这里插入图片描述
其中? 7号电影不知道类别,如何去预测?我们可以利用K近邻算法的思想
下方是根据欧氏距离计算结果:
在这里插入图片描述
K值是重要影响元素:
当我们看,如果k=1,假如只有第二行一个”邻居“,那么就计算的距离误差就比较大,样本量过少:
在这里插入图片描述
如果k=7,也就是再多一行数据,假设是“封神榜”,那么也就是说邻近的算法中,动作片的类占多数,那么我们就会将位置的那行数据预测为动作片,但是实际位置那行数据的接吻镜头是比较多的,应该是个爱情片,所以预测也是错误的!

在这里插入图片描述

  • 总结:
    • k 值取得过小,容易受到异常点的影响
    • k 值取得过大,样本不均衡的影响

3、K-近邻算法API

  • sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm=‘auto’)
    • n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数
    • algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选用于计算最近邻居的算法:‘ball_tree’将会使用 BallTree,‘kd_tree’将使用 KDTree。‘auto’将尝试根据传递给fit方法的值来决定最合适的算法。 (不同实现方式影响效率)

4、代码:鸢尾花案例

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler

"""
用KNN算法对鸢尾花进行分类
:return:
"""
# 1)获取数据
iris = load_iris()
# print(iris)
# print(iris.target_names)
# print(iris.DESCR)

# 2)划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=22)

# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 4)KNN算法预估器
estimator = KNeighborsClassifier(n_neighbors=3, algorithm='auto')
estimator.fit(x_train, y_train)

# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)

# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)

在这里插入图片描述

那么说来说去,“邻居的数量” K 到底怎么取呢?
这就涉及到模型的选择与调优了;

5、为什么需要交叉验证

  • 交叉验证目的:为了让被评估的模型更加准确可信

6、什么是交叉验证(cross validation)

  • 交叉验证:将拿到的训练数据,分为训练和验证集。以下图为例:将数据分成4份,其中一份作为验证集。然后经过4次(组)的测试,每次都更换不同的验证集。即得到4组模型的结果,取平均值作为最终结果。又称4折交叉验证。
    • 训练集:训练集+验证集
    • 测试集:测试集
      在这里插入图片描述
      问题:那么这个只是对于参数得出更好的结果,那么怎么选择或者调优参数呢?

7、超参数搜索-网格搜索(Grid Search)

通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,网格搜索帮我们实现了这个调参过程,首先需要对模型预设几种超参数组合,每组超参数都采用交叉验证来进行评估,最后选出最优参数组合建立模型。
在这里插入图片描述

7.1、模型选择与调优 API

  • sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)
    • 对估计器的指定参数值进行详尽搜索
    • estimator:估计器对象
    • param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}
    • cv:指定几折交叉验证
    • fit:输入训练数据
    • score:准确率
  • 结果分析:
    • bestscore:在交叉验证中验证的最好结果_
    • bestestimator:最好的参数模型
    • cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果

7.1、网格搜索与交叉验证代码

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler


"""
用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证
:return:
"""
# 1)获取数据
iris = load_iris()

# 2)划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=22)

# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 4)KNN算法预估器
estimator = KNeighborsClassifier()

# 加入网格搜索与交叉验证
# 参数准备
param_dict = {"n_neighbors": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
estimator.fit(x_train, y_train)

# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)

# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)

# 最佳参数:best_params_
print("最佳参数:\n", estimator.best_params_)
# 最佳结果:best_score_
print("最佳结果:\n", estimator.best_score_)
# 最佳估计器:best_estimator_
print("最佳估计器:\n", estimator.best_estimator_)
# 交叉验证结果:cv_results_
print("交叉验证结果:\n", estimator.cv_results_)

在这里插入图片描述

8、facebook 签到位置预测

在这里插入图片描述
在这里插入图片描述

  • 数据介绍:将根据用户的位置,准确性和时间戳预测用户正在查看的业务。
  • train.csv
    • row_id:登记事件的ID
    • xy:坐标
    • 准确性:定位准确性
    • 时间:时间戳
    • place_id:业务的ID,这是您预测的目标

官网:https://www.kaggle.com/navoshta/grid-knn/data

8.1、流程分析

对于数据做一些基本处理(这里所做的一些处理不一定达到很好的效果,我们只是简单尝试,有些特征我们可以根据一些特征选择的方式去做处理)

1、缩小数据集范围 DataFrame.query()(选择性处理!)
2、删除没用的日期数据 DataFrame.drop(可以选择保留)
3、将签到位置少于n个用户的删除

place_count = data.groupby('place_id').count()
tf = place_count[place_count.row_id > 3].reset_index()
data = data[data['place_id'].isin(tf.place_id)]

4、分割数据集
5、标准化处理
6、k-近邻预测

8.2、代码

import pandas as pd
# 1、获取数据
data = pd.read_csv("train.csv")
data.head()

在这里插入图片描述

# 1)处理时间特征
time_value = pd.to_datetime(data["time"], unit="s")
date = pd.DatetimeIndex(time_value)
data["day"] = date.day
data["weekday"] = date.weekday
data["hour"] = date.hour
data.head()

在这里插入图片描述

# 2)过滤签到次数少的地点
place_count = data.groupby("place_id").count()["row_id"]
data_final = data[data["place_id"].isin(place_count[place_count > 3].index.values)]
data_final.head()

在这里插入图片描述

# 筛选特征值和目标值
x = data_final[["x", "y", "accuracy", "day", "weekday", "hour"]]
y = data_final["place_id"]

在这里插入图片描述

# 数据集划分
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y)
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV

# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 4)KNN算法预估器
estimator = KNeighborsClassifier()

# 加入网格搜索与交叉验证
# 参数准备
param_dict = {"n_neighbors": [3, 5, 7, 9]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=3)
estimator.fit(x_train, y_train)

# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)

# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)

# 最佳参数:best_params_
print("最佳参数:\n", estimator.best_params_)
# 最佳结果:best_score_
print("最佳结果:\n", estimator.best_score_)
# 最佳估计器:best_estimator_
print("最佳估计器:\n", estimator.best_estimator_)
# 交叉验证结果:cv_results_
print("交叉验证结果:\n", estimator.cv_results_)

这个结果数据量比较大,毕竟两千万训练数据了,各位可自行试验及调参;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1102840.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【网络协议】聊聊从物理层到MAC层 ARP 交换机

物理层 物理层其实就是电脑、交换器、路由器、光纤等。组成一个局域网的方式可以使用集线器。可以将多台电脑连接起来,然后进行将数据转发给别的端口。 数据链路层 Hub其实就是广播模式,如果A电脑发出一个包,B、C电脑也可以收到。那么数据…

zk的二阶段提交图解

第一阶段:每次的数据写入事件作为提案广播给所有Follower结点;可以写入的结点返回确认信息ACK;第二阶段:Leader收到一半以上的ACK信息后确认写入可以生效,向所有结点广播COMMIT将提案生效。

Unity 实现一个FPS游戏的全过程

Unity是一款功能强大的游戏引擎,它提供了各种各样的工具和功能,以帮助开发者轻松地创建精美的3D游戏和应用程序。在本文中,我们将使用Unity实现一个FPS游戏的全过程,从场景设计、角色控制、敌人AI到最终的打包发布。 对啦&#x…

开源项目汇总

element-plus 人人开源 人人开源 多租户 若依 jeecg https://gitee.com/jeecg/jeecg?_fromgitee_search#https://gitee.com/link?targethttp%3A%2F%2Fidoc.jeecg.com jeeplus JeePlus快速开发平台 j2eefast Sa-Plus

地震勘探原理部分问题解答

1、二维/三维(陆地/海洋)地震勘探,炮点(激发点)和检波点(接收点)的排布位置如何?画图作答? (1)陆地地震勘探 二维陆地地震野外采集:震…

过滤器(Filter)和拦截器(Interceptor)有什么不同?

过滤器(Filter)和拦截器(Interceptor)是用于处理请求和响应的中间件组件,但它们在实现方式和应用场景上有一些不同。 实现方式: 过滤器是Servlet规范中定义的一种组件,通常以Java类的形式实现。过滤器通过在…

Python文件共享+cpolar内网穿透:轻松实现公网访问

文章目录 1.前言2.本地文件服务器搭建2.1.Python的安装和设置2.2.cpolar的安装和注册 3.本地文件服务器的发布3.1.Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1.前言 数据共享作为和连接作为互联网的基础应用,不仅在商业和办公场景有广泛的应用&#…

yolov7改进优化之蒸馏(一)

最近比较忙,有一段时间没更新了,最近yolov7用的比较多,总结一下。上一篇yolov5及yolov7实战之剪枝_CodingInCV的博客-CSDN博客 我们讲了通过剪枝来裁剪我们的模型,达到在精度损失不大的情况下,提高模型速度的目的。上一…

6.6 图的应用

思维导图: 6.6.1 最小生成树 ### 6.6 图的应用 #### 主旨:图的概念可应用于现实生活中的许多问题,如网络构建、路径查询、任务排序等。 --- #### 6.6.1 最小生成树 **概念**:要在n个城市中建立通信联络网,则最少需…

【Mysql】Innodb数据结构(四)

概述 MySQL 服务器上负责对表中数据的读取和写入工作的部分是存储引擎 ,而服务器又支持不同类型的存储引擎,比如 InnoDB 、MyISAM 、Memory 等,不同的存储引擎一般是由不同的人为实现不同的特性而开发的,真实数据在不同存储引擎中…

如何让大模型自由使用外部知识与工具

本文将分享为什么以及如何使用外部的知识和工具来增强视觉或者语言模型。 全文目录: 1. 背景介绍 OREO-LM: 用知识图谱推理来增强语言模型 REVEAL: 用多个知识库检索来预训练视觉语言模型 AVIS: 让大模型用动态树决策来调用工具 技术交流群 建了技术交流群&a…

微信小程序会议OA系统

Flex弹性布局 Flex弹性布局是一种 CSS3 的布局模式,也叫Flexbox。它可以让容器中的元素按一定比例自动分配空间,使得它们在不同宽度、高度等情况下仍能保持整齐和密集不间隙地排列。 在使用Flexbox弹性布局时,首先需要创建一个容器和若干个…

JNDI-Injection-Exploit工具安装

从github上下载安装 git clone https://github.com/welk1n/JNDI-Injection-Exploit.git 打开 cd JNDI-Injection-Exploit 编译安装,Maven入门百科_maven中quickstart是什么意思-CSDN博客 mvn clean package -DskipTests 因为提示mvn错误,解决下…

Spring中Setter注入详解

目录 一、setter注入是什么 二、setter注入详解 三、JDK内置类型注入方式 3.1 数组类型 3.2 set集合类型 3.3 list集合 3.4 map集合 3.5 properties类型 四、用户自定义类型 一、setter注入是什么 书接上回,我们发现在Spring配置文件中为类对象的属性赋值时&#x…

java SpringBoot基础

目录 SpringBootWeb快速入门前言需求开发步骤创建SpringBoot工程(需要联网)定义请求处理类运行测试 HTTP协议HTTP概述HTTP-请求协议格式GET方式的请求协议POST方式的请求协议 HTTP-响应协议格式HTTP-协议解析 WEB服务器-Tomcat简介基本使用注意事项 Spri…

智慧渔业方案:AI渔政视频智能监管平台助力水域禁渔执法

一、方案背景 国内有很多水库及河流设立了禁渔期,加强渔政执法监管对保障国家渔业权益、维护渔业生产秩序、保护渔民群众生命财产安全、推进水域生态文明建设具有重要意义。目前,部分地区的监管手段信息化水平低下,存在人员少、职责多、任务…

排序【七大排序】

文章目录 1. 排序的概念及引用1.1 排序的概念1.2 常见的排序算法 2. 常见排序算法的实现2.1 插入排序2.1.1基本思想:2.1.2 直接插入排序2.1.3 希尔排序( 缩小增量排序 ) 2.2 选择排序2.2.1基本思想:2.2.2 直接选择排序:2.2.3 堆排序 2.3 交换排序2.3.1冒…

新一代开源语音库CoQui TTS冲到了GitHub 20.5k Star

Coqui TTS 项目介绍 Coqui 文本转语音(Text-to-Speech,TTS)是新一代基于深度学习的低资源零样本文本转语音模型,具有合成多种语言语音的能力。该模型能够利用共同学习技术,从各语言的训练资料集转换知识,来…

Leetcode刷题详解——将x减到0的最小操作数

1. 题目链接:1658. 将 x 减到 0 的最小操作数 2. 题目描述: 给你一个整数数组 nums 和一个整数 x 。每一次操作时,你应当移除数组 nums 最左边或最右边的元素,然后从 x 中减去该元素的值。请注意,需要 修改 数组以供接下来的操作…