IOS平台TensorFlow实践:实际应用教程(附源码)(二)

简介: 本文是《从零到一:IOS平台TensorFlow入门及应用》系列二,介绍IOS平台TensorFlow的安装,以及将系列一中开发的模型在IOS app上的实际应用

更多深度文章,请关注云计算频道:https://yq.aliyun.com/cloud

 

作者简介:

1ade0f07bb4ebf3abbcb0fe6bf011337d5b397e3

MATTHIJS HOLLEMANS

荷兰人,独立开发者,专注于底层编码,GPU优化和算法研究。目前研究方向为IOS上的深度学习及其在APP上的应用。

推特地址:https://twitter.com/mhollemans

邮件地址:mailto:matt@machinethink.net

github地址:https://github.com/hollance

个人博客:http://machinethink.net/

 

  上一节中,我们介绍了在如何用TnesorFlow创建一个逻辑斯蒂回归分类器,接下来介绍如何将这个分类器运用在实际的app中。

在IOS上安装TensorFlow

  前面已经训练好模型,下面创建一个利用TensorFlow C++ 库和这个模型的app。坏消息是你不得不从源构建TensorFlow,还需要使用Java环境;好消息是这个过程相对简单。完整的指导在这里,但是下面几步很重要(测试环境为TensorFlow 1.0)。

      首先你得安装好Xcode 8,确定开发者目录指向你安装Xcode的位置并且已经被激活。(如果你在安装Xcode之前已经安装了Homebrew,这可能会指向错误的地址,导致TensorFlow安装失败):

sudo xcode-select -s /Applications/Xcode.app/Contents/Developer

我们将使用名为bazel的工具来安装TensorFlow。先使用Homebrew安装所需要的包:

brew cask install java
brew install bazel
brew install automake
brew install libtool

完成之后,你需要克隆TensorFlow GitHub仓库。注意,一定要保存在没有空格的路径下,否则bazel会拒绝构建。我是克隆到我的主目录下:

cd /Users/matthijs
git clone https://github.com/tensorflow/tensorflow -b r1.0
   -b r1.0 表明克隆的是 r1.0 分支 。当然你也可以随时获取 最新的分支 或者主分支。

Note:MacOS Sierra 上,运行下面的配置脚本报错了,我只能克隆主分支来代替。在OS X EI Caption 上使用r1.0分支就不会有任何问题。

一旦GitHub仓库克隆完毕,你就需要运行配置脚本(configure script:

cd tensorflow
./configure

这里有些地方可能需要你自行配置,比如:

Please specify the location of python. [Default is /usr/bin/python]:	

我写的是/usr/local/bin/python3,因为我使用的是Python 3.6。如果你选择默认选项,就会使用Python 2.7来创建TensorFlow

Please specify optimization flags to use during compilation [Default is 
-march=native]:

这里只需要按Enter键。后面两个问题,只需要选择n(表示 no)。当询问使用哪个Python库时,按Enter键选择默认选项(应该是Python 3.6 库)。剩下的问题都选择n。随后,这个脚本将会下载大量的依赖项并准备构建TensorFlow所需的一切。


构建静态库

有两种方法构建TensorFlow:1.在Mac上使用bazel工具;2.在IOS上,使用Makefile。我们是在IOS上构建,所以选择第2种方式。不过因为会用到一些工具,也会用到第一种方式。

在TensorFlow的目录中执行以下脚本:

tensorflow/contrib/makefile/build_all_ios.sh

这个脚本首先会下载一些依赖项,然后开始构建。一切顺利的话,它会创建三个链入你的app的静态库:libtensorflow-core.a libprotobuf.a libprotobuf-lite.a

还有另外两个工具需要构建,在终端运行如下两行命令:

bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/python/tools:optimize_for_inference

Note: 这个过程至少需要20分钟,因为它会从头开始构建TensorFlow(本次使用的是bazel)。如果遇到问题,请参考官方指导


在Mac上构建TensorFlow

这一步是可选的,不过因为已经安装了所有需要的包,在Mac上构建TensorFlow就没那么困难了。使用pip包代替官方的TensorFlow包进行安装。

现在你就可以创建一个自定义的TensorFlow版本。例如,当运行train.py脚本时,如果出现“The TensorFlow library wasn’t compiled to use SSE4.1 instructions”提醒,你可以编译一个允许这些指令的TensorFlow版本。

在终端运行如下命令来构建TensorFlow:

bazel build --copt=-march=native -c opt //tensorflow/tools/pip_package:build_pip_package

bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
-march=native选项添加了对SSE,AVX,AVX2,FMA等指令的支持(如果这些指令能够在你的CPU上运行)。然后安装包:
pip3 uninstall tensorflow
sudo -H pip3 install /tmp/tensorflow_pkg/tensorflow-1.0.0-XXXXXX.whl

更多详细指令请参考TensorFlow网站


固化计算图

我们将要创建的app会载入之前训练好的模型,并作出预测。之前在train.py中,我们将图保存到了 /tmp/voice/graph.pb文件中。但是你不能在IOS app中直接载入这个计算图,因为图中的部分操作是TensorFlow C++库并不支持。所以就需要用到上面我们构建的那两个工具。

freeze_graph将包含训练好的wbgraph.pb和检查点文件合成为一个文件,并移除IOS不支持的操作。在终端运行TensorFlow目录下的这个工具:

bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/voice/graph.pb --input_checkpoint=/tmp/voice/model \
--output_node_names=model/y_pred,inference/inference --input_binary \
--output_graph=/tmp/voice/frozen.pb

最终输出/tmp/voice/frozen.pb文件,只包含得到y_predinference的节点,不包括用来训练的节点。freeze_graph也将权重保存进了文件,就不用再单独载入。

optimize_for_inference工具进一步简化了可计算图,它以frozen.pb作为输入,以/tmp/voice/inference.pb作为输出。这就是我们将嵌入IOS app中的文件,按如下方式运行这个工具:

bazel-bin/tensorflow/python/tools/optimize_for_inference \
--input=/tmp/voice/frozen.pb --output=/tmp/voice/inference.pb \
--input_names=inputs/x --output_names=model/y_pred,inference/inference \
--frozen_graph=True
 
  


IOS app

     你可以在VoiceTensorFlow 文件夹下找到这个app。用Xcode打开这个项目,有几处需要注意:

            1. App是用C++写的(源文件后缀名为.mm),因为TensorFlow没有Swift API,只有C++的;

           2.inference.pb文件已经包含在项目中,如果有需要的话,你可以用你自己的inference.pb文件替换掉;

           3.这个app使用了Accelerate框架;

           4.这个app使用了已经编译好的静态库。

     在项目设置界面打开构建参数标签页,在Other Linker Flags,你会看见如下信息:

/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf-lite.a 

/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf.a 

-force_load /Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/lib/
libtensorflow-core.a

      除非你的名字也是“matthijs”,否则需要用你克隆的TensorFlow存放的路径进行替换。(TensorFlow出现了两次,所以文件名为tensorflow/tensorflow/...)。

      Note: 你也可以将这3个文件拷贝到项目文件夹中,就不必担心路径出错了。我之所以没有这样做,是因为libtensorflow-core.a 文件有440MB大。

     再检查Header Search Paths,目前的设置是:

~/tensorflow
~/tensorflow/tensorflow/contrib/makefile/downloads 
~/tensorflow/tensorflow/contrib/makefile/downloads/eigen 
~/tensorflow/tensorflow/contrib/makefile/downloads/protobuf/src 
~/tensorflow/tensorflow/contrib/makefile/gen/proto

     然后你还要将这些路径更新到您克隆仓库的位置,还有些build settings我也做了修改:

        1.Enable Bitcode: No

        2.Warnings / Documentation Comments: No

        3.Warnings / Deprecated Functions: No

     目前TensorFlow并不支持字节码,所以我禁用了这个功能。我也关闭了警告功能,否则你编译app时会遇到很多问题。(虽然你还是会遇到值转换问题的警告,禁止这个警告功能也没毛病)。

     完成Other Linker Flags和 the Header Search Paths的设置之后,就可以构建并运行app了。下面看一下这个使用TensorFlow的IOS app是如何工作的。


使用Tensorflow C++ API

IOS上的TensorFlow使用C++写的,不过需要你写的C++代码有限,通常,你只需要做下面几件事:

1.从.pb文件中载入计算图和权重;

2.使用图创建会话;

3.将数据放入输入张量;

4.在图上运行一个或多个节点;

5.得到输出张量结果。

在演示的APP中,这些都是写在ViewController.mm中。首先载入图:

- (BOOL)loadGraphFromPath:(NSString *)path
{
    auto status = ReadBinaryProto(tensorflow::Env::Default(), 
                                  path.fileSystemRepresentation, &graph);
    if (!status.ok()) {
        NSLog(@"Error reading graph: %s", status.error_message().c_str());
        return NO;
    }
    return YES;
}

Xcode项目包含在 graph.pb上运行freeze_graph 和optimize_for_inference工具得到的inference.pb图。如果你试图载入graph.pb,会报错:

Error adding graph to session: No OpKernel was registered to support Op 
'L2Loss' with these attrs.  Registered devices: [CPU], Registered kernels:
  <no registered kernels>

[[Node: loss-function/L2Loss = L2Loss[T=DT_FLOAT](model/W/read)]]

      这个C++ API 支持的操作要比Python API少。这里他说的是损失函数节点中L2Loss操作在IOS上不支持。这就是为什么我们要使用freeze_graph简化图。

在载入图之后,创建会话:

- (BOOL)createSession
{
    tensorflow::SessionOptions options;
    auto status = tensorflow::NewSession(options, &session);
    if (!status.ok()) {
        NSLog(@"Error creating session: %s", 
                status.error_message().c_str());
        return NO;
    }

    status = session->Create(graph);
    if (!status.ok()) {
        NSLog(@"Error adding graph to session: %s", 
                status.error_message().c_str());
        return NO;
    }
    return YES;
}

会话创建好之后,就可以进行预测了。predict:方法需要一个包含20个浮点数的元组,代表声学特征,然后传入图中,该方法如下所示:

- (void)predict:(float *)example {
    tensorflow::Tensor x(tensorflow::DT_FLOAT, 
                         tensorflow::TensorShape({ 1, 20 }));

    auto input = x.tensor<float, 2>();
    for (int i = 0; i < 20; ++i) {
        input(0, i) = example[i];
    }

首先定义张量x作为输入数据。这个张量维度为{1, 20},因为它一次接收一个样本,每个样本有20个特征。然后从float *数组将数据拷贝至张量中。

接下来运行会话:

    std::vector<std::pair<std::string, tensorflow::Tensor>> inputs = {
        {"inputs/x-input", x}
    };

    std::vector<std::string> nodes = {
        {"model/y_pred"},
        {"inference/inference"}
    };

    std::vector<tensorflow::Tensor> outputs;

    auto status = session->Run(inputs, nodes, {}, &outputs);
    if (!status.ok()) {
        NSLog(@"Error running model: %s", status.error_message().c_str());
        return;
    }

运行如下代码:

pred, inf = sess.run([y_pred, inference], feed_dict={x: example})

这条代码看起来并没有Python版的简洁。我们创建了feed字典,运行的节点列表,以及保存结果的向量。最后,打印结果:

    auto y_pred = outputs[0].tensor<float, 2>();
    NSLog(@"Probability spoken by a male: %f%%", y_pred(0, 0));

    auto isMale = outputs[1].tensor<float, 2>();
    if (isMale(0, 0)) {
        NSLog(@"Prediction: male");
    } else {
        NSLog(@"Prediction: female");
    }
}

本来只需要运行inference节点就可以得到男性/女性的预测结果,但我还想看计算出来的概率,所以后面运行了y_pred节点。


运行app

你可以在iphone模拟器或者设备上运行这个app。在模拟器上,你可能会得到诸如 “The TensorFlow library wasn’t compiled to use SSE4.1 instructions”的消息,但是在设备上则不会报错。

app会做出来两种预测:男性/女性。运行这个app,你会看到下面的输出,它先打印出图中的节点:

Node count: 9
Node 0: Placeholder 'inputs/x-input'
Node 1: Const 'model/W'
Node 2: Const 'model/b'
Node 3: MatMul 'model/MatMul'
Node 4: Add 'model/add'
Node 5: Sigmoid 'model/y_pred'
Node 6: Const 'inference/Greater/y'
Node 7: Greater 'inference/Greater'
Node 8: Cast 'inference/inference'

这个图只包含进行预测的节点,并不需要训练相关的节点。然后就会输出结果:

Probability spoken by a male: 0.970405%
Prediction: male

Probability spoken by a male: 0.005632%
Prediction: female

如果用Python脚本测试同样的数据,会得到相同的答案。


IOS上TensorFlow的优缺点

优点:

1. 一个工具搞定所有事。你可以使用TensorFlow训练模型并进行预测。不需要将计算图移植到其他的API,如BNNS或者Metal。另一方面,你只需要将少量Python代码移植到C++代码;

2.TensorFlow有比BNNSMetal更多的特性;

3.你可以在模拟器上运行。Metal总是要在设备上运行。

缺点:

1.目前不支持GPUTensorFlow使用 Accelerate 框架能够发挥CPU向量指令的优势,原始速度比不上Metal

2.TensorFlow API使用C++写的,所以你不得不写一些C++代码,并不能直接使用Swift编写。

3.相比于Python APIC++ API有限。这意味着你不能在设备上进行训练,因为不支持反向传播中用到的自动梯度计算。

4.TensorFlow静态库增加了app包大概40MB的空间。通过减少支持操作的数量,可以减少这个额外空间,不过这很麻烦。而且,这还不包括模型的大小。

目前,我个人并不提倡在IOS上使用TensorFlow。优点并没有超过缺点,作为一款有潜力的产品,谁知道未来会怎样呢?

Note: 如果决定在你的IOS app中使用TensorFlow,那你必须知道别人很容易从app安装包中拷贝图的.pb文件窃取你的模型。由于固化的图文件包含模型参数和图定义,反编译简直轻而易举。如果你的模型具有竞争优势,你可能需要做出预案防止你的机密被窃取。


使用Metal在GPU上训练

       IOS app上使用TensorFlow的一个弊端是他是运行在CPU上的。对于数据和模型较小的项目,TensorFlow能够满足我们的需求。但是对于更大的数据集,特别是深度学习,你就必须要使用GPU代替CPU,在IOS上就意味着要使用Metal。

  训练后,我们需要将学习到的参数wb保存成Metal能够读取的格式。其实只要以二进制格式保存为浮点数列表就可以了。

下面的Python脚本export_weights.py和之前载入图定义和检查点的test.py很相似,如下:

    W.eval().tofile("W.bin")
    b.eval().tofile("b.bin")

W.eval()计算w目前的值,并以返回Numpy数组(和sess.run(W)作用是一样的)。然后使用tofile()Numpy数组保存为二进制文件。

你可以在源码VoiceMetal文件夹下发现Xcode项目,使用Swift编写的。

之前我们使用下面的公式计算逻辑斯蒂回归:

y_pred = sigmoid((W * x) + b)

这和神经网络中全连接层进行的计算相同,为了实现Metal版分类器,我们只需要使用MPSCNN Fully Connected 层。首先将W.binb.bin载入到Data对象:

let W_url = Bundle.main.url(forResource: "W", withExtension: "bin")
let b_url = Bundle.main.url(forResource: "b", withExtension: "bin")
let W_data = try! Data(contentsOf: W_url!)
let b_data = try! Data(contentsOf: b_url!)

然后创建全连接层:

let sigmoid = MPSCNNNeuronSigmoid(device: device)

let layerDesc = MPSCNNConvolutionDescriptor(
                   kernelWidth: 1, kernelHeight: 1, 
                   inputFeatureChannels: 20, outputFeatureChannels: 1, 
                   neuronFilter: sigmoid)

W_data.withUnsafeBytes { W in
  b_data.withUnsafeBytes { b in
    layer = MPSCNNFullyConnected(device: device, 
               convolutionDescriptor: layerDesc, 
               kernelWeights: W, biasTerms: b, flags: .none)
  }
}

因为输入是20个数字,我设计了作用于一个1x1的有20个输入信道(input channels)的全连接层。预测结果y_pred是一个数字,所以全连接层只有一个输出信道。输入和输出数据放在MPSImage 中:

let inputImgDesc = MPSImageDescriptor(channelFormat: .float16, 
                       width: 1, height: 1, featureChannels: 20)
let outputImgDesc = MPSImageDescriptor(channelFormat: .float16, 
                       width: 1, height: 1, featureChannels: 1)

inputImage = MPSImage(device: device, imageDescriptor: inputImgDesc)
outputImage = MPSImage(device: device, imageDescriptor: outputImgDesc)

app上的TensorFlow一样,这里也有一个predict 方法,这个方法以组成一条样本的20个浮点数作为输入。下面是完整的方法:

func predict(example: [Float]) {
  convert(example: example, to: inputImage)

  let commandBuffer = commandQueue.makeCommandBuffer()
  layer.encode(commandBuffer: commandBuffer, sourceImage: inputImage, 
               destinationImage: outputImage)
  commandBuffer.commit()
  commandBuffer.waitUntilCompleted()

  let y_pred = outputImage.toFloatArray()
  print("Probability spoken by a male: \(y_pred[0])%")

  if y_pred[0] > 0.5 {
    print("Prediction: male")
  } else {
    print("Prediction: female")
  }
}

和运行session的结果是一样的。convert(example:to:)toFloatArray()方法加载和输出MPSImage 对象的辅助函数。

你需要在设备上运行这个app,因为模拟器不支持Metal。输出结果如下:

Probability spoken by a male: 0.970215%
Prediction: male

Probability spoken by a male: 0.00568771%
Prediction: female

注意到这些概率和用TensorFlow预测到的概率不完全相同,这是因为Metal使用16位浮点数,但结果相当接近。


版权许可

本文所用的数据集是 Kory Becker制作的,在 Kaggle.com下载,也参考了Kory的博文源码。其他人也写过IOS上TensorFlow相关的一些东西。从这些文章和代码中我受益匪浅:

1.Getting Started with Deep MNIST and TensorFlow on iOS by Matt Rajca

2.Speeding Up TensorFlow with Metal Performance Shaders also by Matt Rajca

3.tensorflow-cocoa-example by Aaron Hillegass

4.TensorFlow iOS Examples in the TensorFlow repository


以上为译文

本文由北邮@爱可可-爱生活 老师推荐,阿里云云栖社区组织翻译。

文章原标题《Getting started with TensorFlow on iOS》,由Matthijs Hollemans发布。

译者:李烽 ;审校:董昭男

文章为简译,更为详细的内容,请查看原文。中文译制文档见附件。


相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
1月前
|
机器学习/深度学习 TensorFlow API
机器学习实战:TensorFlow在图像识别中的应用探索
【10月更文挑战第28天】随着深度学习技术的发展,图像识别取得了显著进步。TensorFlow作为Google开源的机器学习框架,凭借其强大的功能和灵活的API,在图像识别任务中广泛应用。本文通过实战案例,探讨TensorFlow在图像识别中的优势与挑战,展示如何使用TensorFlow构建和训练卷积神经网络(CNN),并评估模型的性能。尽管面临学习曲线和资源消耗等挑战,TensorFlow仍展现出广阔的应用前景。
67 5
|
4月前
|
自然语言处理 C# 开发者
Uno Platform多语言开发秘籍大公开:轻松驾驭全球用户,一键切换语言,让你的应用成为跨文化交流的桥梁!
【8月更文挑战第31天】Uno Platform 是一个强大的开源框架,允许使用 C# 和 XAML 构建跨平台的原生移动、Web 和桌面应用程序。本文详细介绍如何通过 Uno Platform 创建多语言应用,包括准备工作、设置多语言资源、XAML 中引用资源、C# 中加载资源以及处理语言更改。通过简单的步骤和示例代码,帮助开发者轻松实现应用的国际化。
47 1
|
4月前
|
机器学习/深度学习 存储 前端开发
实战揭秘:如何借助TensorFlow.js的强大力量,轻松将高效能的机器学习模型无缝集成到Web浏览器中,从而打造智能化的前端应用并优化用户体验
【8月更文挑战第31天】将机器学习模型集成到Web应用中,可让用户在浏览器内体验智能化功能。TensorFlow.js作为在客户端浏览器中运行的库,提供了强大支持。本文通过问答形式详细介绍如何使用TensorFlow.js将机器学习模型带入Web浏览器,并通过具体示例代码展示最佳实践。首先,需在HTML文件中引入TensorFlow.js库;接着,可通过加载预训练模型如MobileNet实现图像分类;然后,编写代码处理图像识别并显示结果;此外,还介绍了如何训练自定义模型及优化模型性能的方法,包括模型量化、剪枝和压缩等。
69 1
|
4月前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow和PyTorch的实际应用比较
TensorFlow和PyTorch的实际应用比较
|
4月前
|
持续交付 测试技术 jenkins
JSF 邂逅持续集成,紧跟技术热点潮流,开启高效开发之旅,引发开发者强烈情感共鸣
【8月更文挑战第31天】在快速发展的软件开发领域,JavaServer Faces(JSF)这一强大的Java Web应用框架与持续集成(CI)结合,可显著提升开发效率及软件质量。持续集成通过频繁的代码集成及自动化构建测试,实现快速反馈、高质量代码、加强团队协作及简化部署流程。以Jenkins为例,配合Maven或Gradle,可轻松搭建JSF项目的CI环境,通过JUnit和Selenium编写自动化测试,确保每次构建的稳定性和正确性。
65 0
|
4月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
90 0
|
4月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
111 0
|
4月前
|
Java Spring Apache
Spring Boot邂逅Apache Wicket:一次意想不到的完美邂逅,竟让Web开发变得如此简单?
【8月更文挑战第31天】Apache Wicket与Spring Boot的集成提供了近乎无缝的开发体验。Wicket以其简洁的API和强大的组件化设计著称,而Spring Boot则以开箱即用的便捷性赢得开发者青睐。本文将指导你如何在Spring Boot项目中引入Wicket,通过简单的步骤完成集成配置。首先,创建一个新的Spring Boot项目并在`pom.xml`中添加Wicket相关依赖。
135 0
|
4月前
|
机器学习/深度学习 人工智能 算法
深入探索TensorFlow在强化学习中的应用:从理论到实践构建智能游戏AI代理
【8月更文挑战第31天】强化学习作为人工智能的一个重要分支,通过智能体与环境的互动,在不断试错中学习达成目标。本文介绍如何利用TensorFlow构建高效的强化学习模型,并应用于游戏AI。智能体通过执行动作获得奖励或惩罚,旨在最大化长期累积奖励。TensorFlow提供的强大工具简化了复杂模型的搭建与训练,尤其适用于处理高维数据。通过示例代码展示如何创建并训练一个简单的CartPole游戏AI,证明了该方法的有效性。未来,这项技术有望拓展至更复杂的应用场景中。
56 0
|
4月前
|
机器学习/深度学习 TensorFlow 数据处理
分布式训练在TensorFlow中的全面应用指南:掌握多机多卡配置与实践技巧,让大规模数据集训练变得轻而易举,大幅提升模型训练效率与性能
【8月更文挑战第31天】本文详细介绍了如何在Tensorflow中实现多机多卡的分布式训练,涵盖环境配置、模型定义、数据处理及训练执行等关键环节。通过具体示例代码,展示了使用`MultiWorkerMirroredStrategy`进行分布式训练的过程,帮助读者更好地应对大规模数据集与复杂模型带来的挑战,提升训练效率。
113 0