Google的神经网络表格处理模型TabNet介绍

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
简介: Google的神经网络表格处理模型TabNet介绍

Google Research的TabNet于2019年发布,在预印稿中被宣称优于表格数据的现有方法。它是如何工作的,又如何可以尝试呢?

640.png

表格数据可能构成当今大多数业务数据。考虑诸如零售交易,点击流数据,工厂中的温度和压力传感器,银行使用的KYC (Know Your Customer) 信息或制药公司使用的模型生物的基因表达数据之类的事情。

论文称为TabNet: Attentive Interpretable Tabular Learning(https://arxiv.org/pdf/1908.07442.pdf),很好地总结了作者正在尝试做的事情。“Net”部分告诉我们这是一种神经网络,“Attentive ”部分表示它正在使用一种注意力机制,旨在实现可解释性,并用于表格数据的机器学习。

它是如何工作的?

TabNet使用一种软功能选择将重点仅放在对当前示例很重要的功能上。这是通过顺序的多步骤决策机制完成的。即,以多个步骤自上而下地处理输入信息。正如论文所指出的那样,“自上而下关注的思想是从处理视觉和语言数据或强化学习中得到的启发,可以在高维输入中搜索一小部分相关信息。”

尽管它们与BERT等流行的NLP模型中使用的transformer 有些不同,但执行这种顺序关注的构件却称为transformer 块。这些transformer 使用自注意力机制,试图模拟句子中不同单词之间的依赖关系。这里使用的transformer类型试图使用“软”特性选择,一步一步地消除与示例无关的那些特性,这是通过使用sparsemax函数完成的。

这篇论文的第一个图,如下重现,描绘了信息是如何聚集起来形成预测的。

640.png

TabNet的一个好特性是它不需要特性预处理。另一个原因是,它具有内置的可解释性,即为每个示例选择最相关的特性。这意味着您不必应用外部解释模块,如shap或LIME。

在阅读本文时,要理解这个架构中发生了什么并不容易,但幸运的是,已经发表的代码稍微澄清了一些问题,并表明它并不像您可能认为的那样复杂。

我怎么使用它?

现在TabNet有了更好的实现,如下所述:一个是PyTorch的接口,它有一个类似scikit学习的接口,还有一个是FastAI的接口。

根据作者readme描述要点如下:

为每个数据集创建新的train.csv,val.csv和test.csv文件,我不如读取整个数据集并在内存中进行拆分(当然,只要可行),所以我写了一个在我的代码中为Pandas提供了新的输入功能。

修改data_helper.py文件可能需要一些工作,至少在最初不确定您要做什么以及应该如何定义功能列时(至少我是这样)。还有许多参数需要更改,但它们位于主训练循环文件中,而不是数据帮助器文件中。有鉴于此,我还尝试在我的代码中概括和简化此过程。

我添加了一些快速的代码来进行超参数优化,但到目前为止仅用于分类。

还值得一提的是,作者提供的示例代码仅显示了如何进行分类,而不是回归,因此用户也必须编写额外的代码。我添加了具有简单均方误差损失的回归功能。

使用命令行运行测试

pythontrain_tabnet.py\--csv-pathdata/adult.csv\--target-name"<=50K"\--categorical-featuresworkclass,education,marital.status,\occupation,relationship,race,sex,native.country\--feature_dim16\--output_dim16\--batch-size4096\--virtual-batch-size128\--batch-momentum0.98\--gamma1.5\--n_steps5\--decay-every2500\--lambda-sparsity0.0001\--max-steps7700

强制性参数包括--csv-path(指向CSV文件的位置),-target-name(具有预测目标的列的名称)和-category-featues(逗号分隔列表) 应该视为分类的功能)。其余输入参数是需要针对每个特定问题进行优化的超参数。但是,上面显示的值直接取自TabNet论文,因此作者已经针对成人普查数据集对其进行了优化。

默认情况下,训练过程会将信息写入执行脚本的位置的tflog子文件夹。您可以将tensorboard指向此文件夹以查看训练和验证统计信息:

tensorboard--logdirtflog

如果您没有GPU ...

…您可以尝试这款Colaboratory笔记(https://colab.research.google.com/drive/1AWnaS6uQVDw0sdWjfh-E77QlLtD0cpDa)。请注意,如果您想查看Tensorboard日志,最好的选择是创建一个Google Storage存储桶,并让脚本在其中写入日志。这可以通过使用tb-log-location参数来完成。例如。如果您的存储桶名称是camembert-skyscrape,则可以在脚本的调用中添加--tb-log-location gs:// camembert-skyscraper。(不过请注意,您必须正确设置存储桶的权限。这可能有点麻烦。)

然后可以将tensorboard从自己的本地计算机指向该存储桶:

tensorboard--logdirgs://camembert-skyscraper

超参数优化

在存储库(opt_tabnet.py)中也有一个用于完成超参数优化的快捷脚本。同样,在协作笔记本中显示了一个示例。该脚本仅适用于到目前为止的分类,值得注意的是,某些训练参数虽然实际上并不需要,但仍进行了硬编码(例如,用于尽早停止的参数[您可以继续执行多少步,而 验证准确性没有提高]。)

优化脚本中变化的参数为N_steps,feature_dim,batch-momentum,gamma,lambda-sparsity。(正如下面的优化技巧所建议的那样,output_dim设置为等于feature_dim。)

论文中具有以下有关超参数优化的提示:

大多数数据集对N_steps∈[3,10]产生最佳结果。通常,更大的数据集和更复杂的任务需要更大的N_steps。N_steps的非常高的值可能会过度拟合并导致不良的泛化。

调整Nd [feature_dim]和Na [output_dim]的值是获得性能与复杂性之间折衷的最有效方法。Nd = Na是大多数数据集的合理选择。Nd和Na的非常高的值可能会过度拟合,导致泛化效果差。

γ的最佳选择对整体性能具有重要作用。通常,较大的N_steps值有利于较大的γ。

批量较大对性能有利-如果内存限制允许,建议最大训练数据集总大小的1-10%。虚拟批次大小通常比批次大小小得多。

最初,较高的学习率很重要,应逐渐降低直至收敛。

结果

我已经通过此命令行界面尝试了TabNet的多个数据集,作者提供了他们在那里找到的最佳参数设置。使用这些设置重复运行后,我注意到最佳验证误差(和测试误差)往往在86%左右,类似于不进行超参数调整的CatBoost。作者报告论文中测试集的性能为85.7%。当我使用hyperopt进行超参数优化时,尽管使用了不同的参数设置,但我毫不奇怪地达到了约86%的相似性能。

对于其他数据集,例如Poker Hand 数据集,TabNet被认为远远击败了其他方法。我还没有花很多时间,但是当然每个人都应邀请他们自己对各种数据集进行超参数优化的TabNet!

TabNet是一个有趣的体系结构,似乎有望用于表格数据分析。它直接对原始数据进行操作,并使用顺序注意机制对每个示例执行显式特征选择。此属性还使其具有某种内置的可解释性。

我试图通过围绕它编写一些包装器代码来使TabNet稍微容易一些。下一步是将其与各种数据集中的其他方法进行比较。

tabnet的各种实现

google官方:https://github.com/google-research/google-research/tree/master/tabnet

pytorch:https://github.com/dreamquark-ai/tabnet

本文作者的一些改进:https://github.com/hussius/tabnet_fork


目录
相关文章
|
5天前
|
网络协议 安全 网络安全
探索网络模型与协议:从OSI到HTTPs的原理解析
OSI七层网络模型和TCP/IP四层模型是理解和设计计算机网络的框架。OSI模型包括物理层、数据链路层、网络层、传输层、会话层、表示层和应用层,而TCP/IP模型则简化为链路层、网络层、传输层和 HTTPS协议基于HTTP并通过TLS/SSL加密数据,确保安全传输。其连接过程涉及TCP三次握手、SSL证书验证、对称密钥交换等步骤,以保障通信的安全性和完整性。数字信封技术使用非对称加密和数字证书确保数据的机密性和身份认证。 浏览器通过Https访问网站的过程包括输入网址、DNS解析、建立TCP连接、发送HTTPS请求、接收响应、验证证书和解析网页内容等步骤,确保用户与服务器之间的安全通信。
37 1
|
10天前
|
监控 安全 BI
什么是零信任模型?如何实施以保证网络安全?
随着数字化转型,网络边界不断变化,组织需采用新的安全方法。零信任基于“永不信任,永远验证”原则,强调无论内外部,任何用户、设备或网络都不可信任。该模型包括微分段、多因素身份验证、单点登录、最小特权原则、持续监控和审核用户活动、监控设备等核心准则,以实现强大的网络安全态势。
|
29天前
|
机器学习/深度学习 自然语言处理 数据可视化
【由浅到深】从神经网络原理、Transformer模型演进、到代码工程实现
阅读这个文章可能的收获:理解AI、看懂模型和代码、能够自己搭建模型用于实际任务。
108 11
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于BP神经网络的苦瓜生长含水量预测模型matlab仿真
本项目展示了基于BP神经网络的苦瓜生长含水量预测模型,通过温度(T)、风速(v)、模型厚度(h)等输入特征,预测苦瓜的含水量。采用Matlab2022a开发,核心代码附带中文注释及操作视频。模型利用BP神经网络的非线性映射能力,对试验数据进行训练,实现对未知样本含水量变化规律的预测,为干燥过程的理论研究提供支持。
|
1月前
|
存储 网络协议 安全
30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场
本文精选了 30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场。
90 2
|
1月前
|
运维 网络协议 算法
7 层 OSI 参考模型:详解网络通信的层次结构
7 层 OSI 参考模型:详解网络通信的层次结构
207 1
|
2月前
|
网络协议 前端开发 Java
网络协议与IO模型
网络协议与IO模型
146 4
网络协议与IO模型
|
2月前
|
机器学习/深度学习 网络架构 计算机视觉
目标检测笔记(一):不同模型的网络架构介绍和代码
这篇文章介绍了ShuffleNetV2网络架构及其代码实现,包括模型结构、代码细节和不同版本的模型。ShuffleNetV2是一个高效的卷积神经网络,适用于深度学习中的目标检测任务。
112 1
目标检测笔记(一):不同模型的网络架构介绍和代码
|
1月前
|
网络协议 算法 网络性能优化
计算机网络常见面试题(一):TCP/IP五层模型、TCP三次握手、四次挥手,TCP传输可靠性保障、ARQ协议
计算机网络常见面试题(一):TCP/IP五层模型、应用层常见的协议、TCP与UDP的区别,TCP三次握手、四次挥手,TCP传输可靠性保障、ARQ协议、ARP协议
|
1月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
87 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型

热门文章

最新文章