文章目录
- 一、什么是K近邻算法
- 二、KNN算法流程总结
- 三、Scikit-learn工具
- 1、安装
- 2、导入
- 3、简单使用
- 三、距离度量
- 1、欧式距离
- 2、曼哈顿距离
- 3、切比雪夫距离
- 4、闵可夫斯基距离
- 5、K值的选择
- 6、KD树
一、什么是K近邻算法
如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。
二、KNN算法流程总结
- 1、计算已知类别数据集中的点与当前点之间的距离
- 2、按距离递增排序
- 3、选取与当前距离最小的k个点
- 4、统计前k个点所在的类别出现的频率
- 5、返回前k个点出现频率最高的类别作为当前点的预测分类
三、Scikit-learn工具
1、安装
pip3 install scikit-learn
2、导入
import sklearn
3、简单使用
三、距离度量
1、欧式距离
欧式距离是最容易直观理解的距离度量方法,我们小学、初中、高中接触到的两个点在空间中的距离一般都是值欧式距离。
2、曼哈顿距离
3、切比雪夫距离
4、闵可夫斯基距离
5、K值的选择
6、KD树
import numpy as np
# 自己实现kd树
# 一、构建kd树
# 1.确定根据哪一个维度进行划分,求方差,方差越大,数据越分散
# 2.以哪个点为切面,求中位数,离中位数越近的点作为根节点
# 3.比中位数的该维度小的放左边,大的放右边
# 4.重复以上步骤,所有的点就都在树中了
class KdNode(object):
def __init__(self, node_data, split_index, left, right):
self.node_data = np.array(node_data) # 节点的数据
self.split_index = split_index # 分割的维度的序号
self.left = left # 左节点
self.right = right # 右节点
class KdTree(object):
split_index_list = np.array([])
data = np.array([])
rootNode = None
def __init__(self, data):
self.k = len(data[0]) # 获取数据的维度
self.data = np.array(data) # 所有的数据
# 获取分割的维度顺序数组
self.getSplitIndexList()
# 构建树
self.rootNode = self.createNode(0, self.data)
def getSplitIndexList(self):
# 获取方差排序后的下标的数组,最后[::-1来反转]
self.split_index_list = np.argsort([np.var(self.data[:, (i)]) for i in range(self.k)])[::-1]
def closest_to_median_index(self, array):
median = np.median(array)
diff = np.abs(array - median)
return diff.argmin() # 返回第一个最小差值的索引
def createNode(self, index, dataList):
if len(dataList) == 0:
return None
split_index = self.split_index_list[index]
split_next = (index + 1) % self.k
# 获取分割维度的中位数下标
data_index = self.closest_to_median_index(dataList[:,(split_index)])
# 获取该位置的数据
rootData = dataList[data_index]
# 删除找到的这个节点
dataList = np.delete(dataList, data_index, 0)
# 获取左侧的所有数据
leftData = dataList[dataList[:,(split_index)] <= rootData[split_index]]
# 获取右侧所有的数据
rightData = dataList[dataList[:,(split_index)] > rootData[split_index]]
return KdNode(rootData, split_index, self.createNode(split_next, leftData), self.createNode(split_next, rightData))