从深度学习到深度森林方法(Python)

简介: 本文第一节源于周志华教授《关于深度学习的一点思考》, 在此基础上对深度学习、深度森林做了原理解析并实践。

一、关于深度学习的一点思考



二、深度森林的介绍


目前深度神经网络(DNN)做得好的几乎都是涉及图像视频(CV)、自然语言处理(NLP)等的任务,都是典型的数值建模任务(在表格数据tabular data的表现也是稍弱的),而在其他涉及符号建模、离散建 模、混合建模的任务上,深度神经网络的性能并没有那么好。


深度森林(gcForest)是深度神经网络(DNN)之外的探索的一种深度模型,原文:it may open a door towards alternative to deep neural networks for many tasks。不同于深度神经网络由可微的神经元组成,深度森林的基础构件是不可微的决策树,其训练过程不基于 BP 算 法,甚至不依赖于梯度计算。它初步验证了关于深度学习奏效原因的猜想--即只要能做到逐层加工处理、内置特征变换、模型复杂度够,就能构建出有效的深度学习模型,并非必须使用神经网络。


深度森林主要的特点是:


  • 拥有比其他基于决策树的集成学习方法更好的性能


  • 拥有更少的超参数,并且无需大量的调参


  • 训练效率高,并且能够处理大规模的数据集


深度森林目前还处于探索阶段,评估模型(gcForest)的表现,在MNIST数据集准确率不错:



在CIFAR-10数据集上准确率欠佳:



三、深度森林原理


深度森林其实也就是ensemble of ensemble的模型,可以看作是对集成树(森林)模型,进一步stacking集成学习及优化(Complete Random Forest、shortcut-connection、Multi-Grained Scanning等),以解决深层容易过拟合的问题。


3.1 特征的处理


深度森林借鉴了CNN滑动卷积核的特征提取,通过多粒度扫描(Multi-Grained Scanning)方法,滑动窗口扫描原始特征,生成输入特征。



如上图,假设我们现在有一个400维(序列数据)的样本输入,现在设定采样窗口是100维的,那我们可以通过逐步的采样,最终获得301个子样本(默认采样步长是1,得到的子样本个数 = (400-100)/1 + 1)。



整个特征处理的过程就是:先输入一个完整的P维样本,然后通过一个长度为k的采样窗口进行滑动采样,得到S = (P - K)/1+1 个k维特征子样本向量,接着每个子样本都用于完全随机森林和普通随机森林的训练并在每个森林都获得一个长度为C(类别数)的概率向量,这样每个森林会产生长度为S*C的表征向量(即经过随机森林转换并拼接的概率向量),最后把每层的F个森林的结果拼接在一起得到本层输出。


3.2 深度森林的架构



整个模型的架构采用一个级联结构(如上图), 每一级都采用两种森林构建: random forests (black) and completely-random tree forests(blue),使用completely-random可以增加基模型的多样性,以减少过拟合风险,提高集成学习的效果。


完全随机森林(completely-random tree forests):由多棵树组成,每棵树包含所有的特征,并且随机选择一个特征作为分裂树的分裂节点。一直分裂到每个叶子节点只包含一个类别或者不多于是个样本结束。随机森林(random forests):同样由多棵树构成,每棵树通过随机选取Sqrt(特征总数)个特征 ,然后通过GINI分数来筛选分裂节点。


每一层的级联接收前一层处理的特征信息,并将处理结果(类向量)输出到下一层(如上图)。以三分类为例,输入特征为向量x,经过每个森林学习后(注:每个森林的学习的数据利用k折交叉验证得到,以减少过拟合风险),得到预测类分布,然后求平均,再与之前原始特征拼接(类似shortcut-connection),作为下一层的输入。


扩展完一层后,整个级联结构可在验证集上面测试性能,若没有显著提高,训练过程会终止,故而层数可以自动确定。这也是gcForest能够自动决定模型复杂度的原因。


四、深度森林预测


本节简单使用深度森林模型用于波士顿房价回归预测及癌细胞分类任务。 安装:pip install deep-forest


  • 波士顿房价回归预测,使用默认参数效果还不错:Testing MSE: 8.068


# 回归预测--波士顿房价
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from deepforest import CascadeForestRegressor
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
model = CascadeForestRegressor(random_state=1)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print("\nTesting MSE: {:.3f}".format(mse))


  • 癌细胞分类任务,准确率不错,Testing Accuracy: 95.105 %


# 分类预测--癌细胞分类数据集
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from deepforest import CascadeForestClassifier
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
model = CascadeForestClassifier(random_state=1)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred) * 100
print("\nTesting Accuracy: {:.3f} %".format(acc))
相关文章
|
26天前
|
测试技术 开发者 Python
Python单元测试入门:3个核心断言方法,帮你快速定位代码bug
本文介绍Python单元测试基础,详解`unittest`框架中的三大核心断言方法:`assertEqual`验证值相等,`assertTrue`和`assertFalse`判断条件真假。通过实例演示其用法,帮助开发者自动化检测代码逻辑,提升测试效率与可靠性。
167 1
|
2月前
|
机器学习/深度学习 数据采集 数据挖掘
基于 GARCH -LSTM 模型的混合方法进行时间序列预测研究(Python代码实现)
基于 GARCH -LSTM 模型的混合方法进行时间序列预测研究(Python代码实现)
|
2月前
|
调度 Python
微电网两阶段鲁棒优化经济调度方法(Python代码实现)
微电网两阶段鲁棒优化经济调度方法(Python代码实现)
|
19天前
|
人工智能 数据安全/隐私保护 异构计算
桌面版exe安装和Python命令行安装2种方法详细讲解图片去水印AI源码私有化部署Lama-Cleaner安装使用方法-优雅草卓伊凡
桌面版exe安装和Python命令行安装2种方法详细讲解图片去水印AI源码私有化部署Lama-Cleaner安装使用方法-优雅草卓伊凡
234 8
桌面版exe安装和Python命令行安装2种方法详细讲解图片去水印AI源码私有化部署Lama-Cleaner安装使用方法-优雅草卓伊凡
|
1月前
|
算法 调度 决策智能
【两阶段鲁棒优化】利用列-约束生成方法求解两阶段鲁棒优化问题(Python代码实现)
【两阶段鲁棒优化】利用列-约束生成方法求解两阶段鲁棒优化问题(Python代码实现)
|
2月前
|
机器学习/深度学习 数据采集 算法
【CNN-BiLSTM-attention】基于高斯混合模型聚类的风电场短期功率预测方法(Python&matlab代码实现)
【CNN-BiLSTM-attention】基于高斯混合模型聚类的风电场短期功率预测方法(Python&matlab代码实现)
154 4
|
2月前
|
机器学习/深度学习 数据采集 TensorFlow
基于CNN-GRU-Attention混合神经网络的负荷预测方法(Python代码实现)
基于CNN-GRU-Attention混合神经网络的负荷预测方法(Python代码实现)
|
9月前
|
机器学习/深度学习 运维 安全
深度学习在安全事件检测中的应用:守护数字世界的利器
深度学习在安全事件检测中的应用:守护数字世界的利器
345 22
|
6月前
|
机器学习/深度学习 编解码 人工智能
计算机视觉五大技术——深度学习在图像处理中的应用
深度学习利用多层神经网络实现人工智能,计算机视觉是其重要应用之一。图像分类通过卷积神经网络(CNN)判断图片类别,如“猫”或“狗”。目标检测不仅识别物体,还确定其位置,R-CNN系列模型逐步优化检测速度与精度。语义分割对图像每个像素分类,FCN开创像素级分类范式,DeepLab等进一步提升细节表现。实例分割结合目标检测与语义分割,Mask R-CNN实现精准实例区分。关键点检测用于人体姿态估计、人脸特征识别等,OpenPose和HRNet等技术推动该领域发展。这些方法在效率与准确性上不断进步,广泛应用于实际场景。
785 64
计算机视觉五大技术——深度学习在图像处理中的应用
|
10月前
|
机器学习/深度学习 传感器 数据采集
深度学习在故障检测中的应用:从理论到实践
深度学习在故障检测中的应用:从理论到实践
734 6

推荐镜像

更多
下一篇
oss教程