【李沐】十分钟从 PyTorch 转 MXNet

简介: PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。MXNet通过ndarray和 gluon模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法。

PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。例如 Caffe2 最近就并入了 PyTorch。

可能大家不是特别知道的是,MXNet 通过 ndarray 和 gluon 模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法

89e0e8d5de21311740959c69f9ae2fe0258d52af

安装

PyTorch 默认使用 conda 来进行安装,例如

03192dec910f50e049d5fecb3109e8b09f6cdf9b

而 MXNet 更常用的是使用 pip。我们这里使用了 --pre 来安装 nightly 版本

83d9786fd8b0f46bc693765c60c2e0544ec118a7

多维矩阵

对于多维矩阵,PyTorch 沿用了 Torch 的风格称之为 tensor,MXNet 则追随了 NumPy 的称呼 ndarray。下面我们创建一个两维矩阵,其中每个元素初始化成 1。然后每个元素加 1 后打印。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

b472e4f3dada3709d53edf6608ab47f322089ca2

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

28436e80f4887a6ef0ba21b7dedafad23e823410

忽略包名的不一样的话,这里主要的区别是 MXNet 的形状传入参数跟 NumPy 一样需要用括号括起来。

模型训练

下面我们看一个稍微复杂点的例子。这里我们使用一个多层感知机(MLP)来在 MINST 这个数据集上训练一个模型。我们将其分成 4 小块来方便对比。

读取数据

这里我们下载 MNIST 数据集并载入到内存,这样我们之后可以一个一个读取批量。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

a3657b3dbcca68c3b62521edd3f0dd3082a15389

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d50f77b5a334e2bd2072180a29c72c9a74ed18dc

这里的主要区别是 MXNet 使用 transform_first 来表明数据变化是作用在读到的批量的第一个元素,既 MNIST 图片,而不是第二个标号元素。

定义模型

下面我们定义一个只有一个单隐层的 MLP 。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

01818faab6a9ae66f6daae74be169b64c894f344

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

8e95eac5ee43682f1ca882e782de98079be13d9e

我们使用了 Sequential 容器来把层串起来构造神经网络。这里 MXNet 跟 PyTorch 的主要区别是:

8481c8f592b7f349aa84a1de5c171db681516edf 不需要指定输入大小,这个系统会在后面自动推理得到
8481c8f592b7f349aa84a1de5c171db681516edf 全连接和卷积层可以指定激活函数
8481c8f592b7f349aa84a1de5c171db681516edf需要创建一个  name_scope  的域来给每一层附上一个独一无二的名字,这个在之后读写模型时需要
8481c8f592b7f349aa84a1de5c171db681516edf 我们需要显示调用模型初始化函数。


大家知道 Sequential 下只能神经网络只能逐一执行每个层。PyTorch 可以继承 nn.Module 来自定义 forward 如何执行。同样,MXNet 可以继承 nn.Block 来达到类似的效果。

损失函数和优化算法

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

483451f3193b8143e4fe7c180da0a03baff4fc71

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d126effd55a80c7df82aa0e96cb0f5cf7f1c5785

这里我们使用交叉熵函数和最简单随机梯度下降并使用固定学习率 0.1

训练

最后我们实现训练算法,并附上了输出结果。注意到每次我们会使用不同的权重和数据读取顺序,所以每次结果可能不一样。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch

37274d74bd5215f00a2a585afaea92d1eb809284

8481c8f592b7f349aa84a1de5c171db681516edfMXNet

fa0addb6b0a7d824307feb70fcd9eae4ea9e209a

MXNet 跟 PyTorch 的不同主要在下面这几点:

8481c8f592b7f349aa84a1de5c171db681516edf不需要将输入放进  Variable , 但需要将计算放在  mx.autograd.record()  里使得后面可以对其求导
8481c8f592b7f349aa84a1de5c171db681516edf 不需要每次梯度清 0,因为新梯度是写进去,而不是累加
8481c8f592b7f349aa84a1de5c171db681516edf step  的时候 MXNet 需要给定批量大小
8481c8f592b7f349aa84a1de5c171db681516edf需要调用  asscalar()  来将多维数组变成标量。
8481c8f592b7f349aa84a1de5c171db681516edf 这个样例里 MXNet 比 PyTorch 快两倍。当然大家对待这样的比较要谨慎。

下一步

8481c8f592b7f349aa84a1de5c171db681516edf 更详细的 MXNet 的教程:http://zh.gluon.ai/

8481c8f592b7f349aa84a1de5c171db681516edf欢迎给我们留言哪些 PyTorch 的方便之处你希望 MXNet 应该也可以有



原文发布时间为:2018-04-3

本文作者:李沐

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”微信公众号

原文链接:【李沐】十分钟从 PyTorch 转 MXNet

相关文章
|
机器学习/深度学习 人工智能 PyTorch
李沐动手学深度学习pytorch :问题:找不到d2l包,No module named ‘d2l’
李沐动手学深度学习pytorch :问题:找不到d2l包,No module named ‘d2l’
665 0
|
机器学习/深度学习 存储 算法
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(下)
【李沐:动手学深度学习pytorch版】第3章:线性神经网络
407 0
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(下)
|
机器学习/深度学习 算法 数据可视化
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(上)
【李沐:动手学深度学习pytorch版】第3章:线性神经网络
216 0
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(上)
|
机器学习/深度学习 存储 数据可视化
【李沐:动手学深度学习pytorch版】第2章:预备知识(下)
【李沐:动手学深度学习pytorch版】第2章:预备知识
414 0
【李沐:动手学深度学习pytorch版】第2章:预备知识(下)
|
机器学习/深度学习 数据采集 算法
【李沐:动手学深度学习pytorch版】第2章:预备知识(上)
【李沐:动手学深度学习pytorch版】第2章:预备知识
286 0
|
机器学习/深度学习 人工智能 数据挖掘
李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!
李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!
2120 0
李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!
|
2月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
297 2
|
2月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
63 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
2月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
95 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
3月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
182 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型