【李沐】十分钟从 PyTorch 转 MXNet

简介: PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。MXNet通过ndarray和 gluon模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法。
+关注继续查看

PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。例如 Caffe2 最近就并入了 PyTorch。

可能大家不是特别知道的是,MXNet 通过 ndarray 和 gluon 模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法

89e0e8d5de21311740959c69f9ae2fe0258d52af

安装

PyTorch 默认使用 conda 来进行安装,例如

03192dec910f50e049d5fecb3109e8b09f6cdf9b

而 MXNet 更常用的是使用 pip。我们这里使用了 --pre 来安装 nightly 版本

83d9786fd8b0f46bc693765c60c2e0544ec118a7

多维矩阵

对于多维矩阵,PyTorch 沿用了 Torch 的风格称之为 tensor,MXNet 则追随了 NumPy 的称呼 ndarray。下面我们创建一个两维矩阵,其中每个元素初始化成 1。然后每个元素加 1 后打印。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

b472e4f3dada3709d53edf6608ab47f322089ca2

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

28436e80f4887a6ef0ba21b7dedafad23e823410

忽略包名的不一样的话,这里主要的区别是 MXNet 的形状传入参数跟 NumPy 一样需要用括号括起来。

模型训练

下面我们看一个稍微复杂点的例子。这里我们使用一个多层感知机(MLP)来在 MINST 这个数据集上训练一个模型。我们将其分成 4 小块来方便对比。

读取数据

这里我们下载 MNIST 数据集并载入到内存,这样我们之后可以一个一个读取批量。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

a3657b3dbcca68c3b62521edd3f0dd3082a15389

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d50f77b5a334e2bd2072180a29c72c9a74ed18dc

这里的主要区别是 MXNet 使用 transform_first 来表明数据变化是作用在读到的批量的第一个元素,既 MNIST 图片,而不是第二个标号元素。

定义模型

下面我们定义一个只有一个单隐层的 MLP 。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

01818faab6a9ae66f6daae74be169b64c894f344

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

8e95eac5ee43682f1ca882e782de98079be13d9e

我们使用了 Sequential 容器来把层串起来构造神经网络。这里 MXNet 跟 PyTorch 的主要区别是:

8481c8f592b7f349aa84a1de5c171db681516edf不需要指定输入大小,这个系统会在后面自动推理得到
8481c8f592b7f349aa84a1de5c171db681516edf全连接和卷积层可以指定激活函数
8481c8f592b7f349aa84a1de5c171db681516edf需要创建一个 name_scope 的域来给每一层附上一个独一无二的名字,这个在之后读写模型时需要
8481c8f592b7f349aa84a1de5c171db681516edf我们需要显示调用模型初始化函数。


大家知道 Sequential 下只能神经网络只能逐一执行每个层。PyTorch 可以继承 nn.Module 来自定义 forward 如何执行。同样,MXNet 可以继承 nn.Block 来达到类似的效果。

损失函数和优化算法

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

483451f3193b8143e4fe7c180da0a03baff4fc71

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d126effd55a80c7df82aa0e96cb0f5cf7f1c5785

这里我们使用交叉熵函数和最简单随机梯度下降并使用固定学习率 0.1

训练

最后我们实现训练算法,并附上了输出结果。注意到每次我们会使用不同的权重和数据读取顺序,所以每次结果可能不一样。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch

37274d74bd5215f00a2a585afaea92d1eb809284

8481c8f592b7f349aa84a1de5c171db681516edfMXNet

fa0addb6b0a7d824307feb70fcd9eae4ea9e209a

MXNet 跟 PyTorch 的不同主要在下面这几点:

8481c8f592b7f349aa84a1de5c171db681516edf不需要将输入放进 Variable, 但需要将计算放在 mx.autograd.record() 里使得后面可以对其求导
8481c8f592b7f349aa84a1de5c171db681516edf不需要每次梯度清 0,因为新梯度是写进去,而不是累加
8481c8f592b7f349aa84a1de5c171db681516edfstep 的时候 MXNet 需要给定批量大小
8481c8f592b7f349aa84a1de5c171db681516edf需要调用 asscalar() 来将多维数组变成标量。
8481c8f592b7f349aa84a1de5c171db681516edf这个样例里 MXNet 比 PyTorch 快两倍。当然大家对待这样的比较要谨慎。

下一步

8481c8f592b7f349aa84a1de5c171db681516edf更详细的 MXNet 的教程:http://zh.gluon.ai/

8481c8f592b7f349aa84a1de5c171db681516edf欢迎给我们留言哪些 PyTorch 的方便之处你希望 MXNet 应该也可以有



原文发布时间为:2018-04-3

本文作者:李沐

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”微信公众号

原文链接:【李沐】十分钟从 PyTorch 转 MXNet

相关文章
|
4月前
|
数据采集 机器学习/深度学习 存储
Pytorch葵花宝典(建议收藏)
Pytorch葵花宝典(建议收藏)
32 0
|
5月前
|
机器学习/深度学习 编译器 TensorFlow
瞎聊深度学习——TensorFlow的基本应用
瞎聊深度学习——TensorFlow的基本应用
|
存储 数据可视化 PyTorch
学懂 ONNX,PyTorch 模型部署再也不怕!
在把 PyTorch 模型转换成 ONNX 模型时,我们往往只需要轻松地调用一句 torch.onnx.export 就行了。这个函数的接口看上去简单,但它在使用上还有着诸多的“潜规则”。在这篇教程中,我们会详细介绍 PyTorch 模型转 ONNX 模型的原理及注意事项。除此之外,我们还会介绍 PyTorch 与 ONNX 的算子对应关系,以教会大家如何处理 PyTorch 模型转换时可能会遇到的算子支持问题。
3024 0
学懂 ONNX,PyTorch 模型部署再也不怕!
|
机器学习/深度学习 人工智能 Kubernetes
PyTorch重大更新再战TensorFlow,AWS也来趟深度学习框架的浑水?
刚刚,Facebook联合AWS 宣布了PyTorch的两个重大更新:TorchServe和TorchElastic。而不久前Google刚公布DynamicEmbedding。两大阵营又开战端,Facebook亚马逊各取所长联手对抗Google!
172 0
PyTorch重大更新再战TensorFlow,AWS也来趟深度学习框架的浑水?
|
机器学习/深度学习 人工智能 PyTorch
2019 深度学习框架大盘点!看 PyTorch、TensorFlow 如何强势上榜?
2019 深度学习框架大盘点!看 PyTorch、TensorFlow 如何强势上榜?
2019 深度学习框架大盘点!看 PyTorch、TensorFlow 如何强势上榜?
|
机器学习/深度学习 安全 前端开发
2020,PyTorch真的赶上TensorFlow了吗?
几天前,OpenAI 通过官方博客宣布了「全面转向 PyTorch」的消息,计划将自家平台的所有框架统一为 PyPyTorch。这一消息再次引发了社区关于两个框架优劣的讨论。作为后起之秀,PyTorch 真的已经全面赶超 TensorFlow 了吗?为了研究这个问题,数据科学家 Jeff Hale 从在线职位数量、顶会论文中的出现次数、在线搜索结果、开发者使用情况四个方面对两个框架的现状进行了调研。
2020,PyTorch真的赶上TensorFlow了吗?
|
机器学习/深度学习 人工智能 JavaScript
AlphaGo Zero你也来造一只,PyTorch实现五脏俱全| 附代码
遥想当年,AlphaGo的Master版本,在完胜柯洁九段之后不久,就被后辈AlphaGo Zero (简称狗零) 击溃了。
2047 0
|
机器学习/深度学习 TensorFlow API
Tensorflow快餐教程(12) - 用机器写莎士比亚的戏剧
如何用TFLearn,Keras等高层框架来学习自动生成莎翁的戏剧或者尼采的哲学文章
2307 0
热门文章
最新文章
推荐文章
更多