【Python-Tensorflow】tf.data.Dataset的解析与使用

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
简介: 本文详细介绍了TensorFlow中`tf.data.Dataset`类的使用,包括创建数据集的方法(如`from_generator()`、`from_tensor_slices()`、`from_tensors()`)、数据集函数(如`apply()`、`as_numpy_iterator()`、`batch()`、`cache()`等),以及如何通过这些函数进行高效的数据预处理和操作。

参考资料

1 作用

dataset = tf.data.Dataset…()

构建和处理数据集。包括三种类型的操作。

  • 根据输入数据创建源数据集。
  • 应用数据集转换以预处理数据。
  • 遍历数据集并处理元素。

2 tf.data.Dataset的函数

2.1 from_generator()

通过生成器去创建dataset,该函数的参数用于传生成器

# 定义生成器
def gen():
  ragged_tensor = tf.ragged.constant([[1, 2], [3]])
  yield 42, ragged_tensor
# 创建数据集
dataset = tf.data.Dataset.from_generator(
     gen,
     # 定义输出形状和输出类型
     output_signature=(
          # 定义输出形状
         tf.TensorSpec(shape=(), dtype=tf.int32),
         # 定义输出类型
         tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))

print(list(dataset.take(1)))

2.2 from_tensor_slices()

对给定张量进行切片
给定的张量沿其第一维被切片。此操作将保留输入张量的结构,删除每个张量的第一维并将其用作数据集维。所有输入张量的第一个维度必须具有相同的大小。

# Slicing a 1D tensor produces scalar tensor elements.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
list(dataset.as_numpy_iterator())
# Slicing a tuple of 1D tensors produces tuple elements containing
# scalar tensors.
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
list(dataset.as_numpy_iterator())
[(1,3,5),(2,4,6)]
# Dictionary structure is also preserved.
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
                                      {'a': 2, 'b': 4}]
True

2.3 from_tensors()

创建一个Dataset包含给定张量的单个元素的。
from_tensors产生仅包含单个元素的数据集。要将输入张量切成多个元素,请from_tensor_slices改用

dataset = tf.data.Dataset.from_tensors([1, 2, 3])
list(dataset.as_numpy_iterator())
[array([1,2,3],dtype=int32)]
dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
list(dataset.as_numpy_iterator())
[(array([1,2,3],dtype=int32),b'A')]

3 dataset 的函数

3.1 apply()

apply启用自定义Dataset转换的链接,这些转换表示为采用一个Dataset参数并返回transformd的函数Dataset。

dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())

3.2 as_numpy_iterator()

返回一个迭代器,该迭代器将数据集的所有元素转换为numpy。

使用as_numpy_iterator检查你的数据集的内容。要查看元素的形状和类型,请直接打印数据集元素,而不要使用 as_numpy_iterator。不建议使用

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
  print(element)

建议如下用法

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
  print(element)

3.3 batch()

将此数据集的连续元素合并为批

dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())

分成三批,分别为【1 2 3】【4 5 6】【7 8】


3.4 cache()

在此数据集中缓存元素。
第一次迭代数据集时,其元素将缓存在指定的文件或内存中。随后的迭代将使用缓存的数据。

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache()
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())
[0,1,4,9,16]
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())
[0,1,4,9,16]

缓存到文件时,缓存的数据将在运行期间保持不变。即使是第一次遍历数据,也将从缓存文件中读取。.cache()直到删除缓存文件或更改文件名,在调用之前更改输入管道才有效。

dataset = tf.data.Dataset.range(5)
dataset = dataset.cache("/path/to/file")  # doctest: +SKIP
list(dataset.as_numpy_iterator())  # doctest: +SKIP
[0,1,2,3,4]
dataset = tf.data.Dataset.range(10)
dataset = dataset.cache("/path/to/file")  # Same file! # doctest: +SKIP
list(dataset.as_numpy_iterator())  # doctest: +SKIP
[0,1,2,3,4]

3.5 cardinality()

返回数据集的大小

  • 数量确定返回数字
  • 无限量,返回tf.data.INFINITE_CARDINALITY
  • 未知,返回tf.data.UNKNOWN_CARDINALITY
dataset = tf.data.Dataset.range(42)
print(dataset.cardinality().numpy())
42
dataset = dataset.repeat()
cardinality = dataset.cardinality()
print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
True
dataset = dataset.filter(lambda x: True)
cardinality = dataset.cardinality()
print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
True

3.6 concatenate()

将给定数据集与此数据集连接来创建一个新的dataset

a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
list(ds.as_numpy_iterator())
[1,2,3,4,5,6,7]
# The input dataset and dataset to be concatenated should have the same
# nested structures and output types.
c = tf.data.Dataset.zip((a, b))
a.concatenate(c)
错误,a、c类型不同,c是tf.int64类型,a是int64类型
d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
a.concatenate(d)
错误,a、d类型不同,a是int64类型,d是string类型

3.7 enumerate()

枚举此数据集的元素。
它类似于python的enumerate

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.enumerate(start=5)
for element in dataset.as_numpy_iterator():
  print(element)
(5,1)
(6,2)
(7,4)

3.8 filter()

根据自定义过滤函数去过滤此数据集

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.filter(lambda x: x < 3)
list(dataset.as_numpy_iterator())
[1,2]
# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
  return tf.math.equal(x, 1)
dataset = dataset.filter(filter_fn)
list(dataset.as_numpy_iterator())
[1]

3.9 flat_map()

跨此数据集映射并展平结果。
使用flat_map,如果你想确保你的数据集保持不变的顺序。例如,要将批次的数据集展平为其元素的数据集:

dataset = tf.data.Dataset.from_tensor_slices(
               [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x))
list(dataset.as_numpy_iterator())
[1,2,3,4,5,6,7,8,9]

3.10 zip()

将给定的数据集压缩在一起来创建一个。
此方法的语义与zip()Python的内置函数相似,主要区别在于datasets 参数可以是Dataset对象的任意嵌套结构。

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
ds = tf.data.Dataset.zip((a, b))
list(ds.as_numpy_iterator())
[(1,4),(2,5),(3,6)]
ds = tf.data.Dataset.zip((b, a))
list(ds.as_numpy_iterator())
[(4,1),(5,2),(6,3)]

3.11 window()

window(size, shift=None, stride=1, drop_remainder=False)

将输入元素(嵌套)组合到窗口(嵌套)的数据集中。说白了就是按窗口大小划分数据集。
“窗口”是大小为平面元素的有限数据集size(如果没有足够的输入元素来填充窗口并drop_remainder计算为,则可能会更少 False)。
该shift参数确定窗口在每次迭代中移动的输入元素的数量。如果窗口和元素都从0开始编号,则窗口中的第一个元素k将是k * shift 输入数据集的元素。特别是,第一个窗口的第一个元素将始终是输入数据集的第一个元素。
所述stride参数确定输入元件的步幅,并且 shift参数确定窗口的移位。

dataset = tf.data.Dataset.range(7).window(2)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,1]
[2,3]
[4,5]
[6]
dataset = tf.data.Dataset.range(7).window(3, 2, 1, True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,1,2]
[2,3,4]
[4,5,6]
dataset = tf.data.Dataset.range(7).window(3, 1, 2, True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,2,4]
[1,3,5]
[2,4,6]

请注意,将window转换应用于嵌套元素的数据集时,它将生成嵌套窗口的数据集。

nested = ([1, 2, 3, 4], [5, 6, 7, 8])
dataset = tf.data.Dataset.from_tensor_slices(nested).window(2)
for window in dataset:
  def to_numpy(ds):
    return list(ds.as_numpy_iterator())
  print(tuple(to_numpy(component) for component in window))
([1,2],[5,6])
([3,4],[7,8])
dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4]})
dataset = dataset.window(2)
for window in dataset:
  def to_numpy(ds):
    return list(ds.as_numpy_iterator())
  print({'a': to_numpy(window['a'])})
  {'a':[1,2]}
  {'a':[3,4]}

3.12 unbatch()

将数据集的元素拆分为多个元素。
例如,如果数据集的元素是shape [B, a0, a1, …],其中B每个输入元素的位置可能有所不同,那么对于数据集中的每个元素,未批处理的数据集将包含Bshape的连续元素[a0, a1, …]。

elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
dataset = dataset.unbatch()
list(dataset.as_numpy_iterator())
[1,2,3,1,2,1,2,3,4]

3.13 take()

从此数据集中Dataset最多创建一个count元素

dataset = tf.data.Dataset.range(10)
dataset = dataset.take(3)
list(dataset.as_numpy_iterator())
[0,1,2]

3.14 skip()

创建一个Dataset跳过count此数据集中的元素的。

dataset = tf.data.Dataset.range(10)
dataset = dataset.skip(7)
list(dataset.as_numpy_iterator())
python
[7,8,9]

3.15 shuffle()

shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)

随机重新排列此数据集的元素。
该数据集用buffer_size元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的改组,需要缓冲区大小大于或等于数据集的完整大小。

例如,如果您的数据集包含10,000个元素但buffer_size设置为1,000,则shuffle最初将仅从缓冲区的前1,000个元素中选择一个随机元素。选择一个元素后,其缓冲区中的空间将被下一个(即1,001个)元素替换,并保留1,000个元素缓冲区。

reshuffle_each_iteration控制随机播放顺序对于每个时期是否应该不同。在TF 1.X中,创建历元的惯用方式是通过repeat转换:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
dataset = dataset.repeat(2)  # doctest: +SKIP
[1,0,2,1,2,0]

3.16 shard()

shard( num_shards, index)

返回dataset指定索引开始,一定步长下的所有数据
num_shards步长,index索引

A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0,3,6,9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1,4,7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2,5,8]

3.17 repeat()

重复此数据集

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.repeat(3)
list(dataset.as_numpy_iterator())
[1,2,3,1,2,3,1,2,3]

3.18 reduce()

reduce( initial_state, reduce_func)

将输入数据集简化为单个元素。
转换将reduce_func依次调用输入数据集的每个元素,直到数据集用完为止,以其内部状态聚合信息。该initial_state参数用于初始状态,并返回最终状态作为结果。

tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
5
tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
10

3.19 prefetch()

创建一个Dataset从该数据集中预提取元素的。

大多数数据集输入管道应以调用结束prefetch。这允许在处理当前元素时准备以后的元素。这通常会提高延迟和吞吐量,但以使用额外的内存存储预取元素为代价。

dataset = tf.data.Dataset.range(3)
dataset = dataset.prefetch(2)
list(dataset.as_numpy_iterator())
[0,1,2]

3.20 map()

map(map_func, num_parallel_calls=None, deterministic=None)

此转换将应用于map_func此数据集的每个元素,并以与输入中出现的顺序相同的顺序返回包含转换后的元素的新数据集。map_func可用于更改值和数据集元素的结构。例如,向每个元素加1或投影元素组件的子集。

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
list(dataset.as_numpy_iterator())
[2,3,4,5,6]
dataset = tf.data.Dataset.range(3)
# `map_func` returns two `tf.Tensor` objects.
def g(x):
  return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.element_spec
(TensorSpec(shape=(),dtype=tf.float32,name=None),
 TensorSpec(shape=(3,),dtype=tf.int32,name=None))

# `map_func` can return nested structures.
def i(x):
  return (37.0, [42, 16]), "foo"
result = dataset.map(i)
result.element_spec
(TensorSpec(shape=(),dtype=tf.float32,name=None),
 TensorSpec(shape=(2,),dtype=tf.int32,name=None),
 TensorSpec(shape=(),dtype=tf.string,name=None))

3.21 interleave()

interleave(
map_func, cycle_length=None, block_length=None, num_parallel_calls=None,
deterministic=None
)

map_func跨此数据集映射,并交织结果。
例如,您可以用来Dataset.interleave()同时处理许多输入文件:

  • cycle_length和block_length参数控制在其中的元件所产生的顺序。cycle_length控制并发处理的输入元素的数量。
  • 如果设置cycle_length为1,则此转换将一次处理一个输入元素,并将产生与相同的结果tf.data.Dataset.flat_map。
  • 一般来说,这种转换将适用map_func于cycle_length输入元件,开放迭代对返回的Dataset对象,并循环通过它们产生block_length从每个迭代连续元素,并且每个其到达一个迭代的结束时间消耗下一个输入元件。
dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
    lambda x: Dataset.from_tensors(x).repeat(6),
    cycle_length=2, block_length=4)
list(dataset.as_numpy_iterator())
[1,1,1,1,
2,2,2,2,
1,1,
2,2,
3,3,3,3,
4,4,4,4,
3,3,
4,4,
5,5,5,5,
5,5]
目录
相关文章
|
1天前
|
存储 索引 Python
Python入门:6.深入解析Python中的序列
在 Python 中,**序列**是一种有序的数据结构,广泛应用于数据存储、操作和处理。序列的一个显著特点是支持通过**索引**访问数据。常见的序列类型包括字符串(`str`)、列表(`list`)和元组(`tuple`)。这些序列各有特点,既可以存储简单的字符,也可以存储复杂的对象。 为了帮助初学者掌握 Python 中的序列操作,本文将围绕**字符串**、**列表**和**元组**这三种序列类型,详细介绍其定义、常用方法和具体示例。
Python入门:6.深入解析Python中的序列
|
1天前
|
存储 Linux iOS开发
Python入门:2.注释与变量的全面解析
在学习Python编程的过程中,注释和变量是必须掌握的两个基础概念。注释帮助我们理解代码的意图,而变量则是用于存储和操作数据的核心工具。熟练掌握这两者,不仅能提高代码的可读性和维护性,还能为后续学习复杂编程概念打下坚实的基础。
Python入门:2.注释与变量的全面解析
|
7天前
|
监控 算法 安全
内网桌面监控软件深度解析:基于 Python 实现的 K-Means 算法研究
内网桌面监控软件通过实时监测员工操作,保障企业信息安全并提升效率。本文深入探讨K-Means聚类算法在该软件中的应用,解析其原理与实现。K-Means通过迭代更新簇中心,将数据划分为K个簇类,适用于行为分析、异常检测、资源优化及安全威胁识别等场景。文中提供了Python代码示例,展示如何实现K-Means算法,并模拟内网监控数据进行聚类分析。
28 10
|
25天前
|
存储 算法 安全
控制局域网上网软件之 Python 字典树算法解析
控制局域网上网软件在现代网络管理中至关重要,用于控制设备的上网行为和访问权限。本文聚焦于字典树(Trie Tree)算法的应用,详细阐述其原理、优势及实现。通过字典树,软件能高效进行关键词匹配和过滤,提升系统性能。文中还提供了Python代码示例,展示了字典树在网址过滤和关键词屏蔽中的具体应用,为局域网的安全和管理提供有力支持。
50 17
|
28天前
|
运维 Shell 数据库
Python执行Shell命令并获取结果:深入解析与实战
通过以上内容,开发者可以在实际项目中灵活应用Python执行Shell命令,实现各种自动化任务,提高开发和运维效率。
56 20
|
1月前
|
数据采集 供应链 API
Python爬虫与1688图片搜索API接口:深度解析与显著收益
在电子商务领域,数据是驱动业务决策的核心。阿里巴巴旗下的1688平台作为全球领先的B2B市场,提供了丰富的API接口,特别是图片搜索API(`item_search_img`),允许开发者通过上传图片搜索相似商品。本文介绍如何结合Python爬虫技术高效利用该接口,提升搜索效率和用户体验,助力企业实现自动化商品搜索、库存管理优化、竞品监控与定价策略调整等,显著提高运营效率和市场竞争力。
89 3
|
2月前
|
数据挖掘 vr&ar C++
让UE自动运行Python脚本:实现与实例解析
本文介绍如何配置Unreal Engine(UE)以自动运行Python脚本,提高开发效率。通过安装Python、配置UE环境及使用第三方插件,实现Python与UE的集成。结合蓝图和C++示例,展示自动化任务处理、关卡生成及数据分析等应用场景。
178 5
|
2月前
|
数据采集 JSON API
如何利用Python爬虫淘宝商品详情高级版(item_get_pro)API接口及返回值解析说明
本文介绍了如何利用Python爬虫技术调用淘宝商品详情高级版API接口(item_get_pro),获取商品的详细信息,包括标题、价格、销量等。文章涵盖了环境准备、API权限申请、请求构建和返回值解析等内容,强调了数据获取的合规性和安全性。
|
2月前
|
存储 缓存 Python
Python中的装饰器深度解析与实践
在Python的世界里,装饰器如同一位神秘的魔法师,它拥有改变函数行为的能力。本文将揭开装饰器的神秘面纱,通过直观的代码示例,引导你理解其工作原理,并掌握如何在实际项目中灵活运用这一强大的工具。从基础到进阶,我们将一起探索装饰器的魅力所在。
|
2月前
|
Android开发 开发者 Python
通过标签清理微信好友:Python自动化脚本解析
微信已成为日常生活中的重要社交工具,但随着使用时间增长,好友列表可能变得臃肿。本文介绍了一个基于 Python 的自动化脚本,利用 `uiautomator2` 库,通过模拟用户操作实现根据标签批量清理微信好友的功能。脚本包括环境准备、类定义、方法实现等部分,详细解析了如何通过标签筛选并删除好友,适合需要批量管理微信好友的用户。
108 7

热门文章

最新文章