import tensorflow as tf import os def _get_training_data(FLAGS): ''' Buildind the input pipeline for training and inference using TFRecords files. @return data only for the training @return data for the inference ''' filenames=[FLAGS.tf_records_train_path+'/'+f for f in os.listdir(FLAGS.tf_records_train_path)] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(parse) dataset = dataset.shuffle(buffer_size=1) dataset = dataset.repeat() dataset = dataset.batch(FLAGS.batch_size) dataset = dataset.prefetch(buffer_size=1) dataset2 = tf.data.TFRecordDataset(filenames) dataset2 = dataset2.map(parse) dataset2 = dataset2.shuffle(buffer_size=1) dataset2 = dataset2.repeat() dataset2 = dataset2.batch(1) dataset2 = dataset2.prefetch(buffer_size=1) return dataset, dataset2 def _get_test_data(FLAGS): ''' Buildind the input pipeline for test data.''' filenames=[FLAGS.tf_records_test_path+'/'+f for f in os.listdir(FLAGS.tf_records_test_path)] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(parse) dataset = dataset.shuffle(buffer_size=1) dataset = dataset.repeat() dataset = dataset.batch(1) dataset = dataset.prefetch(buffer_size=1) return dataset def parse(serialized): ''' Parser fot the TFRecords file.''' features={'movie_ratings':tf.FixedLenFeature([3952], tf.float32), } parsed_example=tf.parse_single_example(serialized, features=features, ) movie_ratings = tf.cast(parsed_example['movie_ratings'], tf.float32) return movie_ratings
这段代码主要定义了三个函数来创建 TensorFlow 数据集对象,这些数据集对象将被用于训练、评估和推断神经网络模型。
函数 _get_training_data
被用于建立训练数据的输入管道,该函数接受一个名为 FLAGS 的对象作为参数,该对象包含训练数据的相关参数,例如 tf_records_train_path
和 batch_size
。在函数内部,首先获取训练数据的文件名列表,然后通过调用 tf.data.TFRecordDataset
方法创建一个 tf.data.Dataset
对象,用于从 TFRecord 文件中读取数据。接下来使用 map
方法调用 parse
函数对每条记录进行解析,并使用 shuffle
方法打乱数据的顺序。然后使用 repeat
方法将数据集重复多次,以便模型可以在训练期间多次使用数据。接着使用 batch
方法将数据分批次,以适应模型的训练。最后使用 prefetch
方法提前加载数据,以使模型在训练过程中不会因为数据加载而发生延迟。
函数 _get_test_data
用于创建测试数据的输入管道,其与 _get_training_data
的逻辑基本相同,只不过这里的数据是用于测试模型的。
函数 parse
被用于解析从 TFRecord 文件中读取的每条记录。它接受一个序列化的字符串,并将其解析为一个包含电影评分的张量。在此函数内部,使用 tf.parse_single_example
方法解析每条记录,其中 features
参数定义了每个属性的名称和类型。在本例中,数据集只包含一个名为 movie_ratings
的属性,它是一个长度为 3952 的一维浮点型张量。解析完成后,将其强制类型转换为 tf.float32
类型,并返回该张量。
在函数 _get_training_data
中,创建了两个数据集对象 dataset
和 dataset2
,分别用于训练和推断。这里需要注意的是,它们使用相同的文件名列表来读取数据,因为我们希望训练和推断数据来自同一个数据集,以确保它们具有相同的分布。使用 shuffle
方法打乱数据集的顺序也是为了保证训练和推断数据的随机性,从而更好地评估模型的泛化能力。
在函数 _get_test_data
中,与训练数据相比,该函数仅使用一个数据集对象来读取测试数据。这是因为测试数据不需要进行批处理,只需要将每个示例一个一个地输入到模型中进行预测即可。
函数 parse
主要用于解析从 TFRecord 文件中读取的数据,并将其转换为可用于训练的张量。在此例中,数据集只包含一个名为 movie_ratings
的属性,它是一个长度为 3952 的一维浮点型张量。在解析过程中,还可以对数据进行一些必要的处理,例如类型转换、归一化或者其他的一些预处理操作。
最后,需要注意的是,使用 TensorFlow 数据集 API 可以更加高效和方便地处理数据,并且可以在不同的环境(例如 CPU 或 GPU)中轻松切换数据处理的实现方式,因此这是训练深度学习模型时常用的工具之一。