Hinton胶囊网络代码正式开源,5天GitHub fork超1.4万

简介: 万众期待中,Hinton胶囊网络论文《Dynamic Routing between Capsules》的代码正式公布,仅仅5天,Github上fork数量就超过了1.4万。Capsule真能取代CNN吗?接下来是你动手的时间了。

Hinton胶囊网络论文《Dynamic Routing between Capsules》的一作Sara Sabour日前在GitHub公布了代码,使用TensorFlow和NumPy实现,只有一台GPU也行,仅仅5天,fork的数量就超过了1.4万。

实际上,在官方代码公布前,已经有很多其他版本和实现。新智元也对胶囊网络的概念做过详细介绍:

[1]【大神Hinton】深度学习要另起炉灶,彻底抛弃反向传播

[2]【重磅】Hinton 大神 Capsule 论文首次公布,深度学习基石 CNN 或被取代

[3] Reddit 讨论:Hinton 的 Capsule 网络真的比 CNN 效果更好吗?

[4]【Hinton 碰撞 LeCun】CNN 有两大缺陷,要用 capsule 做下一代 CNN

[5]【一文读懂 Hinton 最新 Capsules 论文】CNN 未来向何处去

[6]【一文读懂 Hinton 最新论文】胶囊网络 9 大优势 4 大缺陷(视频 + PPT)

不过,在看代码前,还是有必要再次回顾这篇Hinton革新CNN的论文,Jonathan Hui在他的博客上对这篇论文做过拆解,从基本概念开始,读来非常友好。

用“Capsule”作为下一代CNN的理由 

在深度学习中,神经元的激活水平通常被解释为检测特定特征的可能性。

471be7a6aa83d463a89dc0f94296a8a9cd8ece87

但是,CNN善于检测特征,却在探索特征(视角,大小,方位)之间的空间关系方面效果较差。例如,下面这张图片可能会骗过一个简单的CNN模型,让CNN模型相信这是一张真实的人脸。

2b9f82c6166ca1351844009f88c4ed832fc9b7a4

一个简单的CNN模型可以正确提取鼻子、眼睛和嘴巴的特征,但会错误地激活神经元进行人脸检测。如果不了解空间方向,大小不匹配,那么对于人脸检测的激活将会太高,比如下图95%。

4098e35e6d55fbab17d9133d86d86f935359a3c4

现在,假设每个神经元都包含特征的可能性和属性。例如,神经元输出的是一个包含 [可能性,方向,大小] 的向量。利用这种空间信息,就可以检测鼻子、眼睛和耳朵特征之间的方向和大小的一致性,因此对于人脸检测的激活输出就会低很多。

5d7659a5c379b4fdeecc6594cc621652cfd78fd2

在Hinton的胶囊网络的论文中,就使用“胶囊”(capsule)来指代这样的神经元。

从概念上讲,我们可以将CNN看成是训练神经元来处理不同方向的视角,并在最顶层有一层人脸检测神经元。

4925a0a31b76f4c1efd7fef9ae640dcda317920b

如上所述,为了CNN能够处理不同的视角或变体,我们添加了更多的卷积图层和特征图。尽管如此,这种方法倾向于记忆数据集,而不是得出一个比较通用的解决方案,它需要大量的训练数据来覆盖不同的变体,并避免过拟合。MNIST数据集包含55,000个训练数据,也即每个数字都有5,500个样本。但是,儿童看过几次就能记住数字。现有的包括CNN在内的深度学习模式在利用数据方面效率十分低下。引用Geoffrey Hinton的一句话:

It (convolutional network) works depressingly well.

胶囊网络不是训练来捕捉特定变体的特征,而是捕捉特征及其变体的可能性。所以胶囊的目的不仅在于检测特征,还在于训练模型来学习变体。

这样,相同的胶囊就可以检测不同方向的同一个物体类别(例如,顺时针旋转):

1ab2835cc475e8bc7acfb29db896738dfcb705d6

其中,Invariance对应特征检测,特征是不变的。例如,检测鼻子的神经元不管什么方向,都检测鼻子。但是,神经元空间定向的损失最终会损害这种invariance模型的有效性。

Equivariance对应变体检测,也即可以相互转换的对象(例如检测不同方向的人脸)。直观地说,胶囊网络检测到脸部旋转了20°,而不是实现与旋转了20°的变体相匹配的脸。通过强制模型学习胶囊中的特征变体,我们可以用较少的训练数据,更有效地推断可能的变体。此外,也可以更有效地防止对抗攻击。

计算一个Capsule网络的输出:不同维度的参数 

胶囊是一组神经元,不仅捕捉特征的可能性,还捕捉具体特征的参数。

例如,下面的第一行表示神经元检测到数字“7”的概率。2-D胶囊是组合了2个神经元的网络。这个胶囊在检测数字“7”时输出2-D矢量。对于第二行中的第一个图像,它输出一个向量 v=(0,0.9)v=(0,0.9)。矢量的大小0.9 对应于检测“7”的概率。每行的第二个图像看起来更像是“1”而不是“7”。 因此,其相应的可能性为“7”较小。

42c5abbad4405062d0b485baf562edfaa072ce78

在第三行,旋转图像20°。胶囊将产生具有相同幅度但不同方向的矢量。这里,矢量的角度表示数字“7”的旋转角度。最后,还可以添加2个神经元来捕捉大小和笔画的宽度(见下图)。

a1e1daff762fc013d6af3e5edc38ac78157d6785

我们称胶囊的输出向量为活动向量 ,其幅度代表检测特征的概率,其方向代表其参数(属性)。

在计算一个胶囊网络输出的时候,首先看一个全连接的神经网络:581cdd93881cd636945dec477a8d373b96c98ac7

其中每个神经元的输出是从前一层神经元的输出计算而来的:

55218d586c4f45458e7e20be98661d93edcb1c48

其中640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=都是标量

对于capsule网络,一个capsule的输入640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=和输出640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=都是向量。

c7cc6b1e91fcdce2e9de655979ea46eb89631c88

我们将一个变换矩阵(transformation matrix640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=应用到前一层的capsule输出640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=。例如,用一个640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=矩阵,我们把一个k-D 640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=变换成一个m-D 640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=。然后计算640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=的加权和:

34649dd16feae16d96b6dfe1d6aa51895e3c6774

其中,640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=迭代动态路由过程( iterative dynamic routing process )训练的耦合系数(coupling coefficients),640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=被设计来求和到1。


我们不使用ReLU函数,而是使用一个挤压函数squashing function)来缩短0和单位长度之间的向量。

101657177e878e398c8d536ee2cfe9c41a431ad9

它将短向量缩小到接近0,将长向量缩小为接近单位向量( unit vectors)。因此,每个capsule的似然性在0到1之间。

4c4263d469ab31d0df892355abe6c2ce5a83399b

迭代动态路由规则与重要性

在深度学习中,我们使用反向传播来训练模型参数。转换矩阵 Wij 在胶囊中仍然用反向传播训练。不过,耦合系数 cij 用新的迭代动态路由方法进行计算。

d9f8028f8a01abf2ea80c627c70d86a54a12be6f

以下是动态路由的最终伪代码:

21bbbb8288536b645c7a1fb6c336a7d88b55e918

在深度学习中,我们使用反向传播来训练基于成本函数的模型参数。这些参数(权重)控制信号从一层到另一层的路由。如果两个神经元之间的权重为零,则神经元的激活不会传播到该神经元。

迭代动态路由提供了如何根据特征参数来路由信号的替代方案。通过利用特征参数,理论上,可以更好地将胶囊分组,形成一个高层次的结构。例如,胶囊层可能最终表现为探索“部分-整体”关系的分析树。例如,脸部由眼睛、鼻子和嘴组成。迭代动态路由利用变换矩阵、可能性和特征的性质,控制向上传播到上面胶囊的信号的多少。

最后,就到了应用胶囊构建CapsNet,进而对MNIST数字进行分类和重构的时候了。下面是CapsNet的架构。一个CapsNet共有3层,两个卷积层和一个全连接层。

6b036030f14de7a05764e9fe2eab9cc983cd437d

论文提到的MNIST数字重构任务:

7b4a7d5bf6622531f5d9c647d172f67a14943417

Github代码


Capsule模型代码在以下论文中使用:

  • "Dynamic Routing between Capsules”(胶囊间的动态路由) by Sara Sabour, Nickolas Frosst, Geoffrey E. Hinton.

要求:

  • TensorFlow(请参阅http://www.tensorflow.org了解如何安装/升级)
  • NumPy(请参阅http://www.numpy.org/)
  • GPU

运行测试验证设置是否正确,例如:

f8e85bcdde69a746313e37b321018361f6d582f6

快速MNIST测试结果:

  • 从以下网址下载并提取MNIST记录到 $DATA_DIR/:https://storage.googleapis.com/capsule_toronto/mnist_data.tar.gz
  • 从以下网址下载并提取MNIST模型checkpoint到$CKPT_DIR:https://storage.googleapis.com/capsule_toronto/mnist_checkpoints.tar.gz

6c1cc5eebad1bbcabbf84aa1f3120894bde35f9f

快速CIFAR10 ensemble测试结果:

  • 从以下网址下载并提取cifar10二进制版本到 $DATA_DIR/:https://www.cs.toronto.edu/~kriz/cifar.html
  • 从以下网址下载并提取cifar10模型checkpoint到 $CKPT_DIR:https://storage.googleapis.com/capsule_toronto/cifar_checkpoints.tar.gz
  • 将提取的二进制文件的目录作为 data_dir 传递给($ DATA_DIR)

beebcab12f07ce5ea2c3e26ddb023dc7ebf3cfbf

Sample CIFAR10训练命令:

aa037141100fa3ec87e0143afee4633c712cbe8b

Sample MNIST的完整训练命令:

  • 在 training-validation pass 训练,validate=true 也是如此
  • 要在一个以上的GPU pass训练,num_gpus = NUM_GPUS

640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=

Sample MNIST基线训练命令:

da92d25d4e42908e211c7ad7a7f68fa132fa3a6e

在上述模型的训练期间对validation进行测试:

训练过程中连续运行的注意事项

  • 在训练中也要注意pass --validate=true
  • 总共需要2个GPU:一个用于训练,一个用于验证
  • 如果在同一台机器上进行训练和验证,则需要限制每个任务的RAM消耗,因为TensorFlow会填满第一个任务的所有RAM,从而导致第二个任务失败。

b84e4cf0bf8574763a303934a706ad2e009b501f

  • 要测试/训练 MultiMNIST pass --num_targets = 2 以及 --data_dir = $DATA_DIR/multitest_6shifted_mnist.tfrecords@10。 
  • 生成 multiMNIST / MNIST 记录的代码位于input_data/mnist/mnist_shift.py。

生成multiMNIST测试的示例代码:

8088f581ae75c03feceb9640a8257866f57c9f40

为 affNIST 的泛化能力建立 expanded_mnist: --shift = 6 --pad = 6。

读取affNIST的代码将遵循。

代码由Sara Sabour(sarasra, sasabour@google.com)维护。


原文发布时间为:2018-02-1

本文作者:文强,马文

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”微信公众号

原文链接:Hinton胶囊网络代码正式开源,5天GitHub fork超1.4万

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
15天前
|
存储 JavaScript 网络架构
【开源图床】使用Typora+PicGo+Github+CDN搭建个人博客图床
【开源图床】使用Typora+PicGo+Github+CDN搭建个人博客图床
27 3
|
1月前
|
机器学习/深度学习 算法 PyTorch
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
233 1
|
1月前
|
人工智能 文字识别 异构计算
关于github开源ocr项目的疑问
小白尝试Python OCR学习,遇到报错。尝试Paddle OCR部署失败,Tesseract OCR在Colab误操作后恢复失败。EasyOCR在Colab和阿里天池Notebook成功,但GPU资源不足。其他平台部署不顺,决定使用WebUI或阿里云轻应用。求教OCR项目部署到本地及简单OCR项目推荐。
29 2
|
1月前
|
Web App开发 前端开发 数据库
推荐GitHub上开源的一款独立开发者出海技术栈和工具合集
推荐GitHub上开源的一款独立开发者出海技术栈和工具合集
|
1月前
|
机器学习/深度学习 人工智能 API
『GitHub项目圈选06』推荐5款本周 超火 的开源AI项目
『GitHub项目圈选06』推荐5款本周 超火 的开源AI项目
|
1月前
|
自然语言处理 并行计算 PyTorch
GitHub 开源神器 Bark模型,让文本转语音更简单!
GitHub 开源神器 Bark模型,让文本转语音更简单!
|
1月前
|
数据采集 人工智能 Rust
『GitHub项目圈选周刊01』一款构建AI数字人项目开源了!自动实现音视频同步!
『GitHub项目圈选周刊01』一款构建AI数字人项目开源了!自动实现音视频同步!
205 0
|
14天前
|
JSON Kubernetes 网络架构
Kubernetes CNI 网络模型及常见开源组件
【4月更文挑战第13天】目前主流的容器网络模型是CoreOS 公司推出的 Container Network Interface(CNI)模型
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
卷积神经元网络中常用卷积核理解及基于Pytorch的实例应用(附完整代码)
卷积神经元网络中常用卷积核理解及基于Pytorch的实例应用(附完整代码)
20 0
|
1月前
|
机器学习/深度学习 数据采集 人工智能
m基于深度学习网络的手势识别系统matlab仿真,包含GUI界面
m基于深度学习网络的手势识别系统matlab仿真,包含GUI界面
41 0