在TPU上运行PyTorch的技巧总结

简介: 在TPU上运行PyTorch的技巧总结

TPU芯片介绍

Google定制的打机器学习专用晶片称之为TPU(Tensor Processing Unit),Google在其自家称,由于TPU专为机器学习所运行,得以较传统CPU、 GPU降低精度,在计算所需的电晶体数量上,自然可以减少,也因此,可从电晶体中挤出更多效能,每秒执行更复杂、强大的机器学习模组,并加速模组的运用,使得使用者更快得到答案,Google最早是计划用FPGA的,但是财大气粗,考虑到自己的特殊应用,就招了很多牛人来做专用芯片TPU。

640.png

TPUs已经针对TensorFlow进行了优化,并且主要用于TensorFlow。但是Kaggle和谷歌在它的一些比赛中分发了免费的TPU时间,并且一个人不会简单地改变他最喜欢的框架,所以这是一个关于我在GCP上用TPU训练PyTorch模型的经验的备忘录(大部分是成功的)。

640.png

PyTorch/XLA是允许这样做的项目。它仍在积极的开发中,问题得到了解决。希望在不久的将来,运行它的体验会更加顺畅,一些bug会得到修复,最佳实践也会得到更好的交流。

https://github.com/pytorch/xla

设置

这里有两种方法可以获得TPU的使用权

GCP计算引擎虚拟机与预构建的PyTorch/XLA映像并按照PyTorch/XLA github页面上的“使用预构建的计算VM映像”部分进行设置。

或者使用最简单的方法,使用google的colab笔记本可以获得免费的tpu使用。

针对一kaggle的比赛您可以在虚拟机上使用以下代码复制Kaggle API令牌并使用它下载竞争数据。还可以使用gsutil cp将文件复制回GS bucket。

gcloudauthlogingsutilcpgs://bucket-name/kaggle-keys/kaggle.json ~/.kagglechmod600~/.kaggle/kaggle.jsonkagglecompetitionsdownload-crecursion-cellular-image-classification

除了谷歌存储之外,我还使用github存储库将数据和代码从我的本地机器传输到GCP虚拟机,然后再返回。

注意,在TPU节点上也有运行的软件版本。它必须匹配您在VM上使用的conda环境。由于PyTorch/XLA目前正在积极开发中,我使用最新的TPU版本:

640.png

使用TPU训练

让我们看看代码。PyTorch/XLA有自己的多核运行方式,由于TPUs是多核的,您希望利用它。但在你这样做之前,你可能想要把你的模型中的device = ' cuda '替换为

importtorch_xla_py.xla_modelasxm...
device=xm.xla_device()...
xm.optimizer_step(optimizer)
xm.mark_step()

仅在TPU的一个核上测试您的模型。上面代码片段中的最后两行替换了常规的optimizer.step()调用。

对于多核训练,PyTorch/XLA使用它自己的并行类。在这里的测试目录中可以找到一个使用并行训练循环的示例(https://github.com/pytorch/xla/blob/master/test/test_train_mnist.py

我想强调与它相关的以下三点。

1) DataParallel并行持有模型对象的副本(每个TPU设备一个),并以相同的权重保持同步。你可以通过访问其中一个模型进行保存,因为权重都是同步的:

torch.save(model_parallel._models[0].state_dict(), filepath)

每个并行内核必须运行相同批数量,并且只允许运行完整批。因此,每个历元在小于100%的样本下运行,剩余部分被忽略。对于数据集变换,这对于训练循环来说不是大问题,但对于推理来说却是个问题。如前所述,我只能使用单核运行进行推理。

直接在jupyter笔记本上运行的DataParallel代码对我来说非常不稳定。它可能运行一段时间,但随后会抛出系统错误、内核崩溃。运行它作为一个脚本似乎是稳定的,所以我们使用以下命令进行转换

!jupyternbconvert--toscriptMyModel.ipynb!pythonMyModel.py

工作的局限性

PyTorch/XLA的设计导致了一系列PyTorch功能的限制。事实上,这些限制一般适用于TPU设备,并且显然也适用于TensorFlow模型,至少部分适用。具体地说

张量形状在迭代之间是相同的,这也限制了mask的使用。

应避免步骤之间具有不同迭代次数的循环。

不遵循准则会导致(严重)性能下降。不幸的是,在损失函数中,我需要同时使用掩码和循环。就我而言,我将所有内容都移到了CPU上,现在速度要快得多。只需对所有张量执行 my_tensor.cpu().detach().numpy() 即可。当然,它不适用于需要跟踪梯度的张量,并且由于迁移到CPU而导致自身速度降低。

性能比较

我的Kaggle比赛队友Yuval Reina非常同意分享他的机器配置和训练速度,以便在本节中进行比较。我还为笔记本添加了一列(这是一台物理机),但它与这些重量级对象不匹配,并且在其上运行的代码未针对性能进行优化。

网络的输入是具有6个通道的512 x 512图像。我们测量了在训练循环中每秒处理的图像,根据该指标,所描述的TPU配置要比Tesla V100好得多。

640.png

如上所述(不带DataParallel)的单核TPU的性能为每秒26张图像,比所有8个核在一起的速度慢约4倍。

由于竞争仍在进行中,我们没有透露Yuval使用的体系结构,但其大小与resnet50并没有太大差异。但是请注意,由于我们没有运行相同的架构,因此比较是不公平的。

尝试将训练映像切换到GCP SSD磁盘并不能提高性能。

总结

总而言之,我在PyTorch / XLA方面的经验参差不齐。我遇到了多个错误/工件(此处未全部提及),现有文档和示例受到限制,并且TPU固有的局限性对于更具创意的体系结构而言可能过于严格。另一方面,它大部分都可以工作,并且当它工作时性能很好。

最后,最重要的一点是,别忘了在完成后停止GCP VM!

image.png

目录
相关文章
|
6月前
|
机器学习/深度学习 存储 PyTorch
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
240 0
|
5月前
|
Serverless PyTorch 文件存储
函数计算产品使用问题之如何使用并运行PyTorch
函数计算产品作为一种事件驱动的全托管计算服务,让用户能够专注于业务逻辑的编写,而无需关心底层服务器的管理与运维。你可以有效地利用函数计算产品来支撑各类应用场景,从简单的数据处理到复杂的业务逻辑,实现快速、高效、低成本的云上部署与运维。以下是一些关于使用函数计算产品的合集和要点,帮助你更好地理解和应用这一服务。
|
机器学习/深度学习 并行计算 PyTorch
基于Pytorch使用GPU运行模型方法及可能出现的问题解决方法
基于Pytorch使用GPU运行模型方法及可能出现的问题解决方法
2206 0
基于Pytorch使用GPU运行模型方法及可能出现的问题解决方法
|
并行计算 PyTorch 算法框架/工具
基于Pytorch运行中出现RuntimeError: Not compiled with CUDA support此类错误解决方案
基于Pytorch运行中出现RuntimeError: Not compiled with CUDA support此类错误解决方案
1280 0
基于Pytorch运行中出现RuntimeError: Not compiled with CUDA support此类错误解决方案
|
并行计算 PyTorch 算法框架/工具
pytorch在GPU上运行模型实现并行计算
pytorch在GPU上运行模型实现并行计算
190 0
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch基于迁移学习的VGG卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕VGG神经网络的注释 两个基本一样 只是这个网络是迁移过来的
Pytorch基于迁移学习的VGG卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕VGG神经网络的注释 两个基本一样 只是这个网络是迁移过来的
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch手撕VGG神经网络(CIFAR10数据集)-详细注释-完整代码可直接运行-
Pytorch手撕VGG神经网络(CIFAR10数据集)-详细注释-完整代码可直接运行-
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch基于迁移学习的Alexnet卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕Alexnet神经网络的注释 两个基本一样 只是这个网络是迁移过来的
Pytorch基于迁移学习的Alexnet卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕Alexnet神经网络的注释 两个基本一样 只是这个网络是迁移过来的
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch手撕Alexnet神经网络(CIFAR10数据集)-详细注释-完整代码可直接运行
Pytorch手撕Alexnet神经网络(CIFAR10数据集)-详细注释-完整代码可直接运行
|
27天前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
123 2