自动机器学习-auto-sklearn

news2025/6/8 12:35:36

1、前言

自动机器学习(AutoML) 旨在通过让一些通用步骤 (如数据预处理、模型选择和调整超参数) 自动化,来简化机器学习中生成模型的过程。AutoML是指尽量不通过人来设定超参数,而是使用某种学习机制,来调节这些超参数。这些学习机制包括传统的贝叶斯优化,多臂老虎机(multi-armed bandit),进化算法,还有比较新的强化学习。当我们提起AutoML时,我们更多地是说自动化数据准备(即数据的预处理,数据的生成和选择)和模型训练(模型选择和超参数调优)。这个过程的每一步都有非常多的选项(options),根据我们遇到的问题,需要设定各种不同的选项。

Auto-Sklearn 是一个基于 Python 的开源工具包,用于执行 AutoML,它采用著名的 Scikit-Learn 机器学习包进行数据处理和机器学习算法。

在正式介绍Auto-Sklearn前我们先要声明几个问题:
1、Auto-Sklearn不能再Windows上使用,不要试图挣扎了
2、不能使用非数值型数据,也就是还需要特征工程进行处理才行
3、不支持深度学习,后面我们会介绍Auto-PyTorch
4、运行时间比较长,时间设置短的话训练不充分

2、介绍

Auto-Sklearn 是改进了一般的 AutoML 方法,自动机器学习框架采用贝叶斯超参数优化方法,有效地发现给定数据集的性能最佳的模型管道。

这里另外添加了两个组件:

  • 一个用于初始化贝叶斯优化器的元学习(meta-learning)方法
  • 优化过程中的自动集成(automated ensemble)方法

这种元学习方法是贝叶斯优化的补充,用于优化 ML 框架。对于像整个 ML 框架一样大的超参数空间,贝叶斯优化的启动速度很慢。通过基于元学习选择若干个配置来用于种子贝叶斯优化。这种通过元学习的方法可以称为热启动优化方法。再配合多个模型的自动集成方法,使得整个机器学习流程高度自动化,将大大节省用户的时间。从这个流程来看,让机器学习使用者可以有更多的时间来选择数据以及思考要处理的问题本身。

2.1 贝叶斯优化

贝叶斯优化的原理是利用现有的样本在优化目标函数中的表现,构建一个后验模型。该后验模型上的每一个点都是一个高斯分布,即有均值和方差。若该点是已有样本点,则均值就是该点的优化目标函数取值,方差为0。而其他未知样本点的均值和方差是后验概率拟合的,不一定接近真实值。那么就用一个采集函数,不断试探这些未知样本点对应的优化目标函数值,不断更新后验概率的模型。由于采集函数可以兼顾Explore/Exploit,所以会更多地选择表现好的点和潜力大的点。因此,在资源预算耗尽时,往往能够得到不错的优化结果。即找到局部最优的优化目标函数中的参数。

auto-sklearn用的是smac(https://github.com/automl/SMAC3)算法,是贝叶斯优化算法的一种。算法在刚初始化时,的确类似随机搜索,但是随着搜索的进行,算法知道的信息越来越多,就能预知下一次搜索哪个点模型的表现会最好。
在这里插入图片描述

图中浅蓝色的表示95%置信区间的上下界,越宽表示对某个点预测的标准差越大,表示对这个点越不确定。随着搜索过的点越来越多,历史点(红叉)附近的标准差就会降低,表示对附近的点越确定。就这样会拟合出一个参数空间映射到模型表现的函数,从这个空间中找一个点,作为下次的搜索点。

2.2 从时间维度看

在这里插入图片描述
上图是在一个简单的 1D 问题上应用贝叶斯优化的实验图,这些图显示了在经过四次迭代后,高斯过程对目标函数的近似。我们以 t=3 为例分别介绍一下图中各个部分的作用。

上图 2 个 evaluations 黑点和一个红色 evaluations,是三次评估后显示替代模型的初始值估计,会影响下一个点的选择,穿过这三个点的曲线可以画出非常多条。黑色虚线曲线是实际真正的目标函数 (通常未知)。黑色实线曲线是代理模型的目标函数的均值。紫色区域是代理模型的目标函数的方差。绿色阴影部分指的是acquisition function的值,选取最大值的点作为下一个采样点。只有三个点,拟合的效果稍差,黑点越多,黑色实线和黑色虚线之间的区域就越小,误差越小,代理模型越接近真实模型的目标函数。

3、实战

3.1 分类

这里我们使用Titanic数据集来演示,因为这个数据集大家比较熟悉,所以看起来会更加简单点

3.1.1数据导入

说明:可以自动识别NAN值,这个是Auto-Sklearn比较人性化的一点;大部分sklearn原生模型都不能自动处理Nan值。

import pandas as pd
data = pd.read_csv("/datasource/DL数据集demo/titanic.csv")
# 这里对性别用One-Hot编码,其实流程跟单个算法一样
one_hot=OneHotEncoder()
data_temp = pd.DataFrame(one_hot.fit_transform(data[['Sex']]).toarray(),columns=one_hot.get_feature_names(['Sex']),dtype='int32')
data_onehot=pd.concat((data,data_temp),axis=1)    #也可以用merge,join
data_onehot.head()
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFare CabinEmbarkedSex_femaleSex_male
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th…female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS

3.1.2 数据定义

X ,y = (data_onehot[["Pclass","Age","Fare", "Sex_female","Sex_male"]],data_onehot["Survived"])
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X, y)

3.1.3 模型定义

AutoSklearn 类提供了大量的配置选项作为参数。
大部分参数可以默认,这里只介绍两个参数
time_left_for_this_task:任务的最长时间,以秒为单位;默认是一个小时,所以自己用来演示的话可以设置5-10分钟;
per_run_time_limit :分配给每个模型评估的时间,以秒为单位,需要依照time_left_for_this_task来进行制定,如果大于或者等于time_left_for_this_task,模型一般会提示并自动赋值。
memory_limit=None:如果内存报错的话,建议这样设置;
n_jobs=-1:这个也会报错,具体错误忘记记下来了,建议设置成-1
另外还有其他参数,如ensemble_size、initial_configurations_via_metalearning,可用于微调分类器
如果有需要可以自行配置其他参数
但是既然我们要学习自动机器学习,不建议考虑太多参数。

automl = autosklearn.classification.AutoSklearnClassifier(
       time_left_for_this_task=3*60, 
       per_run_time_limit=2*60,
        n_jobs=-1,
        memory_limit=None
    )

3.1.4 训练预测

预测时间要比一般的sklearn模型要长,一般的模型,比如GBDT,设置1000棵树,基本也在秒级完成,但是这个需要根据设置的时间来运行。
优点就是不需要自己进行网格调参,所以整体看下来,autosklearn是比较节省时间的。

automl.fit(X_train, y_train)
print(automl.score(X_train, y_train))
# 0.9251497005988024
y_AUTO= automl.predict(X_test)

虽然autosklearn有score查看准确率,但是最好跟其他的算法统一,用sklearn的数据进行评估

from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score


acc = accuracy_score(y_test,y_AUTO)#"accuracy"(准确率)
auc = roc_auc_score(y_test, y_AUTO)#从预测分数中,计算ROC曲线的面积(ROC AUC)
precision = precision_score(y_test,y_AUTO)#精确率
recall = recall_score(y_test,y_AUTO)#计算召回率
f1 = f1_score(y_test,y_AUTO)#计算F1评分,也被称为balanced F-score或者F-measure

print(acc,auc,precision,recall,f1)
#0.7937219730941704 0.7673930921052632 0.8888888888888888 0.5894736842105263 0.7088607594936708

3.1.5 模型查看

sprint_statistics()函数总结了上述搜索和选择的最佳模型的性能

print(automl.sprint_statistics())

auto-sklearn results:
Dataset name: b7f831d5-8013-11ed-9911-c3531a603d75
Metric: accuracy
Best validation score: 0.846154
Number of target algorithm runs: 92
Number of successful target algorithm runs: 91
Number of crashed target algorithm runs: 0
Number of target algorithms that exceeded the time limit: 1
Number of target algorithms that exceeded the memory limit: 0

leaderboard()函数查看所有模型打印排行榜

print(automl.leaderboard())
model_idrankensemble_weighttypecostduration
3910.04random_forest0.1538461.876512
2620.06random_forest0.1628961.979716
7630.04random_forest0.1628962.221861
240.04random_forest0.1674212.143343
9250.02random_forest0.1719461.950003
770.04mlp0.1809958.816894
6860.02random_forest0.1809951.625041
6780.06adaboost0.1855200.962228
13100.02random_forest0.1900452.401354
2190.04libsvm_svc0.1900451.238663
4110.06extra_trees0.1945702.202599
18120.02gradient_boosting0.1945701.983646
30130.02random_forest0.1945701.753223
22140.02gradient_boosting0.1990951.955653
19180.02extra_trees0.2036202.184760
36170.04extra_trees0.2036201.547301
72160.02qda0.2036201.264002
74150.18k_nearest_neighbors0.2036200.933298
40190.02lda0.2081451.531620
14210.12adaboost0.2126702.197585
79200.02random_forest0.2126701.693966
20220.02gradient_boosting0.2217191.806265
49230.02random_forest0.2307691.673310
15250.02qda0.2352941.421954
23240.02random_forest0.2352942.352457

show_models()函数可以查看所有模型的信息

print(automl.show_models())

{4: {‘model_id’: 4,
‘rank’: 1,
‘cost’: 0.1945701357466063,
‘ensemble_weight’: 0.08,
‘data_preprocessor’: <autosklearn.pipeline.components.data_preprocessing.DataPreprocessorChoice at 0x7fe8f5e35550>,
‘balancing’: Balancing(random_state=1),
‘feature_preprocessor’: <autosklearn.pipeline.components.feature_preprocessing.FeaturePreprocessorChoice at 0x7fe8f5747070>,
‘classifier’: <autosklearn.pipeline.components.classification.ClassifierChoice at 0x7fe8f5747580>,
‘sklearn_classifier’: ExtraTreesClassifier(max_features=5, min_samples_leaf=3, min_samples_split=11,
n_estimators=512, n_jobs=1, random_state=1,
warm_start=True)},
7: {‘model_id’: 7,
‘rank’: 2,
‘cost’: 0.1809954751131222,
‘ensemble_weight’: 0.06,
‘data_preprocessor’: <autosklearn.pipeline.components.data_preprocessing.DataPreprocessorChoice at 0x7fe766d14b80>,
‘balancing’: Balancing(random_state=1, strategy=‘weighting’),
‘feature_preprocessor’: <autosklearn.pipeline.components.feature_preprocessing.FeaturePreprocessorChoice at 0x7fe88710f970>,
‘classifier’: <autosklearn.pipeline.components.classification.ClassifierChoice at 0x7fe88710f460>,
‘sklearn_classifier’: MLPClassifier(alpha=4.2841884333778574e-06, beta_1=0.999, beta_2=0.9,
hidden_layer_sizes=(263, 263, 263),
learning_rate_init=0.0011804284312897009, max_iter=256,
n_iter_no_change=32, random_state=1, validation_fraction=0.0,
verbose=0, warm_start=True)},

其他参数:
官网:APIs — AutoSklearn 0.15.0 documentation
源码:https://github.com/automl/auto-sklearn

3.2 回归

我们使用常见的波士顿房价预测

import pandas as pd
data = pd.read_csv("/datasource/DL数据集demo/boston_house_prices.csv")

X = data[['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
       'PTRATIO', 'B', 'LSTAT']]
y = data['MEDV']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

流程与分类一致,所以就不废话了

automl = autosklearn.regression.AutoSklearnRegressor(
    time_left_for_this_task=180,  
    per_run_time_limit=30,       
    memory_limit=None,  
    n_jobs=-1,
)

automl.fit(X_train, y_train, dataset_name='boston')
# evaluate the best model
y_pred = automl.predict(X_test)
# 结果评估
print('均方误差: %.2f' % mean_squared_error(y_test,y_pred))
print('确定系数(R^2): %.2f' % r2_score(y_test,y_pred))

均方误差: 15.37
确定系数(R^2): 0.82

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/105157.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

阿里微服务架构到底多牛逼:深入解析Apache Dubbo与实战

本书的由来 在Apache Dubbo (以下简称Dubbo)重新开源之前&#xff0c;Dubbo已经被很多公司广泛用于生产环境并获得了良好的反馈&#xff0c;很多公司内部也会建立私有分支自己维护&#xff0c;其中Dubbox 就是基于Dubbo分支进行扩展并二次维护的。重新开源后&#xff0c;社区维…

【ESXi 7.x内部升级】ESXi 升级 —— 小版本升级(7.X或8.X版本内升级)

目录4. 小版本升级&#xff08;7.X或8.X版本内升级&#xff09;4.1 示例 — 使用 vSphere Lifecycle Manager升级 ESXi目标&#xff1a;将 VMware ESXi 7.0 U2e 升级为 7.0 U3f&#xff08;1&#xff09;在vSphere Client 中查看需要升级的 ESXi 版本&#xff08;2&#xff09;…

【微信篇】PC端微信文件夹里的“微信号“

【微信篇】PC端微信文件夹里的"微信号" 更新记录最敷衍的软件一微信&#xff01;&#xff01;&#xff01;—【蘇小沐】 文章目录【微信篇】PC端微信文件夹里的"微信号"1.实验环境PC端微信文件夹里的"微信号"总结1.实验环境 系统版本Windows 1…

深度学习入门(六十四)循环神经网络——编码器-解码器架构

深度学习入门&#xff08;六十四&#xff09;循环神经网络——编码器-解码器架构前言循环神经网络——编码器-解码器架构课件重新考察CNN重新考察RNN编码器-解码器架构总结教材1 编码器2 解码器3 合并编码器和解码器4 训练模型5 小结参考文献前言 核心内容来自博客链接1博客连…

分布式任务调度 - PowerJob

一、简介 1、介绍 PowerJob&#xff08;原OhMyScheduler&#xff09;是全新一代分布式任务调度与计算框架&#xff0c;其主要功能特性如下&#xff1a; 使用简单&#xff1a;提供前端Web界面&#xff0c;允许开发者可视化地完成调度任务的管理&#xff08;增、删、改、查&am…

数据库原理及MySQL应用 | 约束

约束是保证数据完整性的一种数据库对象&#xff0c;按约束作用不同&#xff0c;分为七种。 约束从字面上来看就是受到限制&#xff0c;它是附加在表上&#xff0c;通过限制列中、行中、表之间数据来保证数据完整性的一种数据库对象。 在MySQL中&#xff0c;有多种约束&#xf…

设计模式原则 - 开闭原则(五)

开闭原则一 官方定义基本介绍二 案例演示普通实现方式案例分析开闭原则实现案例分析三 注意事项一 官方定义 开闭原则&#xff08; Open Close Principle &#xff09;&#xff0c;又称为OCP原则&#xff0c;他的官方定义如下&#xff1a; Software entities like classes,modu…

基于Java+Swing+Mysql实现停车场管理系统

基于JavaSwingMysql实现停车场管理系统一、系统介绍二、系统展示三、其它1.其他系统实现一、系统介绍 1.系统功能 用户 1.登录系统 2.信息查询 包含计费标准&#xff0c;当前在场信息&#xff0c;用户历史信息&#xff0c;用户个人信息&#xff0c;出入场信息&#xff0c;当前…

Win10提示错误代码0xc0000001的解决办法

​有一些朋友在使用Win10系统的时候会遇到蓝屏故障&#xff0c;提示“无法正常启动你的电脑&#xff0c;在多次尝试后&#xff0c;你的电脑上的操作系统仍无法启动&#xff0c;因此需求对其进行修复。” Win10提示恢复无法正常启动你的电脑0xc0000001 故障原因&#xff1a; 错误…

实战案例:初探工程配置 图标组件热身

点击上方卡片“前端司南”关注我您的关注意义重大原创前端司南前言本文是 基于ViteAntDesignVue打造业务组件库[1] 专栏第 3 篇文章【实战案例&#xff1a;初探工程配置 & 图标组件热身】&#xff0c;我将从业务系统中最基础的图标组件入手&#xff0c;带着读者们练练手找找…

websocket的用处及vue和SpringBoot和nginx的引入-入门

websocket的用处及vue和SpringBoot的引入-入门 为什么要有websocket 微信 想一个场景&#xff0c;扫码登录&#xff0c;服务器并不知道用户有没有扫码&#xff0c;怎么办&#xff0c;一种办法是HTTP定时轮询&#xff0c;1-2秒就请求一次服务端&#xff0c;看看用户有没有扫码…

5.3 常见的电感式和电容式感测原理及应用

常见的电感式和电容式感测应用1、电感式和电容式工作原理1.1 电感式感测工作原理1.2 电容式感测工作原理2 FDC&#xff1a;电容式液位感测2.1 电容技术在液位感测中的优势2.2 电容式液位感测入门3 LDC&#xff1a;电感式触控按钮4 LDC&#xff1a;增量编码器和事件计数5 LDC&am…

再学C语言10:字符串(1)

一、字符串定义 字符串&#xff1a;一个或多个字符的序列 "hello world!" 双引号并不是字符串的一部分&#xff0c;只是用于通知编译器其中包含了一个字符串 C没有为字符串定义专门的变量类型&#xff0c;而是将其存储在char数组中 字符串中的字符存放在相邻的存…

Amazon 4.7 星评,领域新经典,了解服务设计就读它

2011 年&#xff0c;Adaptive Path 公司的 Brandon Schauer 粗略估算&#xff0c;美国每年在服务的规划和设计上大约花费 20 亿美元&#xff0c;但其中仅有 7000 万美元&#xff08;大约 3.5%&#xff09;花在了“服务设计”上。做另外 96.5% 的工作的那些人&#xff0c;从不觉…

参加大学生数学建模大赛,Matlab和Python到底哪个更好?

前言 后台的小伙伴经常会问编程过程中&#xff0c;MATLAB和Python到底哪个更好&#xff1f;这个问题一直困惑很多同学&#xff0c;今天小编来给大家从实用型来综合分析一下&#xff1a; 首先从两者各自的应用做个对比。 一、python的优势 Python相对于Matlab最大的优势&…

Mac M2芯 k8s(minikube)超详细实战 - 单节点部分

概述 我使用的电脑是Mac pro M2芯的&#xff0c;使用的虚拟环境是 Ubuntu 22.04 &#xff0c;M2芯兼容性不是特别好&#xff0c;所以尽量跟我博客中的版本保持一致。 虚拟机环境 Ubuntu 22.04docker &#xff1a;20.10.17minikube&#xff1a;v1.25.2 搭建minikube虚拟机环境…

【强化学习基础】强化学习的基本概念:状态、动作、智能体、策略、奖励、状态转移、轨迹、回报、价值函数

文章目录1.状态&#xff08;State&#xff09;2.动作&#xff08;Action&#xff09;3.智能体&#xff08;Agent&#xff09;4.策略&#xff08;Policy&#xff09;5.奖励&#xff08;Reward&#xff09;6.状态转移&#xff08;State transition&#xff09;7.智能体与环境交互…

高效率的Python开发工具——PyCharm v2022.3正式发布

JetBrains PyCharm是一种Python IDE&#xff0c;其带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具。此外&#xff0c;该IDE提供了一些高级功能&#xff0c;以用于Django框架下的专业Web开发。 PyCharm v2022.3官方正式版下载(q技术交流&#xff1a;786598704)…

wireshark抓包数据提取TCP/UDP/RTP负载数据方法

wireshark抓包数据提取TCP_UDP_RTP负载数据方法 文章目录wireshark抓包数据提取TCP_UDP_RTP负载数据方法1 背景2 TCP和UDP负载提取方式3 RTP负载提取方式1 背景 在视频抓包分析过程中&#xff0c;有时候需要从TCP、UDP、RTP中直接提取payload数据&#xff0c;比如较老的摄像机…

微课堂助力在线教育招生引流方式_付费视频系统搭建对在线教育的作用

一、借助优惠码线上线下推广课程 1、线下发传单&#xff1a; 机构先在我们后台创建对应课程的通用优惠码&#xff0c;然后再制作课程传单介绍页。传单上显示出对应课程的通用优惠码&#xff0c;线下派发传单给到用户。 2、线下刮刮卡片推广&#xff1a;将私有码制作成卡片配合…