【前沿】TensorFlow Pytorch Keras代码实现深度学习大神Hinton NIPS2017 Capsule论文

简介: 10月26日,深度学习元老Hinton的NIPS2017 Capsule论文《Dynamic Routing Between Capsules》终于在arxiv上发表。今天相关关于这篇论文的TensorFlow\Pytorch\Keras实现相继开源出来,让我们来看下。

10月26日,深度学习元老Hinton的NIPS2017 Capsule论文《Dynamic Routing Between Capsules》终于在arxiv上发表。今天相关关于这篇论文的TensorFlow\Pytorch\Keras实现相继开源出来,让我们来看下。

719e11ff6aba2d68a205b6f28c68dd180226a7b4

论文地址:https://arxiv.org/pdf/1710.09829.pdf

Capsule 是一组神经元,其活动向量(activity vector)表示特定实体类型的实例化参数,如对象或对象部分。我们使用活动向量的长度表征实体存在的概率,向量方向表示实例化参数。同一水平的活跃 capsule 通过变换矩阵对更高级别的 capsule 的实例化参数进行预测。当多个预测相同时,更高级别的 capsule 变得活跃。我们展示了判别式训练的多层 capsule 系统在 MNIST 数据集上达到了最好的性能效果,比识别高度重叠数字的卷积网络的性能优越很多。为了达到这些结果,我们使用迭代的路由协议机制:较低级别的 capsule 偏向于将输出发送至高级别的 capsule,有了来自低级别 capsule 的预测,高级别 capsule 的活动向量具备较大的标量积。

CapsNet-PyTorch

python依赖包

  • Python 3
  • PyTorch
  • TorchVision
  • TorchNet
  • TQDM
  • Visdom

使用说明

第一步 在capsule_network.py文件中设置训练epochs,batch size等


BATCH_SIZE = 100NUM_CLASSES = 10NUM_EPOCHS = 30NUM_ROUTING_ITERATIONS = 3


Step 2 开始训练. 如果本地文件夹中没有MNIST数据集,将运行脚本自动下载到本地. 确保 PyTorch可视化工具Visdom正在运行。


$ sudo python3 -m visdom.server & python3 capsule_network.py


基准数据

经过30个epoche的训练手写体数字的识别率达到99.48%. 从下图的训练进度和损失图的趋势来看,这一识别率可以被进一步的提高 。

9f94db0c9075c534d2b438d8966d95eb09dcef1f

采用了PyTorch中默认的Adam梯度优化参数并没有用到动态学习率的调整。 batch size 使用100个样本的时候,在雷蛇GTX 1050 GPU上每个Epochs 用时3分钟。

待完成

  • 扩展到除MNIST以外的其他数据集。

Credits

主要借鉴了以下两个 TensorFlow 和 Keras 的实现:

  1. Keras implementation by @XifengGuo
  2. TensorFlow implementation by @naturomics

Many thanks to @InnerPeace-Wu for a discussion on the dynamic routing procedure outlined in the paper.


CapsNet-Tensorflow


Python依赖包

  • Python
  • NumPy
  • Tensorflow (I'm using 1.3.0, not yet tested for older version)
  • tqdm (for displaying training progress info)
  • scipy (for saving image)

使用说明

训练

第一步 用git命令下载代码到本地.


$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow


第二步 下载MNIST数据集(http://yann.lecun.com/exdb/mnist/), 移动并解压到data/mnist 文件夹(当你用复制wget 命令到你的终端是注意渠道花括号里的反斜杠)


$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/{train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz}
$ gunzip data/mnist/*.gz


第三步 开始训练:


$ pip install tqdm  # install it if you haven't installed yet
$ python train.py


tqdm包并不是必须的,只是为了可视化训练过程。如果你不想要在train.py中将循环for in step ... 改成 ``for step in range(num_batch)就行了。

评估


$ python eval.py --is_training False


结果

错误的运行结果(Details in Issues #8):

  • training loss 
d39afb6441fad222de027dafb71542e163ccc50e
6ff053759b0d6c567fae0fd2cf7417cd859268b4
  • test acc

b73d9cb682dbd62a620309e59815252be2a8854a
f8fec5951f2d31b9db1776a9f85d76caf0b51b56

4d1af1011ad3d14b2186bd503c27c18394fb5927

Results after fixing Issues #8:

关于capsule的一点见解

  1. 一种新的神经单元(输入向量输出向量,而不是标量)
  2. 常规算法类似于Attention机制
  3. 总之是一项很有潜力的工作,有很多工作可以在之上开展


待办:

  • 完成MNIST的实现Finish the MNIST version of capsNet (progress:90%)
  • 在其他数据集上验证capsNet
  • 调整模型结构
  • 一篇新的投稿在ICLR2018上的后续论文(https://openreview.net/pdf?id=HJWLfGWRb) about capsules(submitted to ICLR 2018)

CapsNet-Keras


依赖包

  • Keras
  • matplotlib

使用方法

训练

第一步 安装 Keras:

$ pip install keras

第二步 用 git命令下载代码到本地.

$ git clone https://github.com/xifengguo/CapsNet-Keras.git
$ cd CapsNet-Keras

第三步 训练:

$ python capsulenet.py

一次迭代训练(default 3).

$ python capsulenet.py --num_routing 1

其他参数包括想 batch_size, epochs, lam_recon, shift_fraction, save_dir 可以以同样的方式使用。 具体可以参考 capsulenet.py

测试

假设你已经有了用上面命令训练好的模型,训练模型将被保存在 result/trained_model.h5. 现在只需要使用下面的命令来得到测试结果。

$ python capsulenet.py --is_training 0 --weights result/trained_model.h5

将会输出测试结果并显示出重构后的图片。测试数据使用的和验证集一样 ,同样也可以很方便的在新数据上验证,至于要按照你的需要修改下代码就行了。

如果你的电脑没有GPU来训练模型,你可以从https://pan.baidu.com/s/1hsF2bvY下载预先训练好的训练模型

结果

主要结果
运行 python capsulenet.py: epoch=1 代表训练一个epoch 后的结果 在保存的日志文件中,epoch从0开始。

66cf69a3316550a23f64ca342a079cb31cf07342

损失和准确度:

669b7258294085504f8de6d70bf0e809e35c7005

一次常规迭代后的结果

运行 python CapsNet.py --num_routing 1 


85eefd6b8ec6475ea939419d990ced0f3c063ae8

测试结果每个 epoch 在单卡GTX 1070 GPU上大概需要110s 注释: 训练任然是欠拟合的,欢迎在你自己的机器上验证。学习率decay还没有经过调试, 我只是试了一次,你可以接续微调。

运行 python capsulenet.py --is_training 0 --weights result/trained_model.h5


模型结构:df996811608d0f97086ea1cf59161dd60392293c

04c4167084109477ad7600a8448ed141dea7b8e5

其他实现代码

  • Kaggle (this version as self-contained notebook):
  • MNIST Dataset running on the standard MNIST and predicting for test data
  • MNIST Fashion running on the more challenging Fashion images.
  • TensorFlow:
  • naturomics/CapsNet-Tensorflow
  • Very good implementation. I referred to this repository in my code.
  • InnerPeace-Wu/CapsNet-tensorflow
  • I referred to the use of tf.scan when optimizing my CapsuleLayer.
  • LaoDar/tf_CapsNet_simple
  • PyTorch:
  • nishnik/CapsNet-PyTorch
  • timomernick/pytorch-capsule
  • gram-ai/capsule-networks
  • andreaazzini/capsnet.pytorch
  • leftthomas/CapsNet
  • MXNet:
  • AaronLeong/CapsNet_Mxnet
  • Lasagne (Theano):
  • DeniskaMazur/CapsNet-Lasagne
  • Chainer:
  • soskek/dynamic_routing_between_capsules

参考网址链接:

https://github.com/gram-ai/capsule-networks

https://github.com/naturomics/CapsNet-Tensorflow

https://github.com/XifengGuo/CapsNet-Keras


原文发布时间为:2017-11-5

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”微信公众号

原文链接:【前沿】TensorFlow Pytorch Keras代码实现深度学习大神Hinton NIPS2017 Capsule论文

相关文章
|
2月前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch深度学习 ? 带你从入门到精通!!!
🌟 蒋星熠Jaxonic,深度学习探索者。三年深耕PyTorch,从基础到部署,分享模型构建、GPU加速、TorchScript优化及PyTorch 2.0新特性,助力AI开发者高效进阶。
PyTorch深度学习 ? 带你从入门到精通!!!
|
2月前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow与PyTorch深度对比分析:从基础原理到实战选择的完整指南
蒋星熠Jaxonic,深度学习探索者。本文深度对比TensorFlow与PyTorch架构、性能、生态及应用场景,剖析技术选型关键,助力开发者在二进制星河中驾驭AI未来。
651 13
|
5月前
|
机器学习/深度学习 算法 定位技术
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现裂缝的检测识别(C#代码UI界面版)
本项目基于YOLOv8模型与C#界面,结合Baumer工业相机,实现裂缝的高效检测识别。支持图像、视频及摄像头输入,具备高精度与实时性,适用于桥梁、路面、隧道等多种工业场景。
602 27
|
3月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
196 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
3月前
|
机器学习/深度学习 数据采集 编解码
基于深度学习分类的时相关MIMO信道的递归CSI量化(Matlab代码实现)
基于深度学习分类的时相关MIMO信道的递归CSI量化(Matlab代码实现)
180 1
|
3月前
|
机器学习/深度学习 算法 vr&ar
【深度学习】基于最小误差法的胸片分割系统(Matlab代码实现)
【深度学习】基于最小误差法的胸片分割系统(Matlab代码实现)
|
6月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
288 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
9月前
|
机器学习/深度学习 自然语言处理 算法
PyTorch PINN实战:用深度学习求解微分方程
物理信息神经网络(PINN)是一种将深度学习与物理定律结合的创新方法,特别适用于微分方程求解。传统神经网络依赖大规模标记数据,而PINN通过将微分方程约束嵌入损失函数,显著提高数据效率。它能在流体动力学、量子力学等领域实现高效建模,弥补了传统数值方法在高维复杂问题上的不足。尽管计算成本较高且对超参数敏感,PINN仍展现出强大的泛化能力和鲁棒性,为科学计算提供了新路径。文章详细介绍了PINN的工作原理、技术优势及局限性,并通过Python代码演示了其在微分方程求解中的应用,验证了其与解析解的高度一致性。
2803 5
PyTorch PINN实战:用深度学习求解微分方程
|
9月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
538 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
10月前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习工具和框架详细指南:PyTorch、TensorFlow、Keras
在深度学习的世界中,PyTorch、TensorFlow和Keras是最受欢迎的工具和框架,它们为研究者和开发者提供了强大且易于使用的接口。在本文中,我们将深入探索这三个框架,涵盖如何用它们实现经典深度学习模型,并通过代码实例详细讲解这些工具的使用方法。

热门文章

最新文章