文章目录
- 什么是随机森林?
 - 随机森林的优缺点
 - 随机森林示例——鸢尾花分类
 
什么是随机森林?
随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由多个决策树构成,相比于单个决策树算法,它分类、预测效果更好,不容易出现过度拟合的情况。
常应用于以下类型的场景:
- 预测用户贷款是否能够按时还款;
 - 预测用户是否会购买某件商品等等
 
官网:分类和回归
随机森林的优缺点
优点:
-  
可以处理高纬度的数据;
 -  
训练之前不需要特意的做特征选择;
 -  
建立很多树,预防了过拟合风险;
 
缺点:
-  
计算量相对于决策树很大,性能开销很大。
 -  
可能会导致有些数据集没有训练到,但这种几率很小。
 -  
分裂的时候,偏向于选择取值较多的特征。
 
随机森林示例——鸢尾花分类
数据集下载:
链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 
提取码:
lz3l
 
数据集介绍:
iris.scale.txt 是 libsvm 格式的鸢尾花数据集,共有五个字段。第一个为标签字段,后四个为特征字段。
 
libsvm 格式参考:机器学习:libsvm数据格式
将数据集中的随机百分之70作为训练集,剩余的作为测试集。
使用 SparkSQL 的方式读取 libsvm 格式的文件会自动生成 label 和 features 结构的数据,如下所示:
val data: DataFrame = spark.read.format("libsvm").load("iris.scale.txt")
data.show()
 
 
需求实现:
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
object Iris {
    def main(args: Array[String]): Unit = {
        val spark: SparkSession = SparkSession.builder().appName("Iris").master("local[*]").getOrCreate()
        // 加载 libsvm 格式文件的数据
        val data: DataFrame = spark.read.format("libsvm").load("C:\\Users\\Administrator\\Desktop\\iris.scale.txt")
        data.show()
        // 1.构建标签列转换对象
        val labelIndexer: StringIndexerModel = new StringIndexer()
                .setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(data)
        // 2.构建特征列转换对象,设置特征列数量
        val featureIndexer: VectorIndexerModel = new VectorIndexer()
                .setInputCol("features")
                .setOutputCol("indexedFeatures")
                .setMaxCategories(4)
                .fit(data)
        // 3.将随机百分之70作为训练集,其余为测试集
        val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
        // 4.创建随机森林对象,设置标签列与特征列以及决策树的个数
        val rf: RandomForestClassifier = new RandomForestClassifier()
                .setLabelCol("indexedLabel")
                .setFeaturesCol("indexedFeatures")
                .setNumTrees(10)
        // 5.设置预测列标签
        val labelConverter: IndexToString = new IndexToString()
                .setInputCol("prediction")
                .setOutputCol("predictedLabel")
                .setLabels(labelIndexer.labelsArray(0))
        // 6.管道组装
        val pipeline: Pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
        // 7.模型训练
        val model: PipelineModel = pipeline.fit(trainingData)
        // 8.模型预测
        val predictions: DataFrame = model.transform(testData)
        // 9.模型评估
        predictions.select("predictedLabel", "label", "features").show()
        // 10.创建错误率的计算对象
        val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("indexedLabel")
                .setPredictionCol("prediction")
                .setMetricName("accuracy")
        // 11.计算错误率
        val accuracy: Double = evaluator.evaluate(predictions)
        println(s"Test Error = ${(1.0 - accuracy)}")
        // 12.打印随机森林模型
        val rfModel: RandomForestClassificationModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
        println(s"Learned classification forest model:\n ${rfModel.toDebugString}")
    }
}
                


















