【机器学习】python实现随机森林

news2025/8/17 6:21:14

目录

一、模型介绍

1. 集成学习

2. bagging

3. 随机森林算法

二、随机森林算法优缺点

三、代码实现

四、疑问

五、总结


本文使用mnist数据集,进行随机森林算法。

一、模型介绍

1. 集成学习

集成学习通过训练学习出多个估计器,当需要预测时通过结合器将多个估计器的结果整合起来当作最后的结果输出。

集成学习的优势是提升了单个估计器的通用性与鲁棒性,比单个估计器拥有更好的预测性能。集成学习的另一个特点是能方便的进行并行化操作。

2. bagging

  Bagging 算法是一种集成学习算法,其全称为自助聚集算法(Bootstrap aggregating),顾名思义算法由 Bootstrap 与 Aggregating 两部分组成。

算法的具体步骤为:假设有一个大小为 N 的训练数据集,每次从该数据集中有放回的取选出大小为 M 的子数据集,一共选 K 次,根据这 K 个子数据集,训练学习出 K 个模型。当要预测的时候,使用这 K 个模型进行预测,再通过取平均值或者多数分类的方式,得到最后的预测结果。

3. 随机森林算法

将多个决策树结合在一起,每次数据集是随机有放回的选出,同时随机选出部分特征作为输入,所以该算法被称为随机森林算法。可以看到随机森林算法是以决策树为估计器的Bagging算法。

上图展示了随机森林算法的具体流程,其中结合器在分类问题中,选择多数分类结果作为最后的结果,在回归问题中,对多个回归结果取平均值作为最后的结果。

使用Bagging算法能降低过拟合的情况,从而带来了更好的性能。单个决策树对训练集的噪声非常敏感,但通过Bagging算法降低了训练出的多颗决策树之间关联性,有效缓解了上述问题。


二、随机森林算法优缺点

1. 对于很多种资料,可以产生高准确度的分类器
2. 可以处理大量的输入变量
3. 可以在决定类别时,评估变量的重要性
4. 在建造森林时,可以在内部对于一般化后的误差产生不偏差的估计
5. 包含一个好方法可以估计丢失的资料,并且如果有很大一部分的资料丢失,仍可以维持准确度
6. 对于不平衡的分类资料集来说,可以平衡误差
7. 可被延伸应用在未标记的资料上,这类资料通常是使用非监督式聚类,也可侦测偏离者和观看资料
8. 学习过程很快速
 


三、代码实现

代码:

from sklearn.ensemble import  RandomForestClassifier  # 随机森林分类器
from sklearn.datasets import load_digits  # 数据集
from sklearn.model_selection import train_test_split  # 数据分割模块
from sklearn.metrics import classification_report  # 生产报告
from sklearn.metrics import confusion_matrix

# 1.加载数据
mnist = load_digits()

# 2.分割数据集
x_train, x_test, y_train, y_test, image_train, image_test = train_test_split(mnist.data, mnist.target, mnist.images,
                                                                             test_size=0.25, random_state=33)

# 3.训练分类器
rfc = RandomForestClassifier(n_jobs=-1)
train_history = rfc.fit(x_train, y_train)

# 4.测试
pred = rfc.predict(x_test)
report = classification_report(y_test, pred)
confusion_mat = confusion_matrix(y_test, pred)

print(report)
print(confusion_mat)

 结果:

              precision    recall  f1-score   support

           0       0.97      1.00      0.99        35
           1       0.98      1.00      0.99        54
           2       1.00      0.95      0.98        44
           3       0.98      0.89      0.93        46
           4       0.94      0.94      0.94        35
           5       0.92      0.94      0.93        48
           6       0.98      0.98      0.98        51
           7       0.92      1.00      0.96        35
           8       0.95      0.95      0.95        58
           9       0.93      0.93      0.93        44

    accuracy                           0.96       450
   macro avg       0.96      0.96      0.96       450
weighted avg       0.96      0.96      0.96       450

[[35  0  0  0  0  0  0  0  0  0]
 [ 0 54  0  0  0  0  0  0  0  0]
 [ 1  0 42  0  0  0  0  0  0  1]
 [ 0  0  0 41  0  2  0  1  1  1]
 [ 0  0  0  0 33  0  0  2  0  0]
 [ 0  0  0  0  0 45  1  0  1  1]
 [ 0  0  0  0  1  0 50  0  0  0]
 [ 0  0  0  0  0  0  0 35  0  0]
 [ 0  1  0  0  1  1  0  0 55  0]
 [ 0  0  0  1  0  1  0  0  1 41]]


四、疑问

以上代码是从书上学习的,但是有一些问题:

1. 为什么不划分验证集,结果如何以图片的形式可视化?

2. 为什么不进行数据的预处理,如下代码所示:

def get_mnist_data():

    (x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()

    # 从训练集中分配验证集
    x_val = x_train_original[50000:] #(10000,28,28)每一个图片
    y_val = y_train_original[50000:] #10000,每个图片的标签
    x_train = x_train_original[:50000]# (50000,28,28)
    y_train = y_train_original[:50000]#50000

    # 将图像转换为四维矩阵(nums,rows,cols,channels), 这里把数据从unint类型转化为float32类型, 提高训练精度。
    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
    #x_train.shape[0]表示x_train的行数。28是图片自身的大小。这里与原本的LeNet-5不同,原有的输入大小是32
    x_val = x_val.reshape(x_val.shape[0], 28, 28, 1).astype('float32')
    x_test = x_test_original.reshape(x_test_original.shape[0], 28, 28, 1).astype('float32')

    # 原始图像的像素灰度值为0-255,为了提高模型的训练精度,通常将数值归一化映射到0-1。
    x_train = x_train / 255
    x_val = x_val / 255
    x_test = x_test / 255

    # 图像标签一共有10个类别即0-9,这里将其转化为独热编码(One-hot)向量
    y_train = np_utils.to_categorical(y_train)#标签都变成为二维
    y_val = np_utils.to_categorical(y_val)
    y_test = np_utils.to_categorical(y_test_original)

    return x_train, y_train, x_val, y_val, x_test, y_test

不进行归一化,不转化为独热编码向量?

3. 所用划分测试集和训练集的代码:

x_train, x_test, y_train, y_test, image_train, image_test = train_test_split(mnist.data, mnist.target, mnist.images,
                                                                             test_size=0.25, random_state=33)

经查询,只能划分测试集和训练集,那如果想要画验证集,怎么办呢?

4. 如何绘制loss曲线、accuracy曲线?按照tensprflow的结构,总是会报错,说不存在history

5. 损失函数如何定义呢?代码中似乎没有。

6. 不需要定义epochs等超参数吗?sklearn库的fit函数:

train_history = rfc.fit(x_train, y_train)

不能添加上面所说的超参数 


五、总结

总之,可能是我对sklearn库了解的不够,感觉和写cnn完全不是一个思路,还需要进一步的学习。如果该代码有进一步的后续改进,会在评论区发出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/33081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[附源码]SSM计算机毕业设计流浪动物救助网站JAVA

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

【百度AI_人脸识别】图片对比相似度、人脸对比登录(调摄像头)

人脸对比 此文档功能: 两张人脸图片相似度对比:比对两张图片中人脸的相似度,并返回相似度分值。存档一张图片与调用的摄像中的人脸进行对比。项目、资源下载:https://download.csdn.net/download/m0_70083523/87150842?spm1001.2…

编译原理—语法制导翻译、S属性、L属性、自上而下、自下而上计算

编译原理—语法制导翻译、S属性、L属性、自上而下、自下而上计算1.语法制导翻译1.1属性文法1.2算术表达式的计数器1.3属性的分类1.4属性依赖图继承属性的计算1.5语义规则的计算方法1.6属性计算次序2. S属性定义2.1 语法树与分析树2.2 语法树与DAG2.2.1构造表达式的语法树(DAG)2…

Android中常见的那些内存泄漏——【问题分析+方案】

1.静态Activity(Activity上下文Context)和View 静态变量Activity和View会导致内存泄漏,在下面代码中对Activity的Context和TextView设置为静态对象,从而产生内存泄漏; public class MemoryTestActivity extends AppCompatActivity {private…

[附源码]SSM计算机毕业设计健身健康规划系统JAVA

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

noexcept说明符/运算符

一、noexcept说明符 1、语法 (1)noexcept 与 noexcept(true) 相同 (2)noexcept(表达式) 如果 表达式 求值为 true,那么声明函数不会抛出任何异常。 (3)throw() //c1…

Ubuntu配置FTP服务

参考目录1.安装FTP服务器软件2.配置FTP服务3.Ubuntud登录ftp服务器4.windows下通过cuteFTPlianjei1.安装FTP服务器软件 (1) FTP文件传送协议(File Transfer Protocol,简称FTP),是一个用于从一台主机到另一台主机传输文件的协议。 (2)Linux下有…

Jetpack 之 LiveData 实现事件总线

事件总线相信大家很多时候都会用到,那大家常用的也就是常青树 EventBus,以及 RxJava 流行起来的后起之秀 RxBus。它们的使用方式都差不多,思想也都是基于观察者模式,正好 LiveData 的核心思想也是观察者模式,因此我们完…

做Android 开发这么久,还不明白 Android Framework 知识重要性?

Framework作为Android的框架层,为App提供了很多API调用,但很多机制都是Framework包装好后直接给App用的,如果不懂这些机制的原理,就很难在这基础上进行优化。 从做Android的第一天起,你一定听过无数次关于Framework的…

计算机音乐-乐理知识(1)

一、节拍 节拍(Beat/Meter),是一个衡量节奏的单位,在音乐中,有一定强弱分别的一系列拍子在每隔一定时间重复出现。如 2 / 4 、 4 / 4 、 3 / 4 拍等。节拍,乐曲中表示固定单位时值和强弱规律的组织形式。 …

测试员工作三年后的工资对比,没达到这个数的都属于拖后腿了

“毕业三年的薪资是职场阶段的一个分水岭。” 不知什么时候开始,这句话深刻的引入了所有打工人的心中,程序员们自然也不例外。 事实上,这句话说的并不无道理,毕业的三年,不仅是学生到职场人身份上的一个转变&#xf…

初阶数据结构学习记录——아홉 二叉树和堆(2)

接着上一篇 之前写过一些关于堆的代码,向下调整,向上调整算法,以及常用的几个函数。这一篇继续完善堆,难度也会有所上升。先来看上一篇文末提到的创建堆算法。 首先要有空间,要有数据,之后再形成堆。我们…

9.5 利用可执行内存挑战DEP

目录 一、实验环境 二、实验思路 三、实验代码 四、实验步骤 1、寻找memcpy函数的地址 2、查看内存中可读可写可执行的内存 3、修复EBP 4、保证memcpy的源地址位于shellcode之前 一、实验环境 操作系统:windows 2000 软件:原版OD、VC6.0 二、实…

删除的数据如何恢复?误删了文件怎么恢复

文件的误删除,相信大部分人都经历过。不过因为很多人删除的文件都不算是很重要,所以有与没有并没有太大的区别。但是一旦你删除的文件正是你最近急需的,删除的数据如何恢复?别着急,可以试试以下的几种方法:…

STM32串口详解

实验一:简单的利用串口接收中断回调函数实现数据的返回 关于串口调试助手,还应知道: 发送英文字符需要用一个字符即8位,发送汉字需要两个字符即16位,如上图,发送汉字“姜”实际是发送“BD AA”而发送英文字…

外卖项目06---套餐管理业务开发(移动端的后台代码编辑开发)

菜品展示、购物车、下单 目录 一、导入用户地址簿相关功能代码 90 1.1需求分析 90 1.2数据模型 90 1.3导入功能代码 90 二、菜品展示 91 2.1需求分析 91 2.2商品展示---代码开发---梳理交互过程 92 2.3菜品展示---代码开发---修改DishController的list方法并测试 93 2…

OpenGL原理与实践——核心模式(二):Shader变量、Shader类的封装以及EBO

目录 Shader内的一些关键字 向量 举例:shader之间的数据传输,并实现渐变颜色 举例:C向shader传输数据的过程 代码整理——shader类的封装 加入颜色信息 索引绘制——EBO 整体代码以及渲染结果 Shader内的一些关键字 in:上…

网站被劫持勒索怎么办

互联网出现后的几十年时间里,世界便由一张张网串联了起来,给我们的生活带来了无限的便利。但在互联网飞速发展的同时,恶意网络攻击也随之而来,近年来,互联网攻击事件频发,不法分子利用常见的DDoS攻击、CC攻…

【生成式网络】入门篇(二):GAN的 代码和结果记录

GAN非常经典,我就不介绍具体原理了,直接上代码。 感兴趣的可以阅读,里面有更多变体。 https://github.com/rasbt/deeplearning-models/tree/master/pytorch_ipynb/gan GAN 在 MINIST上的代码和效果 import os # os.chdir(os.path.dirname(_…

springBoot集成websocket实现消息实时推送提醒

在浏览某些网页的时候,例如 WebQQ、京东在线客服服务、CSDN私信消息等类似的情况下,我们可以在网页上进行在线聊天,或者即时消息的收取与回复,可见,这种功能的需求由来已久,并且应用广泛,和pc端web系统待办…