DeepTables:用于表格数据的深度学习工具包
简介
MLP(也称为全连接神经网络)已被证明在学习分布表示方面效率低下。 事实证明,感知器层的“Add”操作在探索乘法特征交互时性能较差。 在大多数情况下,必须进行手动特征工程,并且这项工作需要广泛的领域知识并且非常繁琐。 如何在神经网络中有效地学习功能交互成为最重要的问题。
目前为止,业界已经提出了各种模型来进行CTR预测,并且这些模型在最近几年中将一直优于现有的最新技术。 众所周知的示例包括FM,DeepFM,Wide&Deep,DCN,PNN等。这些模型还可以在合理利用的情况下为表格数据提供良好的性能。
DT旨在利用最新的研究结果为用户提供表格数据的端到端工具包。
DT的设计考虑了以下主要目标:
- 易于使用,非专家也可以使用。
- 开箱即用地提供良好的性能。
- 灵活的架构,易于用户扩展。
教程
安装
cpu安装命令:
pip install deeptables
gpu安装命令:
pip install deeptables[gpu]
简单实例
下面是DT用于二分类任务的简单例子:
import numpy as np from deeptables.models import deeptable, deepnets from deeptables.datasets import dsutils from sklearn.model_selection import train_test_split # 加载数据 df = dsutils.load_bank() df_train, df_test = train_test_split(df, test_size=0.2, random_state=42) y = df_train.pop('y') y_test = df_test.pop('y') #训练 config = deeptable.ModelConfig(nets=deepnets.DeepFM) dt = deeptable.DeepTable(config=config) model, history = dt.fit(df_train, y, epochs=10) #评估 result = dt.evaluate(df_test,y_test, batch_size=512, verbose=0) print(result) #预测 preds = dt.predict(df_test)
DeepTables在Kaggle Categorical Feature Encoding Challenge II 比赛中取得了第一的好成绩,方案链接,大家可以尝试使用~