使用Deep Replay可视化神经网络学习的过程

简介: 使用Deep Replay可视化神经网络学习的过程

深度学习通常被认为是一种黑盒技术,因为通常无法分析它在后端是如何工作的。例如创建了一个深层神经网络,然后将它与你的数据相匹配,我们知道它会使用不同层次的神经元和所有的激活等其他重要的超参数来进行训练。但是我们无法想象信息是如何被传递的或者模型是如何学习的。

如果有一个python包可以创建模型在每个迭代/轮次中如何工作或学习的可视化。您可以将这种可视化用于教育目的,也可以将其展示给其他人,向他们展示模型是如何学习的,首先我们展示下结果,如果你对创建这样的可视化感兴趣那么请往下阅读。

640.gif

Deep Replay一个开放源代码的python包,设计用于让您可视化再现如何在Keras中执行模型训练过程。

设置Colab

对于本文,我们将使用谷歌colab,复制并运行下面给出的代码,以便准备好你的notebook。

#TorunthisnotebookonGoogleColab, youneedtorunthesetwocommandsfirst#toinstallFFMPEG (togenerateanimations-itmaytakeawhiletoinstall!)
#andtheactualDeepReplaypackage!apt-getinstallffmpeg!pipinstalldeepreplay

这个命令也将安装我们需要的库,即deepreplay。

导入库

因为我们正在创建一个深度神经网络,所以我们需要导入所需的库。

fromkeras.layersimportDensefromkeras.modelsimportSequentialfromkeras.optimizersimportSGDfromkeras.initializersimportglorot_normal, normalfromdeepreplay.callbacksimportReplayDatafromdeepreplay.replayimportReplayfromdeepreplay.plotimportcompose_animations, compose_plotsimportmatplotlib.pyplotaspltfromIPython.displayimportHTMLfromsklearn.datasetsimportmake_moons%matplotlibinline

加载数据和创建回调

在这一步中,我们将加载将要处理的数据,并为可视化的回放创建一个回调。

group_name='moons'X, y=make_moons(n_samples=2000, random_state=27, noise=0.03)
replaydata=ReplayData(X, y, filename='moons_dataset.h5', group_name=group_name)
fig, ax=plt.subplots(1, 1, figsize=(5, 5))
ax.scatter(*X.transpose(), c=y, cmap=plt.cm.brg, s=5)

640.png

创建Keras模型

现在,我们将使用不同的层、激活和所有其他超参数创建Keras模型。同时,我们将打印模型的摘要。

sgd=SGD(lr=0.01)
glorot_initializer=glorot_normal(seed=42)
normal_initializer=normal(seed=42)
model=Sequential()
model.add(Dense(input_dim=2,
units=4,
kernel_initializer=glorot_initializer,
activation='tanh'))
model.add(Dense(units=2,
kernel_initializer=glorot_initializer,
activation='tanh',
name='hidden'))
model.add(Dense(units=1,
kernel_initializer=normal_initializer,
activation='sigmoid',
name='output'))
model.compile(loss='binary_crossentropy',
optimizer=sgd,
metrics=['acc'])
model.summary()

640.png

现在让我们训练模型

在训练模型时,我们将回调传递给fit命令。

model.fit(X, y, epochs=200, batch_size=16, callbacks=[replaydata])

绘图

现在我们将创建一些空的图,我们将在其上绘制与模型学习相关的数据。

fig=plt.figure(figsize=(12, 6))
ax_fs=plt.subplot2grid((2, 4), (0, 0), colspan=2, rowspan=2)
ax_ph_neg=plt.subplot2grid((2, 4), (0, 2))
ax_ph_pos=plt.subplot2grid((2, 4), (1, 2))
ax_lm=plt.subplot2grid((2, 4), (0, 3))
ax_lh=plt.subplot2grid((2, 4), (1, 3))

在下一步中,我们只需要将数据传递到这些图中,并创建所有迭代的视频。视频将包含每个轮次的学习过程。

replay=Replay(replay_filename='moons_dataset.h5', group_name=group_name)
fs=replay.build_feature_space(ax_fs, layer_name='hidden',
xlim=(-1, 2), ylim=(-.5, 1),
display_grid=False)
ph=replay.build_probability_histogram(ax_ph_neg, ax_ph_pos)
lh=replay.build_loss_histogram(ax_lh)
lm=replay.build_loss_and_metric(ax_lm, 'acc')

创建一个示例图

sample_figure=compose_plots([fs, ph, lm, lh], 160)
sample_figure

创建视频

sample_anim=compose_animations([fs, ph, lm, lh])
HTML(sample_anim.to_html5_video()

最终的结果就像我们上面展示的那样。

目录
相关文章
|
4小时前
|
Kubernetes 应用服务中间件 Docker
Kubernetes学习-集群搭建篇(二) 部署Node服务,启动JNI网络插件
Kubernetes学习-集群搭建篇(二) 部署Node服务,启动JNI网络插件
|
4小时前
|
机器学习/深度学习 存储 自然语言处理
【威胁情报挖掘-论文阅读】学习图表绘制 基于多实例学习的网络行为提取 SeqMask: Behavior Extraction Over Cyber Threat Intelligence
【威胁情报挖掘-论文阅读】学习图表绘制 基于多实例学习的网络行为提取 SeqMask: Behavior Extraction Over Cyber Threat Intelligence
7 0
|
4小时前
|
机器学习/深度学习 数据可视化 算法
R语言神经网络与决策树的银行顾客信用评估模型对比可视化研究
R语言神经网络与决策树的银行顾客信用评估模型对比可视化研究
|
4小时前
|
机器学习/深度学习 监控 数据可视化
R语言SOM神经网络聚类、多层感知机MLP、PCA主成分分析可视化银行客户信用数据实例2
R语言SOM神经网络聚类、多层感知机MLP、PCA主成分分析可视化银行客户信用数据实例
|
4小时前
|
机器学习/深度学习 数据可视化 算法
R语言SOM神经网络聚类、多层感知机MLP、PCA主成分分析可视化银行客户信用数据实例1
R语言SOM神经网络聚类、多层感知机MLP、PCA主成分分析可视化银行客户信用数据实例
|
4小时前
|
机器学习/深度学习 数据可视化 数据挖掘
R语言神经网络模型金融应用预测上证指数时间序列可视化
R语言神经网络模型金融应用预测上证指数时间序列可视化
|
4小时前
|
机器学习/深度学习 数据可视化 算法
SPSS Modeler决策树和神经网络模型对淘宝店铺服装销量数据预测可视化|数据分享
SPSS Modeler决策树和神经网络模型对淘宝店铺服装销量数据预测可视化|数据分享
|
4小时前
|
机器学习/深度学习 数据可视化 数据挖掘
R语言软件对房屋价格预测:回归、LASSO、决策树、随机森林、GBM、神经网络和SVM可视化|数据分享
R语言软件对房屋价格预测:回归、LASSO、决策树、随机森林、GBM、神经网络和SVM可视化|数据分享
|
4小时前
|
数据采集 安全 数据处理
疫情期间航空网络演变复杂网络可视化
疫情期间航空网络演变复杂网络可视化
|
4小时前
|
机器学习/深度学习 数据可视化 TensorFlow
Python用线性回归和TensorFlow非线性概率神经网络不同激活函数分析可视化
Python用线性回归和TensorFlow非线性概率神经网络不同激活函数分析可视化