【Python-Tensorflow】tf.data.Dataset的解析与使用

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 本文详细介绍了TensorFlow中`tf.data.Dataset`类的使用,包括创建数据集的方法(如`from_generator()`、`from_tensor_slices()`、`from_tensors()`)、数据集函数(如`apply()`、`as_numpy_iterator()`、`batch()`、`cache()`等),以及如何通过这些函数进行高效的数据预处理和操作。

参考资料

1 作用

dataset = tf.data.Dataset…()

构建和处理数据集。包括三种类型的操作。

  • 根据输入数据创建源数据集。
  • 应用数据集转换以预处理数据。
  • 遍历数据集并处理元素。

2 tf.data.Dataset的函数

2.1 from_generator()

通过生成器去创建dataset,该函数的参数用于传生成器

# 定义生成器
def gen():
  ragged_tensor = tf.ragged.constant([[1, 2], [3]])
  yield 42, ragged_tensor
# 创建数据集
dataset = tf.data.Dataset.from_generator(
     gen,
     # 定义输出形状和输出类型
     output_signature=(
          # 定义输出形状
         tf.TensorSpec(shape=(), dtype=tf.int32),
         # 定义输出类型
         tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))

print(list(dataset.take(1)))

2.2 from_tensor_slices()

对给定张量进行切片
给定的张量沿其第一维被切片。此操作将保留输入张量的结构,删除每个张量的第一维并将其用作数据集维。所有输入张量的第一个维度必须具有相同的大小。

# Slicing a 1D tensor produces scalar tensor elements.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
list(dataset.as_numpy_iterator())
# Slicing a tuple of 1D tensors produces tuple elements containing
# scalar tensors.
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
list(dataset.as_numpy_iterator())
[(1,3,5),(2,4,6)]
# Dictionary structure is also preserved.
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
                                      {'a': 2, 'b': 4}]
True

2.3 from_tensors()

创建一个Dataset包含给定张量的单个元素的。
from_tensors产生仅包含单个元素的数据集。要将输入张量切成多个元素,请from_tensor_slices改用

dataset = tf.data.Dataset.from_tensors([1, 2, 3])
list(dataset.as_numpy_iterator())
[array([1,2,3],dtype=int32)]
dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
list(dataset.as_numpy_iterator())
[(array([1,2,3],dtype=int32),b'A')]

3 dataset 的函数

3.1 apply()

apply启用自定义Dataset转换的链接,这些转换表示为采用一个Dataset参数并返回transformd的函数Dataset。

dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())

3.2 as_numpy_iterator()

返回一个迭代器,该迭代器将数据集的所有元素转换为numpy。

使用as_numpy_iterator检查你的数据集的内容。要查看元素的形状和类型,请直接打印数据集元素,而不要使用 as_numpy_iterator。不建议使用

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
  print(element)

建议如下用法

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
  print(element)

3.3 batch()

将此数据集的连续元素合并为批

dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())

分成三批,分别为【1 2 3】【4 5 6】【7 8】


3.4 cache()

在此数据集中缓存元素。
第一次迭代数据集时,其元素将缓存在指定的文件或内存中。随后的迭代将使用缓存的数据。

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache()
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())
[0,1,4,9,16]
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())
[0,1,4,9,16]

缓存到文件时,缓存的数据将在运行期间保持不变。即使是第一次遍历数据,也将从缓存文件中读取。.cache()直到删除缓存文件或更改文件名,在调用之前更改输入管道才有效。

dataset = tf.data.Dataset.range(5)
dataset = dataset.cache("/path/to/file")  # doctest: +SKIP
list(dataset.as_numpy_iterator())  # doctest: +SKIP
[0,1,2,3,4]
dataset = tf.data.Dataset.range(10)
dataset = dataset.cache("/path/to/file")  # Same file! # doctest: +SKIP
list(dataset.as_numpy_iterator())  # doctest: +SKIP
[0,1,2,3,4]

3.5 cardinality()

返回数据集的大小

  • 数量确定返回数字
  • 无限量,返回tf.data.INFINITE_CARDINALITY
  • 未知,返回tf.data.UNKNOWN_CARDINALITY
dataset = tf.data.Dataset.range(42)
print(dataset.cardinality().numpy())
42
dataset = dataset.repeat()
cardinality = dataset.cardinality()
print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
True
dataset = dataset.filter(lambda x: True)
cardinality = dataset.cardinality()
print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
True

3.6 concatenate()

将给定数据集与此数据集连接来创建一个新的dataset

a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
list(ds.as_numpy_iterator())
[1,2,3,4,5,6,7]
# The input dataset and dataset to be concatenated should have the same
# nested structures and output types.
c = tf.data.Dataset.zip((a, b))
a.concatenate(c)
错误,a、c类型不同,c是tf.int64类型,a是int64类型
d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
a.concatenate(d)
错误,a、d类型不同,a是int64类型,d是string类型

3.7 enumerate()

枚举此数据集的元素。
它类似于python的enumerate

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.enumerate(start=5)
for element in dataset.as_numpy_iterator():
  print(element)
(5,1)
(6,2)
(7,4)

3.8 filter()

根据自定义过滤函数去过滤此数据集

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.filter(lambda x: x < 3)
list(dataset.as_numpy_iterator())
[1,2]
# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
  return tf.math.equal(x, 1)
dataset = dataset.filter(filter_fn)
list(dataset.as_numpy_iterator())
[1]

3.9 flat_map()

跨此数据集映射并展平结果。
使用flat_map,如果你想确保你的数据集保持不变的顺序。例如,要将批次的数据集展平为其元素的数据集:

dataset = tf.data.Dataset.from_tensor_slices(
               [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x))
list(dataset.as_numpy_iterator())
[1,2,3,4,5,6,7,8,9]

3.10 zip()

将给定的数据集压缩在一起来创建一个。
此方法的语义与zip()Python的内置函数相似,主要区别在于datasets 参数可以是Dataset对象的任意嵌套结构。

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
ds = tf.data.Dataset.zip((a, b))
list(ds.as_numpy_iterator())
[(1,4),(2,5),(3,6)]
ds = tf.data.Dataset.zip((b, a))
list(ds.as_numpy_iterator())
[(4,1),(5,2),(6,3)]

3.11 window()

window(size, shift=None, stride=1, drop_remainder=False)

将输入元素(嵌套)组合到窗口(嵌套)的数据集中。说白了就是按窗口大小划分数据集。
“窗口”是大小为平面元素的有限数据集size(如果没有足够的输入元素来填充窗口并drop_remainder计算为,则可能会更少 False)。
该shift参数确定窗口在每次迭代中移动的输入元素的数量。如果窗口和元素都从0开始编号,则窗口中的第一个元素k将是k * shift 输入数据集的元素。特别是,第一个窗口的第一个元素将始终是输入数据集的第一个元素。
所述stride参数确定输入元件的步幅,并且 shift参数确定窗口的移位。

dataset = tf.data.Dataset.range(7).window(2)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,1]
[2,3]
[4,5]
[6]
dataset = tf.data.Dataset.range(7).window(3, 2, 1, True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,1,2]
[2,3,4]
[4,5,6]
dataset = tf.data.Dataset.range(7).window(3, 1, 2, True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,2,4]
[1,3,5]
[2,4,6]

请注意,将window转换应用于嵌套元素的数据集时,它将生成嵌套窗口的数据集。

nested = ([1, 2, 3, 4], [5, 6, 7, 8])
dataset = tf.data.Dataset.from_tensor_slices(nested).window(2)
for window in dataset:
  def to_numpy(ds):
    return list(ds.as_numpy_iterator())
  print(tuple(to_numpy(component) for component in window))
([1,2],[5,6])
([3,4],[7,8])
dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4]})
dataset = dataset.window(2)
for window in dataset:
  def to_numpy(ds):
    return list(ds.as_numpy_iterator())
  print({'a': to_numpy(window['a'])})
  {'a':[1,2]}
  {'a':[3,4]}

3.12 unbatch()

将数据集的元素拆分为多个元素。
例如,如果数据集的元素是shape [B, a0, a1, …],其中B每个输入元素的位置可能有所不同,那么对于数据集中的每个元素,未批处理的数据集将包含Bshape的连续元素[a0, a1, …]。

elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
dataset = dataset.unbatch()
list(dataset.as_numpy_iterator())
[1,2,3,1,2,1,2,3,4]

3.13 take()

从此数据集中Dataset最多创建一个count元素

dataset = tf.data.Dataset.range(10)
dataset = dataset.take(3)
list(dataset.as_numpy_iterator())
[0,1,2]

3.14 skip()

创建一个Dataset跳过count此数据集中的元素的。

dataset = tf.data.Dataset.range(10)
dataset = dataset.skip(7)
list(dataset.as_numpy_iterator())
python
[7,8,9]

3.15 shuffle()

shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)

随机重新排列此数据集的元素。
该数据集用buffer_size元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的改组,需要缓冲区大小大于或等于数据集的完整大小。

例如,如果您的数据集包含10,000个元素但buffer_size设置为1,000,则shuffle最初将仅从缓冲区的前1,000个元素中选择一个随机元素。选择一个元素后,其缓冲区中的空间将被下一个(即1,001个)元素替换,并保留1,000个元素缓冲区。

reshuffle_each_iteration控制随机播放顺序对于每个时期是否应该不同。在TF 1.X中,创建历元的惯用方式是通过repeat转换:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
dataset = dataset.repeat(2)  # doctest: +SKIP
[1,0,2,1,2,0]

3.16 shard()

shard( num_shards, index)

返回dataset指定索引开始,一定步长下的所有数据
num_shards步长,index索引

A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0,3,6,9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1,4,7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2,5,8]

3.17 repeat()

重复此数据集

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.repeat(3)
list(dataset.as_numpy_iterator())
[1,2,3,1,2,3,1,2,3]

3.18 reduce()

reduce( initial_state, reduce_func)

将输入数据集简化为单个元素。
转换将reduce_func依次调用输入数据集的每个元素,直到数据集用完为止,以其内部状态聚合信息。该initial_state参数用于初始状态,并返回最终状态作为结果。

tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
5
tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
10

3.19 prefetch()

创建一个Dataset从该数据集中预提取元素的。

大多数数据集输入管道应以调用结束prefetch。这允许在处理当前元素时准备以后的元素。这通常会提高延迟和吞吐量,但以使用额外的内存存储预取元素为代价。

dataset = tf.data.Dataset.range(3)
dataset = dataset.prefetch(2)
list(dataset.as_numpy_iterator())
[0,1,2]

3.20 map()

map(map_func, num_parallel_calls=None, deterministic=None)

此转换将应用于map_func此数据集的每个元素,并以与输入中出现的顺序相同的顺序返回包含转换后的元素的新数据集。map_func可用于更改值和数据集元素的结构。例如,向每个元素加1或投影元素组件的子集。

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
list(dataset.as_numpy_iterator())
[2,3,4,5,6]
dataset = tf.data.Dataset.range(3)
# `map_func` returns two `tf.Tensor` objects.
def g(x):
  return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.element_spec
(TensorSpec(shape=(),dtype=tf.float32,name=None),
 TensorSpec(shape=(3,),dtype=tf.int32,name=None))

# `map_func` can return nested structures.
def i(x):
  return (37.0, [42, 16]), "foo"
result = dataset.map(i)
result.element_spec
(TensorSpec(shape=(),dtype=tf.float32,name=None),
 TensorSpec(shape=(2,),dtype=tf.int32,name=None),
 TensorSpec(shape=(),dtype=tf.string,name=None))

3.21 interleave()

interleave(
map_func, cycle_length=None, block_length=None, num_parallel_calls=None,
deterministic=None
)

map_func跨此数据集映射,并交织结果。
例如,您可以用来Dataset.interleave()同时处理许多输入文件:

  • cycle_length和block_length参数控制在其中的元件所产生的顺序。cycle_length控制并发处理的输入元素的数量。
  • 如果设置cycle_length为1,则此转换将一次处理一个输入元素,并将产生与相同的结果tf.data.Dataset.flat_map。
  • 一般来说,这种转换将适用map_func于cycle_length输入元件,开放迭代对返回的Dataset对象,并循环通过它们产生block_length从每个迭代连续元素,并且每个其到达一个迭代的结束时间消耗下一个输入元件。
dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
    lambda x: Dataset.from_tensors(x).repeat(6),
    cycle_length=2, block_length=4)
list(dataset.as_numpy_iterator())
[1,1,1,1,
2,2,2,2,
1,1,
2,2,
3,3,3,3,
4,4,4,4,
3,3,
4,4,
5,5,5,5,
5,5]
目录
相关文章
|
1月前
|
算法 前端开发 数据处理
小白学python-深入解析一位字符判定算法
小白学python-深入解析一位字符判定算法
47 0
|
12天前
|
算法 Python
Python 大神修炼手册:图的深度优先&广度优先遍历,深入骨髓的解析
在 Python 编程中,掌握图的深度优先遍历(DFS)和广度优先遍历(BFS)是进阶的关键。这两种算法不仅理论重要,还能解决实际问题。本文介绍了图的基本概念、邻接表表示方法,并给出了 DFS 和 BFS 的 Python 实现代码示例,帮助读者深入理解并应用这些算法。
25 2
|
21天前
|
测试技术 开发者 Python
深入浅出:Python中的装饰器解析与应用###
【10月更文挑战第22天】 本文将带你走进Python装饰器的世界,揭示其背后的魔法。我们将一起探索装饰器的定义、工作原理、常见用法以及如何自定义装饰器,让你的代码更加简洁高效。无论你是Python新手还是有一定经验的开发者,相信这篇文章都能为你带来新的启发和收获。 ###
12 1
|
21天前
|
设计模式 测试技术 开发者
Python中的装饰器深度解析
【10月更文挑战第24天】在Python的世界中,装饰器是那些能够为函数或类“添彩”的魔法工具。本文将带你深入理解装饰器的概念、工作原理以及如何自定义装饰器,让你的代码更加优雅和高效。
|
1月前
|
XML 前端开发 数据格式
Beautiful Soup 解析html | python小知识
在数据驱动的时代,网页数据是非常宝贵的资源。很多时候我们需要从网页上提取数据,进行分析和处理。Beautiful Soup 是一个非常流行的 Python 库,可以帮助我们轻松地解析和提取网页中的数据。本文将详细介绍 Beautiful Soup 的基础知识和常用操作,帮助初学者快速入门和精通这一强大的工具。【10月更文挑战第11天】
56 2
|
1月前
|
数据安全/隐私保护 流计算 开发者
python知识点100篇系列(18)-解析m3u8文件的下载视频
【10月更文挑战第6天】m3u8是苹果公司推出的一种视频播放标准,采用UTF-8编码,主要用于记录视频的网络地址。HLS(Http Live Streaming)是苹果公司提出的一种基于HTTP的流媒体传输协议,通过m3u8索引文件按序访问ts文件,实现音视频播放。本文介绍了如何通过浏览器找到m3u8文件,解析m3u8文件获取ts文件地址,下载ts文件并解密(如有必要),最后使用ffmpeg合并ts文件为mp4文件。
|
1月前
|
Web App开发 SQL 数据库
使用 Python 解析火狐浏览器的 SQLite3 数据库
本文介绍如何使用 Python 解析火狐浏览器的 SQLite3 数据库,包括书签、历史记录和下载记录等。通过安装 Python 和 SQLite3,定位火狐数据库文件路径,编写 Python 脚本连接数据库并执行 SQL 查询,最终输出最近访问的网站历史记录。
|
1月前
|
机器学习/深度学习 算法 Python
深度解析机器学习中过拟合与欠拟合现象:理解模型偏差背后的原因及其解决方案,附带Python示例代码助你轻松掌握平衡技巧
【10月更文挑战第10天】机器学习模型旨在从数据中学习规律并预测新数据。训练过程中常遇过拟合和欠拟合问题。过拟合指模型在训练集上表现优异但泛化能力差,欠拟合则指模型未能充分学习数据规律,两者均影响模型效果。解决方法包括正则化、增加训练数据和特征选择等。示例代码展示了如何使用Python和Scikit-learn进行线性回归建模,并观察不同情况下的表现。
282 3
|
1月前
|
网络协议 Python
IP地址探秘:识别与解析的Python之旅 🚀
《IP地址探秘:识别与解析的Python之旅》通过Python的`ipaddress`模块,轻松实现IP地址的分类(如单播、多播、私有、环回或公有)及子网内所有IP的生成,使网络管理更加便捷高效。示例代码直观展示了功能实现过程。
|
1月前
|
运维 安全 网络协议
Python 网络编程:端口检测与IP解析
本文介绍了使用Python进行网络编程的两个重要技能:检查端口状态和根据IP地址解析主机名。通过`socket`库实现端口扫描和主机名解析的功能,并提供了详细的示例代码。文章最后还展示了如何整合这两部分代码,实现一个简单的命令行端口扫描器,适用于网络故障排查和安全审计。