Google的神经网络表格处理模型TabNet介绍

本文涉及的产品
模型训练 PAI-DLC,100CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: Google的神经网络表格处理模型TabNet介绍

Google Research的TabNet于2019年发布,在预印稿中被宣称优于表格数据的现有方法。它是如何工作的,又如何可以尝试呢?

640.png

表格数据可能构成当今大多数业务数据。考虑诸如零售交易,点击流数据,工厂中的温度和压力传感器,银行使用的KYC (Know Your Customer) 信息或制药公司使用的模型生物的基因表达数据之类的事情。

论文称为TabNet: Attentive Interpretable Tabular Learning(https://arxiv.org/pdf/1908.07442.pdf),很好地总结了作者正在尝试做的事情。“Net”部分告诉我们这是一种神经网络,“Attentive ”部分表示它正在使用一种注意力机制,旨在实现可解释性,并用于表格数据的机器学习。

它是如何工作的?

TabNet使用一种软功能选择将重点仅放在对当前示例很重要的功能上。这是通过顺序的多步骤决策机制完成的。即,以多个步骤自上而下地处理输入信息。正如论文所指出的那样,“自上而下关注的思想是从处理视觉和语言数据或强化学习中得到的启发,可以在高维输入中搜索一小部分相关信息。”

尽管它们与BERT等流行的NLP模型中使用的transformer 有些不同,但执行这种顺序关注的构件却称为transformer 块。这些transformer 使用自注意力机制,试图模拟句子中不同单词之间的依赖关系。这里使用的transformer类型试图使用“软”特性选择,一步一步地消除与示例无关的那些特性,这是通过使用sparsemax函数完成的。

这篇论文的第一个图,如下重现,描绘了信息是如何聚集起来形成预测的。

640.png

TabNet的一个好特性是它不需要特性预处理。另一个原因是,它具有内置的可解释性,即为每个示例选择最相关的特性。这意味着您不必应用外部解释模块,如shap或LIME。

在阅读本文时,要理解这个架构中发生了什么并不容易,但幸运的是,已经发表的代码稍微澄清了一些问题,并表明它并不像您可能认为的那样复杂。

我怎么使用它?

现在TabNet有了更好的实现,如下所述:一个是PyTorch的接口,它有一个类似scikit学习的接口,还有一个是FastAI的接口。

根据作者readme描述要点如下:

为每个数据集创建新的train.csv,val.csv和test.csv文件,我不如读取整个数据集并在内存中进行拆分(当然,只要可行),所以我写了一个在我的代码中为Pandas提供了新的输入功能。

修改data_helper.py文件可能需要一些工作,至少在最初不确定您要做什么以及应该如何定义功能列时(至少我是这样)。还有许多参数需要更改,但它们位于主训练循环文件中,而不是数据帮助器文件中。有鉴于此,我还尝试在我的代码中概括和简化此过程。

我添加了一些快速的代码来进行超参数优化,但到目前为止仅用于分类。

还值得一提的是,作者提供的示例代码仅显示了如何进行分类,而不是回归,因此用户也必须编写额外的代码。我添加了具有简单均方误差损失的回归功能。

使用命令行运行测试

pythontrain_tabnet.py\--csv-pathdata/adult.csv\--target-name"<=50K"\--categorical-featuresworkclass,education,marital.status,\occupation,relationship,race,sex,native.country\--feature_dim16\--output_dim16\--batch-size4096\--virtual-batch-size128\--batch-momentum0.98\--gamma1.5\--n_steps5\--decay-every2500\--lambda-sparsity0.0001\--max-steps7700

强制性参数包括--csv-path(指向CSV文件的位置),-target-name(具有预测目标的列的名称)和-category-featues(逗号分隔列表) 应该视为分类的功能)。其余输入参数是需要针对每个特定问题进行优化的超参数。但是,上面显示的值直接取自TabNet论文,因此作者已经针对成人普查数据集对其进行了优化。

默认情况下,训练过程会将信息写入执行脚本的位置的tflog子文件夹。您可以将tensorboard指向此文件夹以查看训练和验证统计信息:

tensorboard--logdirtflog

如果您没有GPU ...

…您可以尝试这款Colaboratory笔记(https://colab.research.google.com/drive/1AWnaS6uQVDw0sdWjfh-E77QlLtD0cpDa)。请注意,如果您想查看Tensorboard日志,最好的选择是创建一个Google Storage存储桶,并让脚本在其中写入日志。这可以通过使用tb-log-location参数来完成。例如。如果您的存储桶名称是camembert-skyscrape,则可以在脚本的调用中添加--tb-log-location gs:// camembert-skyscraper。(不过请注意,您必须正确设置存储桶的权限。这可能有点麻烦。)

然后可以将tensorboard从自己的本地计算机指向该存储桶:

tensorboard--logdirgs://camembert-skyscraper

超参数优化

在存储库(opt_tabnet.py)中也有一个用于完成超参数优化的快捷脚本。同样,在协作笔记本中显示了一个示例。该脚本仅适用于到目前为止的分类,值得注意的是,某些训练参数虽然实际上并不需要,但仍进行了硬编码(例如,用于尽早停止的参数[您可以继续执行多少步,而 验证准确性没有提高]。)

优化脚本中变化的参数为N_steps,feature_dim,batch-momentum,gamma,lambda-sparsity。(正如下面的优化技巧所建议的那样,output_dim设置为等于feature_dim。)

论文中具有以下有关超参数优化的提示:

大多数数据集对N_steps∈[3,10]产生最佳结果。通常,更大的数据集和更复杂的任务需要更大的N_steps。N_steps的非常高的值可能会过度拟合并导致不良的泛化。

调整Nd [feature_dim]和Na [output_dim]的值是获得性能与复杂性之间折衷的最有效方法。Nd = Na是大多数数据集的合理选择。Nd和Na的非常高的值可能会过度拟合,导致泛化效果差。

γ的最佳选择对整体性能具有重要作用。通常,较大的N_steps值有利于较大的γ。

批量较大对性能有利-如果内存限制允许,建议最大训练数据集总大小的1-10%。虚拟批次大小通常比批次大小小得多。

最初,较高的学习率很重要,应逐渐降低直至收敛。

结果

我已经通过此命令行界面尝试了TabNet的多个数据集,作者提供了他们在那里找到的最佳参数设置。使用这些设置重复运行后,我注意到最佳验证误差(和测试误差)往往在86%左右,类似于不进行超参数调整的CatBoost。作者报告论文中测试集的性能为85.7%。当我使用hyperopt进行超参数优化时,尽管使用了不同的参数设置,但我毫不奇怪地达到了约86%的相似性能。

对于其他数据集,例如Poker Hand 数据集,TabNet被认为远远击败了其他方法。我还没有花很多时间,但是当然每个人都应邀请他们自己对各种数据集进行超参数优化的TabNet!

TabNet是一个有趣的体系结构,似乎有望用于表格数据分析。它直接对原始数据进行操作,并使用顺序注意机制对每个示例执行显式特征选择。此属性还使其具有某种内置的可解释性。

我试图通过围绕它编写一些包装器代码来使TabNet稍微容易一些。下一步是将其与各种数据集中的其他方法进行比较。

tabnet的各种实现

google官方:https://github.com/google-research/google-research/tree/master/tabnet

pytorch:https://github.com/dreamquark-ai/tabnet

本文作者的一些改进:https://github.com/hussius/tabnet_fork


目录
相关文章
|
2月前
|
C++
基于Reactor模型的高性能网络库之地址篇
这段代码定义了一个 InetAddress 类,是 C++ 网络编程中用于封装 IPv4 地址和端口的常见做法。该类的主要作用是方便地表示和操作一个网络地址(IP + 端口)
164 58
|
2月前
|
网络协议 算法 Java
基于Reactor模型的高性能网络库之Tcpserver组件-上层调度器
TcpServer 是一个用于管理 TCP 连接的类,包含成员变量如事件循环(EventLoop)、连接池(ConnectionMap)和回调函数等。其主要功能包括监听新连接、设置线程池、启动服务器及处理连接事件。通过 Acceptor 接收新连接,并使用轮询算法将连接分配给子事件循环(subloop)进行读写操作。调用链从 start() 开始,经由线程池启动和 Acceptor 监听,最终由 TcpConnection 管理具体连接的事件处理。
62 2
|
2月前
基于Reactor模型的高性能网络库之Tcpconnection组件
TcpConnection 由 subLoop 管理 connfd,负责处理具体连接。它封装了连接套接字,通过 Channel 监听可读、可写、关闭、错误等
87 1
|
2月前
|
JSON 监控 网络协议
干货分享“对接的 API 总是不稳定,网络分层模型” 看电商 API 故障的本质
本文从 OSI 七层网络模型出发,深入剖析电商 API 不稳定的根本原因,涵盖物理层到应用层的典型故障与解决方案,结合阿里、京东等大厂架构,详解如何构建高稳定性的电商 API 通信体系。
|
9天前
|
机器学习/深度学习 并行计算 算法
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
|
4月前
|
域名解析 网络协议 安全
计算机网络TCP/IP四层模型
本文介绍了TCP/IP模型的四层结构及其与OSI模型的对比。网络接口层负责物理网络接口,处理MAC地址和帧传输;网络层管理IP地址和路由选择,确保数据包准确送达;传输层提供端到端通信,支持可靠(TCP)或不可靠(UDP)传输;应用层直接面向用户,提供如HTTP、FTP等服务。此外,还详细描述了数据封装与解封装过程,以及两模型在层次划分上的差异。
694 13
|
4月前
|
网络协议 中间件 网络安全
计算机网络OSI七层模型
OSI模型分为七层,各层功能明确:物理层传输比特流,数据链路层负责帧传输,网络层处理数据包路由,传输层确保端到端可靠传输,会话层管理会话,表示层负责数据格式转换与加密,应用层提供网络服务。数据在传输中经过封装与解封装过程。OSI模型优点包括标准化、模块化和互操作性,但也存在复杂性高、效率较低及实用性不足的问题,在实际中TCP/IP模型更常用。
538 10
|
2月前
基于Reactor模型的高性能网络库之Poller(EpollPoller)组件
封装底层 I/O 多路复用机制(如 epoll)的抽象类 Poller,提供统一接口支持多种实现。Poller 是一个抽象基类,定义了 Channel 管理、事件收集等核心功能,并与 EventLoop 绑定。其子类 EPollPoller 实现了基于 epoll 的具体操作,包括事件等待、Channel 更新和删除等。通过工厂方法可创建默认的 Poller 实例,实现多态调用。
203 60
|
2月前
基于Reactor模型的高性能网络库之Channel组件篇
Channel 是事件通道,它绑定某个文件描述符 fd,注册感兴趣的事件(如读/写),并在事件发生时分发给对应的回调函数。
164 60
|
2月前
|
安全 调度
基于Reactor模型的高性能网络库之核心调度器:EventLoop组件
它负责:监听事件(如 I/O 可读写、定时器)、分发事件、执行回调、管理事件源 Channel 等。
180 57

热门文章

最新文章