Tensorflow datasets.shuffle repeat batch方法

简介: 机器学习中数据读取是很重要的一个环节,TensorFlow也提供了很多实用的方法,为了避免以后时间久了又忘记,所以写下笔记以备日后查看。最普通的正常情况首先我们看看最普通的情况:# 创建0-10的数据集,每个batch取个数。

机器学习中数据读取是很重要的一个环节,TensorFlow也提供了很多实用的方法,为了避免以后时间久了又忘记,所以写下笔记以备日后查看。

最普通的正常情况

首先我们看看最普通的情况:

# 创建0-10的数据集,每个batch取个数。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]

由结果我们可以知道TensorFlow能很好地帮我们自动处理最后一个batch的数据。

datasets.batch(batch_size)与迭代次数的关系

但是如果上面for循环次数超过2会怎么样呢?也就是说如果 **循环次数*批数量 > 数据集数量** 会怎么样?我们试试看:

dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    >>==for i in range(3):==<<
        value = sess.run(next_element)
        print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]
---------------------------------------------------------------------------
OutOfRangeError                           Traceback (most recent call last)
D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1277     try:
   
  ...
  ...省略若干信息...
  ...
  
OutOfRangeError (see above for traceback): End of sequence
     [[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]

可以知道超过范围了,所以报错了。

datasets.repeat()

为了解决上述问题,repeat方法登场。还是直接看例子吧:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

可以知道repeat其实就是将数据集重复了指定次数,上面代码将数据集重复了2次,所以这次即使for循环次数是4也依旧能正常读取数据,并且都能完整把数据读取出来。同理,如果把for循环次数设置为大于4,那么也还是会报错,这么一来,我每次还得算repeat的次数,岂不是很心累?所以更简便的办法就是对repeat方法不设置重复次数,效果见如下:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(6):
        value = sess.run(next_element)
        print(value)

输出结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

此时无论for循环多少次都不怕啦~~

datasets.shuffle(buffer_size)

仔细看可以知道上面所有输出结果都是有序的,这在机器学习中用来训练模型是浪费资源且没有意义的,所以我们需要将数据打乱,这样每批次训练的时候所用到的数据集是不一样的,这样啊可以提高模型训练效果。

另外shuffle前需要设置buffer_size:

  • 不设置会报错,
  • buffer_size=1:不打乱顺序,既保持原序
  • buffer_size越大,打乱程度越大,演示效果见如下代码:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果:

[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]

注意:shuffle的顺序很重要,一般建议是最开始执行shuffle操作,因为如果是先执行batch操作的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。不信你看:

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]



MARSGGBO原创





2018-8-5



目录
相关文章
|
7月前
|
机器学习/深度学习 搜索推荐 算法
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
445 0
|
机器学习/深度学习 存储 人工智能
Google Earth Engine(GEE)——TensorFlow支持深度学习等高级机器学习方法(非免费项目)
Google Earth Engine(GEE)——TensorFlow支持深度学习等高级机器学习方法(非免费项目)
1394 0
|
4月前
|
CDN 缓存 前端开发
JSF 性能优化:提升应用响应速度
【8月更文挑战第31天】JavaServer Faces (JSF) 是构建企业级 Web 应用的强大框架。但随着应用复杂度增加,性能问题可能显现。本文通过具体案例介绍如何优化 JSF 应用,提升响应速度。首先创建一个名为 “MyJSFOptimizationApp” 的新 JSF 项目,并在 `pom.xml` 中添加必要的依赖。接着,在 `WEB-INF` 目录下配置 `web.xml` 文件,设置 JSF servlet。然后创建一个 Managed Bean 包含简单属性和方法,并使用 Facelets 页面 `index.xhtml` 展示信息。
42 0
|
4月前
|
UED 开发工具 iOS开发
Uno Platform大揭秘:如何在你的跨平台应用中,巧妙融入第三方库与服务,一键解锁无限可能,让应用功能飙升,用户体验爆棚!
【8月更文挑战第31天】Uno Platform 让开发者能用同一代码库打造 Windows、iOS、Android、macOS 甚至 Web 的多彩应用。本文介绍如何在 Uno Platform 中集成第三方库和服务,如 Mapbox 或 Google Maps 的 .NET SDK,以增强应用功能并提升用户体验。通过 NuGet 安装所需库,并在 XAML 页面中添加相应控件,即可实现地图等功能。尽管 Uno 平台减少了平台差异,但仍需关注版本兼容性和性能问题,确保应用在多平台上表现一致。掌握正确方法,让跨平台应用更出色。
62 0
|
4月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】用Tensorflow.keras的方法替代keras.layers.merge
在TensorFlow 2.0和Keras中替代旧版keras.layers.merge函数的方法,使用了新的层如add, multiply, concatenate, average, 和 dot来实现常见的层合并操作。
40 1
|
7月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
【2月更文挑战第22天】本文介绍基于Python的tensorflow库,将tensorflow与keras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++等其他语言中将其打开的方法~
156 1
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
|
7月前
|
并行计算 TensorFlow 算法框架/工具
Linux Ubuntu配置CPU与GPU版本tensorflow库的方法
Linux Ubuntu配置CPU与GPU版本tensorflow库的方法
165 1
|
7月前
|
机器学习/深度学习 数据可视化 TensorFlow
用TensorBoard可视化tensorflow神经网络模型结构与训练过程的方法
用TensorBoard可视化tensorflow神经网络模型结构与训练过程的方法
375 1
|
7月前
|
并行计算 TensorFlow 算法框架/工具
新版本GPU加速的tensorflow库的配置方法
新版本GPU加速的tensorflow库的配置方法
184 1
|
7月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
Anaconda配置Python新版本tensorflow库(CPU、GPU通用)的方法
Anaconda配置Python新版本tensorflow库(CPU、GPU通用)的方法
162 1