TensorFlow中的那些高级API-阿里云开发者社区

开发者社区> 【方向】> 正文

TensorFlow中的那些高级API

简介: 在这篇文章中,我们将看到一个使用了最新高级构件的例子,包括Estimator(估算器)、Experiment(实验)和Dataset(数据集)。值得注意的是,你可以独立地使用Experiment和Dataset。不妨进来看看作者是如何玩转这些高级API的。
+关注继续查看

1.png

TensorFlow拥有很多库,比如KerasTFLearnSonnet,对于模型训练来说,使用这些库比使用低级功能更简单。尽管Keras的API目前正在添加到TensorFlow中去,但TensorFlow本身就提供了一些高级构件,而且最新的1.3版本中也引入了一些新的构件。

在这篇文章中,我们将看到一个使用了这些最新的高级构件的例子,包括Estimator(估算器)、Experiment(实验)和Dataset(数据集)。值得注意的是,你可以独立地使用Experiment和Dataset。我在这里假设你已经了解TensorFlow的基础知识;如果没有的话,那么TensorFlow官网上提供的教程值得学习。

2.png
Experiment、Estimator和DataSet框架以及它们之间的交互。

我们在本文中将使用MNIST作为数据集。这是一个使用起来很简单的数据集,可以从TensorFlow官网获取到。你可以在这个gist中找到完整的代码示例。使用这些框架的其中一个好处是,我们不需要直接处理会话

Estimator(估算器)类

Estimator类代表了一个模型,以及如何对这个模型进行训练和评估。我们可以像下面这段代码创建一个Estimator:

return tf.estimator.Estimator(
    model_fn=model_fn,  # First-class function
    params=params,  # HParams
    config=run_config  # RunConfig
)

要创建Estimator,需要传入一个模型函数、一组参数和一些配置。

  • 传入的参数应该是模型超参数的一个集合。这可以是一个dictionary,但是我们将在这个例子中把它表示成一个HParams对象,就像namedtuple一样。
  • 传入的配置用于指定如何运行训练和评估,以及在哪里存储结果。这个配置是一个RunConfig对象,该对象会把模型运行环境相关的信息告诉Estimator。
  • 模型函数是一个Python函数,它根据给定的输入构建模型。

模型函数

模型函数是一个Python函数,并作为一级函数传递给Estimator。稍后我们会看到,TensorFlow在其他地方也使用了一级函数。将模型表示为一个函数的好处是可以通过实例化函数来多次创建模型。模型可以在训练过程中用不同的输入重新创建,例如,在训练过程中运行验证测试。

模型函数把输入特征作为参数,将相应的标签作为张量。它也能以某种方式来告知用户模型是在训练、评估或是在执行推理。模型函数的最后一个参数是超参数集合,它们与传递给Estimator的超参数集合相同。模型函数返回一个EstimatorSpec对象,该对象定义了一个完整的模型。

EstimatorSpec对象用于对操作进行预测、损失、训练和评估,因此,它定义了一个用于训练、评估和推理的完整的模型图。由于EstimatorSpec只可用于常规的TensorFlow操作,因此,我们可以使用像TF-Slim这样的框架来定义模型。

Experiment(实验)类

Experiment类定义了如何训练模型,它与Estimator完美地集成在一起。我们可以像如下代码创建一个Experiment对象:

experiment = tf.contrib.learn.Experiment(
    estimator=estimator,  # Estimator
    train_input_fn=train_input_fn,  # First-class function
    eval_input_fn=eval_input_fn,  # First-class function
    train_steps=params.train_steps,  # Minibatch steps
    min_eval_frequency=params.min_eval_frequency,  # Eval frequency
    train_monitors=[train_input_hook],  # Hooks for training
    eval_hooks=[eval_input_hook],  # Hooks for evaluation
    eval_steps=None  # Use evaluation feeder until its empty
)

以下几种情况会把Experiment对象作为输入:

  • 一个estimator(例如我们上面定义的)。
  • 作为一级函数训练和评估数据。这里使用了与前面提到的模型函数相同的概念。如果需要的话,通过传入函数而不是操作,可以重新创建输入图。稍后我们还会谈到这个。
  • 训练和评估hook(钩子)。钩子可用于保存或监视特定的内容,或者在图或会话中设置某些操作。例如,我们将其传入到操作中,帮助初始化数据加载器。
  • 描述需要训练多久以及何时评估的各种参数。

一旦定义了experiment,我们就可以像下面这段代码那样使用learn_runner.run来运行它训练和评估模型:

learn_runner.run(
    experiment_fn=experiment_fn,  # First-class function
    run_config=run_config,  # RunConfig
    schedule="train_and_evaluate",  # What to run
    hparams=params  # HParams
)

与模型函数和数据函数一样,learn_runner将一个创建experiment的函数作为参数传入。

Dataset(数据集)类

我们将使用Dataset类和相应的Iterator来表示数据的训练和评估,以及创建在训练过程中迭代数据的数据馈送器。 在本示例中,我们将使用在Tensorflow中可用的MNIST数据,并为其构建一个Dataset包装。例如,我们将把训练输入数据表示为:

# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
    """Return the input function to get the training data.
    Args:
        batch_size (int): Batch size of training iterator that is returned
                          by the input function.
        mnist_data (Object): Object holding the loaded mnist data.
    Returns:
        (Input function, IteratorInitializerHook):
            - Function that returns (features, labels) when called.
            - Hook to initialise input iterator.
    """
    iterator_initializer_hook = IteratorInitializerHook()

    def train_inputs():
        """Returns training set as Operations.
        Returns:
            (features, labels) Operations that iterate over the dataset
            on every evaluation
        """
        with tf.name_scope('Training_data'):
            # Get Mnist data
            images = mnist_data.train.images.reshape([-1, 28, 28, 1])
            labels = mnist_data.train.labels
            # Define placeholders
            images_placeholder = tf.placeholder(
                images.dtype, images.shape)
            labels_placeholder = tf.placeholder(
                labels.dtype, labels.shape)
            # Build dataset iterator
            dataset = tf.contrib.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))
            dataset = dataset.repeat(None)  # Infinite iterations
            dataset = dataset.shuffle(buffer_size=10000)
            dataset = dataset.batch(batch_size)
            iterator = dataset.make_initializable_iterator()
            next_example, next_label = iterator.get_next()
            # Set runhook to initialize iterator
            iterator_initializer_hook.iterator_initializer_func = \
                lambda sess: sess.run(
                    iterator.initializer,
                    feed_dict={images_placeholder: images,
                               labels_placeholder: labels})
            # Return batched (features, labels)
            return next_example, next_label

    # Return function and hook
    return train_inputs, iterator_initializer_hook

调用这个get_train_inputs将返回一个一级函数,用于在TensorFlow图中创建数据加载操作,以及返回一个用于初始化迭代器的Hook

本示例中使用的MNIST数据最初是一个Numpy数组。我们创建了一个占位符张量来获取数据;使用占位符的目的是为了避免数据的复制。接下来,我们在from_tensor_slices的帮助下创建一个切片数据集。我们要确保该数据集可以运行无限次数,并且数据被重新洗牌并放入指定大小的批次中。

要迭代数据,就需要从数据集中创建一个迭代器。由于我们正在使用占位符,因此需要使用NumPy数据在相关会话中对占位符进行初始化。可以通过创建一个可初始化的迭代器来实现这个。在创建图的时候,将创建一个自定义的IteratorInitializerHook对象来初始化迭代器:

class IteratorInitializerHook(tf.train.SessionRunHook):
    """Hook to initialise data iterator after Session is created."""

    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None

    def after_create_session(self, session, coord):
        """Initialise the iterator after the session has been created."""
        self.iterator_initializer_func(session)

IteratorInitializerHook继承自SessionRunHook。这个钩子将在相关会话创建后立即调用after_create_session,并使用正确的数据初始化占位符。这个钩子由我们的get_train_inputs函数返回,并在创建时传递给Experiment对象。

train_inputs函数返回的数据加载操作是TensorFlow的操作,该操作每次评估时都会返回一个新的批处理。

运行代码

现在,我们已经定义了所有内容,可以使用下面这个命令运行代码了:

python mnist_estimator.py --model_dir ./mnist_training --data_dir ./mnist_data

如果不传入参数,它将使用文件开头的默认标志来确定数据和模型保存的位置。

在训练过程中,在终端上会输出这段时间内的全​​局步骤、损失和准确性等信息。除此之外,Experiment和Estimator框架将记录TensorBoard可视化的某些统计信息。如果我们运行这个命令:

tensorboard --logdir='./mnist_training'

那么我们可以看到所有的训练统计数据,如训练损失、评估准确性、每个步骤的时间,以及模型图。

3.png
TensorBoard可视化中的评估准确度

我写这篇文章,是因为我在编写代码示例时,无法找到有关Tensorflow Estimator 、Experiment和Dataset框架太多的信息和示例。我希望这篇文章能向你简要介绍一下这些框架是如何工作的,它们采用了什么样的抽象方法以及如何使用它们。如果你对使用这些框架感兴趣,下面我将介绍一些注意点和其他的文档。

有关Estimator、Experiment和Dataset框架的注意点

文章原标题《Higher-Level APIs in TensorFlow》,作者:Peter Roelants,译者:夏天,审校:主题曲。

文章为简译,更为详细的内容,请查看原文需要爬梯,不方便的同学也可以下载下方的PDF附件,阅读原文内容。

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

附件下载:https://developer.aliyun.com/topic/download?id=453

相关文章
借助ThreadLocal实现多线程安全 | 带你学《Java语言高级特性》之三十二
本节结合消息发送案例简单直白地向读者展现了多线程模式下的线程安全问题,并向读者介绍了ThreadLocal类及其实现线程安全的原理,讲解了其使用方法。
793 0
Comparable助你实现自定义比较 | 带你学《Java语言高级特性》之三十六
上一节中我们已经了解到常规的排序比较功能并不适用于对象,本节将介绍JDK1.2之后Java为开发者提供的Comparable接口的相关内容。
681 0
怎么设置阿里云服务器安全组?阿里云安全组规则详细解说
阿里云服务器安全组设置规则分享,阿里云服务器安全组如何放行端口设置教程
6838 0
阿里云服务器端口号设置
阿里云服务器初级使用者可能面临的问题之一. 使用tomcat或者其他服务器软件设置端口号后,比如 一些不是默认的, mysql的 3306, mssql的1433,有时候打不开网页, 原因是没有在ecs安全组去设置这个端口号. 解决: 点击ecs下网络和安全下的安全组 在弹出的安全组中,如果没有就新建安全组,然后点击配置规则 最后如上图点击添加...或快速创建.   have fun!  将编程看作是一门艺术,而不单单是个技术。
4398 0
使用OpenApi弹性释放和设置云服务器ECS释放
云服务器ECS的一个重要特性就是按需创建资源。您可以在业务高峰期按需弹性的自定义规则进行资源创建,在完成业务计算的时候释放资源。本篇将提供几个Tips帮助您更加容易和自动化的完成云服务器的释放和弹性设置。
7735 0
CDN高级技术专家周哲:深度剖析短视频分发过程中的用户体验优化技术点
深圳云栖大会已经圆满落幕,在3月29日飞天技术汇-弹性计算、网络和CDN专场中,阿里云CDN高级技术专家周哲为我们带来了《海量短视频极速分发》的主题分享,带领我们从视频内容采集、上传、存储和分发的角度介绍整体方案,并且重点讲解短视频加速的注意事项和用户体验优化要点。
5010 0
迅速读懂Java线程的交通规则 | 带你学《Java语言高级特性》之八
线程就像现实世界里复杂的交通线上运行的汽车,有停车让行的,也有一路高歌的,甚至还有特殊情况优先行驶的情况,本节将一一介绍这些内容。
859 0
Python爬虫入门教程 57-100 python爬虫高级技术之验证码篇3-滑动验证码识别技术
滑动验证码介绍 本篇博客涉及到的验证码为滑动验证码,不同于极验证,本验证码难度略低,需要的将滑块拖动到矩形区域右侧即可完成。 这类验证码不常见了,官方介绍地址为:https://promotion.
3441 0
+关注
【方向】
欢迎各位对内容方向及质量提需求,我们尽量满足,将国外优质的内容呈现给大家!
696
文章
5
问答
文章排行榜
最热
最新
相关电子书
更多
文娱运维技术
立即下载
《SaaS模式云原生数据仓库应用场景实践》
立即下载
《看见新力量:二》电子书
立即下载