【0基础学机器学习】2.决策树
决策树模型笔记1. 基础知识基本模型形式决策树是一种常见的监督学习模型既可以做分类也可以做回归。它通过一系列“如果…那么…”的规则不断划分特征空间最终在叶子节点给出预测结果。对于分类任务模型会根据样本特征逐层判断例如如果花瓣长度小于某个阈值进入左子树否则进入右子树最终到达某个叶子节点后叶子节点中占比最高的类别就是预测类别。核心目标决策树的核心目标是在每一次节点划分时找到一个最优特征和最优切分点让划分后的子节点尽可能“纯”。分类任务中常见目标包括让同一类别样本尽量落到同一个叶子节点降低节点的不确定性提升整体分类准确率损失函数决策树通常不直接写成统一的全局损失函数最小化问题而是在每个节点上贪心地选择最优划分标准。常见划分指标有基尼指数Gini Index信息熵Entropy以基尼指数为例Gini(D) 1 - Σ(p_k)^2其中p_k表示样本集合D中第k类样本所占比例。基尼指数越小说明节点越纯。参数求解决策树的参数求解过程本质上是一个递归划分过程在当前节点中遍历候选特征为每个特征尝试不同划分阈值计算划分后的不纯度下降选择收益最大的划分方式递归生成左右子树直到满足停止条件常见停止条件包括达到最大树深度节点样本数过少节点已经足够纯应用示例Python实现本项目使用scikit-learn中的DecisionTreeClassifier实现一个经典的鸢尾花三分类任务fromsklearn.treeimportDecisionTreeClassifier modelDecisionTreeClassifier(max_depth3,random_state42)model.fit(x_train,y_train)y_predmodel.predict(x_test)注意要点决策树容易过拟合需要通过max_depth、min_samples_split等参数控制复杂度决策树对特征缩放不敏感一般不强制要求标准化树结构可解释性强适合教学演示和规则分析单棵树性能通常不如集成模型但更容易理解2. 代码实践model.pymodel.py负责定义决策树模型、训练模型和预测接口。这里统一封装了build_model()创建模型train_model()拟合训练数据predict()执行预测fromsklearn.treeimportDecisionTreeClassifierdefbuild_model(criterion:strgini,max_depth:int3,random_state:int42,)-DecisionTreeClassifier:创建决策树分类模型。returnDecisionTreeClassifier(criterioncriterion,max_depthmax_depth,random_staterandom_state,)deftrain_model(x_train,y_train,criterion:strgini,max_depth:int3,random_state:int42,)-DecisionTreeClassifier:训练决策树分类模型。modelbuild_model(criterioncriterion,max_depthmax_depth,random_staterandom_state,)model.fit(x_train,y_train)returnmodeldefpredict(model:DecisionTreeClassifier,x_test):使用训练好的模型进行预测。returnmodel.predict(x_test)train.pytrain.py负责训练流程包括训练集和测试集划分调用train_model()完成训练代码中使用了stratifyy保证分类任务中训练集和测试集的类别分布更加稳定。fromsklearn.model_selectionimporttrain_test_splitfrommodelimporttrain_modeldefsplit_data(x,y,test_size:float0.2,random_state:int42,):划分训练集和测试集。returntrain_test_split(x,y,test_sizetest_size,random_staterandom_state,stratifyy,)defrun_train(x,y,test_size:float0.2,random_state:int42,criterion:strgini,max_depth:int3,):完成数据划分和模型训练。x_train,x_test,y_train,y_testsplit_data(x,y,test_sizetest_size,random_staterandom_state,)modeltrain_model(x_train,y_train,criterioncriterion,max_depthmax_depth,random_staterandom_state,)returnmodel,x_train,x_test,y_train,y_testeval.pyeval.py负责评估模型效果输出准确率accuracy混淆矩阵confusion_matrix分类报告classification_report这些指标能帮助我们同时观察总体表现和各类别的精确率、召回率、F1 值。fromsklearn.metricsimportaccuracy_score,classification_report,confusion_matrixfrommodelimportpredictdefevaluate_model(model,x_test,y_test)-dict:评估决策树分类模型效果。y_predpredict(model,x_test)return{accuracy:accuracy_score(y_test,y_pred),confusion_matrix:confusion_matrix(y_test,y_pred),classification_report:classification_report(y_test,y_pred),}dataload.pydataload.py从sklearn.datasets中加载鸢尾花数据集特征x4 个花萼/花瓣数值特征标签y3 个类别标签target_names类别名称用于可视化展示importpandasaspdfromsklearn.datasetsimportload_irisdefload_data():加载 sklearn 自带的 iris 分类数据集。datasetload_iris()xpd.DataFrame(dataset.data,columnsdataset.feature_names)ypd.Series(dataset.target,nametarget)returnx,y,dataset.target_namesrun.pyrun.py是项目入口负责串联整个流程加载数据训练模型评估模型保存可视化结果可视化部分包含决策树结构图混淆矩阵图frompathlibimportPathimportmatplotlib matplotlib.use(Agg)importmatplotlib.pyplotaspltfromsklearn.metricsimportConfusionMatrixDisplayfromsklearn.treeimportplot_treefromdataloadimportload_datafromevalimportevaluate_modelfrommodelimportpredictfromtrainimportrun_traindefsave_plots(model,x_test,y_test,class_names)-list[Path]:保存决策树结构图和混淆矩阵图。current_dirPath(__file__).resolve().parent output_dircurrent_dir/figureoutput_dir.mkdir(exist_okTrue)tree_pathoutput_dir/decision_tree_structure.pngcm_pathoutput_dir/decision_tree_confusion_matrix.pngfig,axplt.subplots(figsize(16,10))plot_tree(model,feature_nameslist(x_test.columns),class_nameslist(class_names),filledTrue,roundedTrue,axax,)fig.tight_layout()fig.savefig(tree_path,dpi150,bbox_inchestight)plt.close(fig)fig,axplt.subplots(figsize(6,5))ConfusionMatrixDisplay.from_predictions(y_test,predict(model,x_test),display_labelsclass_names,cmapBlues,axax,)fig.tight_layout()fig.savefig(cm_path,dpi150,bbox_inchestight)plt.close(fig)return[tree_path,cm_path]defmain()-None:x,y,class_namesload_data()model,x_train,x_test,y_train,y_testrun_train(x,y)metricsevaluate_model(model,x_test,y_test)plot_pathssave_plots(model,x_test,y_test,class_names)print(Decision Tree Demo)print(fTrain size:{len(x_train)}, Test size:{len(x_test)})print(fAccuracy:{metrics[accuracy]:.4f})print(Confusion Matrix:)print(metrics[confusion_matrix])print(Classification Report:)print(metrics[classification_report])print(Saved plots:)forplot_pathinplot_paths:print(plot_path)if__name____main__:main()运行结果运行python run.py后终端会输出训练集/测试集大小、准确率、混淆矩阵和分类报告。图片会保存在当前目录下的figure/文件夹中通常包括decision_tree_structure.pngdecision_tree_confusion_matrix.png如果分类结果接近满分这是因为鸢尾花数据集本身比较经典且较容易划分适合作为决策树入门 demo。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2440084.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!