TabNet: Attentive Interpretable Tabular Learning——一种具有可解释性的注意力表格学习模型
文章提出了一种名为TabNet的新型深度神经网络架构专门用于处理表格数据。该架构旨在结合决策树DT的优势如可解释性、处理表格数据的高效性与深度神经网络DNN的优势如端到端学习、表示学习从而在表格数据学习任务中取得更优的性能和可解释性。主要研究内容与创新点1. 提出 TabNet 架构TabNet 是一种基于序列注意力机制的深度学习模型其核心设计包括实例级稀疏特征选择在每个决策步骤模型使用一个可学习的掩码为每个输入实例动态选择最相关的特征子集进行处理。这确保了模型容量专注于最重要的特征提高了参数效率。顺序多步处理模型通过多个决策步骤顺序处理信息每一步都基于上一步的输出决定下一步需要关注的特征并逐步构建对最终决策的贡献。这种架构模仿了决策树DT的分步推理过程。特征变换块对选中的特征进行非线性处理包含跨步骤共享和步骤独享的全连接层结合批归一化BN和门控线性单元GLU以增强学习能力。端到端学习TabNet 可以直接处理原始表格数据包括类别特征和数值特征无需复杂的预处理或特征工程并通过梯度下降进行优化。2. 增强模型可解释性TabNet 天然具有可解释性这是其在现实应用中至关重要的一点局部可解释性通过可视化每个决策步骤产生的特征选择掩码可以清楚了解对于单个输入样本模型在每一步依赖哪些特征做出决策。全局可解释性通过聚合所有决策步骤的特征掩码并考虑每步的决策贡献可以计算出每个特征在整个训练模型中的全局重要性得分从而理解模型的整体行为。3. 提出表格数据的自监督学习这是论文的另一项重要贡献。作者首次为表格数据设计了一个自监督学习框架方法训练一个编码器-解码器结构的 TabNet任务是从部分特征中预测被掩码的其他特征。例如模型尝试从其他列推断出“教育水平”或“性别”列。效果在大规模无标签表格数据上进行这种预训练后得到的编码器模型在下游的监督学习任务如分类、回归上表现显著提升尤其是在标记数据稀缺的情况下。这类似于 BERT 在 NLP 中的预训练范式。4. 广泛的实验验证论文在多个真实世界和合成数据集上对 TabNet 进行了全面的评估涵盖分类、回归等任务性能对比在合成数据上TabNet 的实例级特征选择能力显著优于全局特征选择方法和 L2X、INVASE 等其他实例级方法。在真实数据集如 Forest Cover Type, Poker Hand, Higgs Boson, Sarcos, Rossmann 销售预测上TabNet 的性能通常超越或媲美最先进的集成树模型如 XGBoost, LightGBM, CatBoost和传统 MLP/DNN 模型。TabNet 在具有高度非线性关系如 Poker Hand和超大规模数据如 Higgs Boson的任务中表现尤为突出。可解释性展示通过可视化特征掩码清晰展示了 TabNet 如何准确聚焦于合成数据中的真正相关特征以及在真实数据如蘑菇分类、人口普查收入预测上给出的特征重要性排名与领域知识及现有可解释性工具如 SHAP的结果高度一致。自监督学习效果实验证明在 Higgs 和 Forest Cover Type 数据集上使用自监督预训练能显著提升小样本场景下的模型准确率和收敛速度。消融研究与超参数指南详细分析了模型各组件如决策步数、特征维度、稀疏正则化强度、批量大小等对性能的影响并为实际应用提供了超参数选择建议。论文的核心贡献是设计了一个高性能、可解释、且能利用自监督学习的深度表格学习模型 TabNet。它成功地将注意力机制、稀疏特征选择和顺序处理融合到 DNN 中有效解决了传统 DNN 在表格数据上性能不佳且不可解释的问题同时在多个任务上证明了其相对于传统树模型的优势并为表格数据的预训练开辟了新的方向。这里是自己的论文阅读记录感兴趣的话可以参考一下如果需要阅读原文的话可以看这里如下所示摘要我们提出了一种新颖的高性能且可解释的规范深度表格数据学习架构——TabNet。TabNet 使用序列注意力机制在每一步决策中选择要推理的特征从而实现可解释性并通过将学习能力集中于最重要的特征来实现更高效的学习。我们证明在各种非性能饱和的表格数据集上TabNet 优于其他变体并产生可解释的特征归因以及对其全局行为的洞察。最后我们展示了针对表格数据的自监督学习当未标记数据充足时可显著提高性能。引言深度神经网络DNN在图像He et al. 2015、文本Lai et al. 2015和音频Amodei et al. 2015领域取得了显著成功。对于这些领域能够高效地将原始数据编码为有意义的表示的规范架构推动了快速的进步。表格数据是尚未见到这种规范架构成功的数据类型之一。尽管表格数据是现实世界人工智能中最常见的数据类型因为它包含任何类别和数值特征Chui et al. 2018但针对表格数据的深度学习仍未得到充分探索集成决策树DT的变体仍然主导着大多数应用Kaggle 2019a。为什么首先因为基于 DT 的方法具有某些优势(i) 它们对于表格数据中常见的、具有近似超平面边界的决策流形具有表示效率(ii) 它们的基本形式高度可解释例如通过跟踪决策节点并且对于其集成形式有流行的后验可解释性方法例如Lundberg, Erion, and Lee 2018——这是许多实际应用中的一个重要考量(iii) 它们训练速度快。其次因为先前提出的 DNN 架构并不非常适合表格数据例如堆叠卷积层或多层感知机MLP参数严重过多——缺乏适当的归纳偏置常常导致它们无法为表格决策流形找到最优解Goodfellow, Bengio, and Courville 2016; Shavitt and Segal 2018; Xu et al. 2019。为什么值得为表格数据探索深度学习一个显而易见的动机是预期的性能提升尤其是在大型数据集上Hestness et al. 2017。此外与树学习不同DNN 能够为表格数据实现基于梯度下降的端到端学习这具有诸多优势(i) 高效编码多种数据类型如图像与表格数据一起(ii) 减少对特征工程的需求这是当前基于树的表格数据学习方法的一个关键方面(iii) 从流式数据中学习也许最重要的是 (iv) 端到端模型允许表示学习这支持许多有价值的应用场景包括数据高效的领域自适应Goodfellow, Bengio, and Courville 2016、生成建模Radford, Metz, and Chintala 2015和半监督学习Dai et al. 2017。我们提出了一种新的规范 DNN 架构用于表格数据称为 TabNet。主要贡献总结如下TabNet 输入原始表格数据无需任何预处理并使用基于梯度下降的优化进行训练从而能够灵活地集成到端到端学习中。TabNet 使用序列注意力机制在每一步决策中选择要推理的特征从而实现可解释性和更好的学习因为学习能力被用于最显著的特征见图 1。这种特征选择是实例级的例如每个输入的选择可能不同并且与Chen et al. 2018或Yoon, Jordon, and van der Schaar 2019等其他实例级特征选择方法不同TabNet 使用单一的深度学习架构进行特征选择和推理。上述设计选择带来了两个有价值的特性(i) 在各种不同领域的分类和回归问题的数据集上TabNet 优于或媲美其他表格学习模型(ii) TabNet 支持两种可解释性局部可解释性可视化特征的重要性及其组合方式以及全局可解释性量化每个特征对训练后模型的贡献。最后首次针对表格数据我们展示了通过使用无监督预训练来预测被掩码的特征可以显著提升性能见图 2。相关工作特征选择特征选择广义上指根据特征对预测的有用性明智地选择一个子集。常用的技术如前向选择法和 Lasso 正则化Guyon and Elisseeff 2003基于整个训练数据归因特征重要性被称为全局方法。实例级特征选择指为每个输入单独选择特征在Chen et al. 2018中通过一个解释器模型来最大化所选特征与响应变量之间的互信息在Yoon, Jordon, and van der Schaar 2019中通过使用演员-评论家框架来模仿基线同时优化选择。与这些不同TabNet 在端到端学习中采用具有可控稀疏性的软特征选择——单个模型联合执行特征选择和输出映射从而以紧凑的表示获得优越的性能。图 1在 Adult Census Income 预测Dua and Graff 2017上示例的 TabNet 稀疏特征选择。稀疏特征选择实现了可解释性和更好的学习因为容量被用于最显著的特征。TabNet 采用多个决策块专注于处理输入特征的一个子集进行推理。图中示例的两个决策块分别处理与职业和投资相关的特征以预测收入水平。图 2自监督表格学习。现实世界的表格数据集具有相互依赖的特征列例如可以从职业猜测教育水平或者从家庭关系猜测性别。通过掩码自监督学习进行无监督表示学习可以为监督学习任务产生一个改进的编码器模型。基于树的学习DT 常用于表格数据学习。其主要优势是高效地挑选具有最大统计信息增益的全局特征Grabczewski and Jankowski 2005。为了提高标准 DT 的性能一种常见的方法是集成以降低方差。在集成方法中随机森林Ho 1998使用随机数据子集和随机选择的特征来生长许多树。XGBoostChen and Guestrin 2016和 LightGBMKe et al. 2017是最近两种主流的集成 DT 方法主导了最近的大多数数据科学竞赛。我们的实验结果显示在各种数据集上当通过深度学习提高表示能力同时保留其特征选择特性时可以超越基于树的模型。DNN 与 DT 的集成如Humbird, Peterson, and McClaren 2018中那样用 DNN 构建块表示 DT 会导致表示冗余和学习效率低下。软神经DTWang, Aggarwal, and Liu 2017; Kontschieder et al. 2015使用可微分的决策函数而不是不可微分的轴对齐分割。然而失去自动特征选择通常会降低性能。在Yang, Morillo, and Hospedales 2018中提出了一种软分箱函数来模拟 DNN 中的 DT但需要低效地枚举所有可能的决策。Ke et al. 2019提出了一种通过显式利用表达性特征组合的 DNN 架构然而学习是基于从梯度提升 DT 转移知识。Tanno et al. 2018提出了一种 DNN 架构通过从原始块自适应地生长同时将表示学习到边、路由函数和叶节点中。TabNet 与这些不同它通过序列注意力嵌入了具有可控稀疏性的软特征选择。图 3使用传统 DNN 模块左和相应决策流形右的类 DT 分类示意图。通过对输入使用乘法稀疏掩码来选择相关特征。所选特征被线性变换并在添加偏置表示边界后ReLU 通过将区域置零来执行区域选择。多个区域的聚合基于加法。随着 \(C_1\) 和 \(C_2\) 增大决策边界变得更锐利。自监督学习无监督表示学习可以改善监督学习尤其是在小数据 regime 下Raina et al. 2007。最近针对文本Devlin et al. 2018和图像Trinh, Luong, and Le 2019数据的工作显示了显著的进步——这得益于对无监督学习目标掩码输入预测和基于注意力的深度学习的明智选择。用于表格学习的 TabNetDT 在从现实世界表格数据集中学习方面是成功的。通过特定设计传统的 DNN 构建块可以用来实现类 DT 的输出流形例如见图 3。在这种设计中个体特征选择是获得超平面形式决策边界的关键这可以推广到特征的线性组合其中系数决定了每个特征的比例。TabNet 基于这种功能并通过精心设计在优于 DT 的同时获得了它们的优势其设计要点是(i) 使用从数据中学习到的稀疏实例级特征选择(ii) 构建一个序列多步架构其中每一步基于所选特征对决策的一部分做出贡献(iii) 通过所选特征的非线性处理提高学习能力(iv) 通过更高维度和更多步骤模拟集成。图 4(a) TabNet 编码器由一个特征变换器、一个注意力变换器和特征掩码组成。一个分割块将处理后的表示分开一部分供后续步骤的注意力变换器使用另一部分用于整体输出。对于每一步特征选择掩码提供了关于模型功能的可解释信息并且可以聚合掩码以获得全局特征重要性归因。(b) TabNet 解码器由每一步的一个特征变换器块组成。(c) 一个特征变换器块示例——展示了 4 层网络其中 2 层在所有决策步之间共享2 层依赖于决策步。每一层由一个全连接FC层、BN 和 GLU 非线性组成。(d) 一个注意力变换器块示例——一个单层映射通过先验尺度信息进行调制该信息聚合了在当前决策步之前每个特征被使用的程度。使用 sparsemaxMartins and Astudillo 2016对系数进行归一化从而实现对显著特征的稀疏选择。实验我们在广泛的问题上研究 TabNet包括回归或分类任务特别是使用已发布的基准数据集。对于所有数据集类别输入通过可学习的嵌入映射到一维可训练标量 数值列直接输入无需预处理 。我们使用标准的分类softmax 交叉熵和回归均方误差损失函数并训练至收敛。TabNet 模型的超参数在验证集上进行了优化并列于附录中。如附录中的消融研究所示TabNet 的性能对大多数超参数不敏感。在附录中我们还提供了关于各种设计的消融研究以及关键超参数选择的指南。对于所有我们引用的实验我们使用与原始工作相同的训练、验证和测试数据划分。所有模型的训练都使用 Adam 优化算法Kingma and Ba 2014和 Glorot 均匀初始化。实例级特征选择选择显著特征对于实现高性能至关重要尤其是在小型数据集上。我们考虑了Chen et al. 2018中的 6 个表格数据集包含 10k 训练样本。这些数据集的构建方式是只有一部分特征决定输出。对于 Syn1-Syn3显著特征对所有实例都是相同的例如Syn2 的输出取决于特征 X3−X6如果已知显著特征全局特征选择将给出高性能。对于 Syn4-Syn6显著特征是实例相关的例如对于 Syn4输出取决于 X1−X2 或 X3−X6 中的一组具体取决于 X11 的值这使得全局特征选择效果不佳。表 1 显示TabNet 优于其他方法树集成Geurts, Ernst, and Wehenkel 2006、LASSO 正则化、L2XChen et al. 2018并与 INVASEYoon, Jordon, and van der Schaar 2019表现相当。对于 Syn1-Syn3TabNet 的性能接近于全局特征选择——它能找出哪些特征是全局重要的。对于 Syn4-Syn6通过消除实例级冗余特征TabNet 改进了全局特征选择。所有其他方法使用的预测模型有 43k 个参数而由于演员-评论家框架中的另外两个模型INVASE 的总参数数为 101k。TabNet 是一个单一的架构其大小为 Syn1-Syn3 的 26k 和 Syn4-Syn6 的 31k。紧凑的表示是 TabNet 的宝贵特性之一。表 1在Chen et al. 2018的 6 个合成数据集上TabNet 与其他基于特征选择的 DNN 模型的测试接收者操作特征曲线下面积AUC的均值和标准差。模型包括无选择使用所有特征无特征选择、全局仅使用全局显著特征、树集成Geurts, Ernst, and Wehenkel 2006、Lasso 正则化模型、L2XChen et al. 2018和 INVASEYoon, Jordon, and van der Schaar 2019。加粗数字表示每个数据集上的最佳结果。模型Syn1Syn2Syn3Syn4Syn5Syn6无选择.578 ± .004.789 ± .003.854 ± .004.558 ± .021.662 ± .013.692 ± .015树.574 ± .101.872 ± .003.899 ± .001.684 ± .017.741 ± .004.771 ± .031Lasso 正则化.498 ± .006.555 ± .061.886 ± .003.512 ± .031.691 ± .024.727 ± .025L2X.498 ± .005.823 ± .029.862 ± .009.678 ± .024.709 ± .008.827 ± .017INVASE.690 ± .006.877 ± .003.902 ± .003.787 ± .004.784 ± .005.877 ± .003全局.686 ± .005.873 ± .003.900 ± .003.774 ± .006.784 ± .005.858 ± .004TabNet.682 ± .005.892 ± .004.897 ± .003.776 ± .017.789 ± .009.878 ± .004在真实世界数据集上的性能表 2Forest Cover Type 数据集的性能。模型测试准确率 (%)XGBoost89.34LightGBM89.28CatBoost85.14AutoML Tables94.95TabNet96.99Forest Cover TypeDua and Graff 2017任务是从制图变量中对森林覆盖类型进行分类。表 2 显示TabNet 优于已知能实现稳健性能的集成树基方法Mitchell et al. 2018。我们还考虑了 AutoML TablesAutoML 2019这是一个基于模型集成的自动化搜索框架包括 DNN、梯度提升 DT、AdaNetCortes et al. 2016和集成AutoML 2019并进行了非常彻底的超参数搜索。一个未进行精细超参数搜索的单一 TabNet 模型就超越了它。表 3Poker Hand 数据集的性能。模型测试准确率 (%)DT50.0MLP50.0深度神经 DT65.1XGBoost71.1LightGBM70.0CatBoost66.6TabNet99.2基于规则100.0Poker HandDua and Graff 2017任务是从牌的花色和点数原始属性中识别扑克手牌。输入-输出关系是确定性的手工设计的规则可以达到 100% 的准确率。然而传统的 DNN、DT甚至是它们的混合变体——深度神经 DTYang, Morillo, and Hospedales 2018都严重受到数据不平衡的影响无法学习所需的排序和比较操作Yang, Morillo, and Hospedales 2018。经过调优的 XGBoost、CatBoost 和 LightGBM 相比它们只有非常微小的改进。TabNet 优于其他方法因为它能够利用其深度执行高度非线性的处理同时通过实例级特征选择避免过拟合。表 4Sarcos 数据集的性能。考虑了三种不同大小的 TabNet 模型。模型测试 MSE模型大小随机森林2.3916.7K随机 DT2.1128KMLP2.130.14M自适应神经树1.230.60M梯度提升树1.440.99MTabNet-S1.256.3KTabNet-M0.280.59MTabNet-L0.141.75MSarcosVijayakumar and Schaal 2000任务是回归一个拟人机器人手臂的逆动力学。Tanno et al. 2018表明使用随机森林一个非常小的模型也能获得不错的性能。在非常小的模型规模 regime 下TabNet 的性能与Tanno et al. 2018中参数多 100 倍的最佳模型相当。当模型规模不受限制时TabNet 实现了几乎低一个数量级的测试 MSE。表 5Higgs Boson 数据集的性能。两个 TabNet 模型分别用 -S 和 -M 表示。模型测试准确率 (%)模型大小稀疏进化 MLP78.4781K梯度提升树-S74.220.12M梯度提升树-M75.970.69MMLP78.442.04M梯度提升树-L76.986.96MTabNet-S78.2581KTabNet-M78.840.66MHiggs BosonDua and Graff 2017任务是区分希格斯玻色子过程与背景。由于其规模大得多1050 万个实例即使是非常大的集成DNN 也优于 DT 变体。TabNet 以更紧凑的表示优于 MLP。我们还将结果与最先进的进化稀疏化算法Mocanu et al. 2018进行了比较该算法将非结构化稀疏性集成到训练中。凭借其紧凑的表示TabNet 在相同参数数量下产生了与稀疏进化训练几乎相似的性能。与稀疏进化训练不同TabNet 的稀疏性是结构化的——它不会降低操作强度Wen et al. 2016并且可以有效地利用现代多核处理器。表 6Rossmann Store Sales 数据集的性能。模型测试 MSEMLP512.62XGBoost490.83LightGBM504.76CatBoost489.75TabNet485.12Rossmann Store SalesKaggle 2019b任务是根据静态和时变特征预测商店销售额。我们观察到 TabNet 优于常用的方法。时间特征例如日期获得了较高的重要性并且对于像节假日这样销售动态不同的情况观察到了实例级特征选择的好处。可解释性自监督学习表 7在 Higgs 数据集上使用 TabNet-M 模型改变监督微调的训练数据集大小得到的准确率15 次运行均值和标准差。训练数据集大小测试准确率 (%)监督学习带预训练1k57.47 ± 1.7861.37 ± 0.8810k66.66 ± 0.8868.06 ± 0.39100k72.92 ± 0.2173.19 ± 0.15表 7 显示无监督预训练显著提高了监督分类任务的性能尤其是在未标记数据集远大于标记数据集的情况下。如图 7 所示使用无监督预训练模型收敛速度更快。非常快的收敛速度对于持续学习和领域自适应可能很有用。结论我们提出了 TabNet一种用于表格学习的新型深度学习架构。TabNet 使用序列注意力机制在每个决策步选择一个语义上有意义的特征子集进行处理。实例级特征选择实现了高效的学习因为模型容量完全用于最显著的特征并且通过选择掩码的可视化也产生了更可解释的决策过程。我们证明了 TabNet 在不同领域的表格数据集上优于先前的工作。最后我们展示了无监督预训练对于快速适应和性能提升的显著好处。图 5在 Syn2 和 Syn4Chen et al. 2018上的特征重要性掩码 \(\mathbf{M}[i]\)指示第 \(i\) 步的特征选择和聚合特征重要性掩码 \(\mathbf{M}_{\text{agg}}\)显示全局实例级特征选择。较亮的颜色表示较高的值。例如对于 Syn2仅使用了 \(X_3 - X_6\)。图 6Adult 数据集决策流形的前两个 T-SNE 维度以及首要特征“年龄”的影响。图 7在 Higgs 数据集上使用 10k 个样本的训练曲线。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2638294.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!