tf_record_writer.py代码解释

简介: 这段代码是用来将电影评分数据集转换为 TensorFlow 训练所需的二进制 TFRecord 格式的。这里采用的是 MovieLens 数据集,其中包含了 1 百万个电影评分记录,用于推荐系统任务的训练和测试。该代码主要分为几个部分:
import numpy as np
import tensorflow as tf
import sys
from preprocess_data import get_dataset_1M
import os
from pathlib import Path
p = Path(__file__).parents[1]
ROOT_DIR=os.path.abspath(os.path.join(p, '..', 'data/'))
TF_RECORD_TRAIN_PATH='/tf_records/train'
TF_RECORD_TEST_PATH='/tf_records/test'
def _add_to_tfrecord(data_sample,tfrecord_writer):
    data_sample=list(data_sample.astype(dtype=np.float32))
    example = tf.train.Example(features=tf.train.Features(feature={'movie_ratings': float_feature(data_sample)}))                                          
    tfrecord_writer.write(example.SerializeToString())
def _get_output_filename(output_dir, idx, name):
    return '%s/%s_%03d.tfrecord' % (ROOT_DIR+output_dir, name, idx)
def int64_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def float_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def bytes_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def main():
    ''' Writes the .txt training and testing data into binary TF_Records.'''
    SAMPLES_PER_FILES=100
    training_set, test_set=get_dataset_1M()
    for data_set, name, dir_ in zip([training_set, test_set], ['train', 'test'], [TF_RECORD_TRAIN_PATH, TF_RECORD_TEST_PATH]):
        num_samples=len(data_set)
        i = 0
        fidx = 1
        while i < num_samples:
            tf_filename = _get_output_filename(dir_, fidx,  name=name)
            with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
                j = 0
                while i < num_samples and j < SAMPLES_PER_FILES:
                    sys.stdout.write('\r>> Converting sample %d/%d' % (i+1, num_samples))
                    sys.stdout.flush()
                    sample = data_set[i]
                    _add_to_tfrecord(sample, tfrecord_writer)
                    i += 1
                    j += 1
                fidx += 1
    print('\nFinished converting the dataset!')
if __name__ == "__main__":
    main()

这段代码是用来将电影评分数据集转换为 TensorFlow 训练所需的二进制 TFRecord 格式的。这里采用的是 MovieLens 数据集,其中包含了 1 百万个电影评分记录,用于推荐系统任务的训练和测试。

该代码主要分为几个部分:

  1. 引入必要的 Python 模块和工具
  2. 定义一些用于将数据样本添加到 TFRecord 文件中的函数,以及获取输出文件名的函数
  3. 定义一些特征转换函数,如 int64_feature、float_feature 和 bytes_feature,这些函数用于将样本的各个特征值转换为 TFRecord 文件所需的数据类型
  4. 定义主函数 main,其中首先调用 get_dataset_1M 函数,获取训练和测试集数据。然后,针对训练集和测试集,分别循环将数据集写入 TFRecord 文件中。每个文件中包含 100 个数据样本,这是为了方便后续处理和管理。在将数据样本添加到文件中时,调用了 _add_to_tfrecord 函数,并传入当前样本和 TFRecordWriter 对象,以将样本写入文件中。在完成一个文件的写入后,使用 fidx += 1 更新文件编号,以准备开始下一个文件的写入。最后输出完成的信息。
  5. 在主函数末尾,通过 if name == "main" 判断当前文件是否为主文件,如果是则执行 main 函数。

这段代码的主要目的是将原始数据转换为可以用于训练机器学习模型的格式,即 TFRecord 格式。这种格式不仅可以更快地读取和处理数据,而且可以节省存储空间。


这段代码中,main() 函数是主要的执行函数,它将训练和测试数据集写入到二进制的 TF_Records 文件中。在 main() 函数中,首先通过 get_dataset_1M() 函数获取训练和测试数据集。接下来,对于每个数据集,程序会将数据分成若干份,每份的大小为 SAMPLES_PER_FILES=100。对于每一份数据,程序会创建一个对应的 TF_Records 文件,将数据写入该文件中。

在写入数据时,程序使用 _add_to_tfrecord() 函数将每个数据样本转换为一个 TF_Records 格式,并将其写入 TF_Records 文件中。其中,_add_to_tfrecord() 函数将每个数据样本转换为一个 tf.train.Example 对象,该对象包含一个 tf.train.Features 对象,其中包含一个 movie_ratings 特征。movie_ratings 特征的值是一个由浮点数构成的列表,列表中的每个元素对应一个电影的评分。

在转换数据样本时,程序使用 float_feature() 函数将浮点数列表转换为一个 tf.train.Feature 对象,该对象将浮点数列表转换为一个 tf.train.FloatList 对象。在写入数据时,程序使用 tf.python_io.TFRecordWriter() 函数创建一个 TF_Records 文件写入器,并将数据写入该文件中。最后,程序输出转换数据集的信息。

相关文章
|
6月前
|
Docker 容器
求助: 运行模型时报错module 'megatron_util.mpu' has no attribute 'get_model_parallel_rank'
运行ZhipuAI/Multilingual-GLM-Summarization-zh的官方代码范例时,报错AttributeError: MGLMTextSummarizationPipeline: module 'megatron_util.mpu' has no attribute 'get_model_parallel_rank' 环境是基于ModelScope官方docker镜像,尝试了各个版本结果都是一样的。
290 5
|
4月前
|
索引
yolov5--detect.py --v5.0版本-最新代码详细解释-2021-6-29号更新
yolov5--detect.py --v5.0版本-最新代码详细解释-2021-6-29号更新
45 0
yolov5--detect.py --v5.0版本-最新代码详细解释-2021-6-29号更新
|
机器学习/深度学习 测试技术 TensorFlow
dataset.py代码解释
这段代码主要定义了三个函数来创建 TensorFlow 数据集对象,这些数据集对象将被用于训练、评估和推断神经网络模型。
|
9月前
|
机器学习/深度学习 JSON 数据格式
YOLOv5源码逐行超详细注释与解读(4)——验证部分val(test).py
YOLOv5源码逐行超详细注释与解读(4)——验证部分val(test).py
916 0
YOLOv5源码逐行超详细注释与解读(4)——验证部分val(test).py
|
API 数据格式
TensorFlow2._:model.summary() Output Shape为multiple解决方法
TensorFlow2._:model.summary() Output Shape为multiple解决方法
189 0
TensorFlow2._:model.summary() Output Shape为multiple解决方法
|
Serverless
train_test_split.py代码解释
这段代码用于将MovieLens 1M数据集的评分数据划分为训练集和测试集。 • 首先,使用Path库获取当前文件的父级目录,也就是项目根目录。 • 接着,定义输出训练集和测试集文件的路径。
118 0
|
存储 搜索推荐 Java
preprocess_data.py代码解释
循环遍历每个用户,对于每个用户,提取其对电影的评分。 创建一个与所有电影数量相同的评分数组,将相应的评分放置在数组的正确位置。 如果该用户没有评分电影,则跳过该用户。 返回所有用户的评分数组列表。
168 0
|
机器学习/深度学习 数据采集 搜索推荐
training.py的代码解释
labels、test_loss_op 和 mae_ops 计算模型的性能指标。最后,我们输出当前 epoch 的训练损失、测试损失和平均绝对误差(MAE),并保存模型参数(如果 MAE 小于 0.9)。 整个代码的目的是使用协同过滤算法建立电影推荐系统的模型,训练模型并计算模型的性能指标。
|
机器学习/深度学习 搜索推荐 TensorFlow
inference.py的代码解释
这是一个 Python 脚本,它用于导出经过训练的模型,使其可以在生产环境中进行推理。该脚本首先使用 TensorFlow 的 flags 定义了一些参数,如模型版本号、模型路径、输出目录等等。然后,它创建了一个名为 inference_graph 的 TensorFlow 图,并定义了一个 InferenceModel,该模型用于从输入数据中推断评级。
352 0
Jmeter组件-Random CSV Data Set Config参数化CSV随机读取文件
Jmeter组件-Random CSV Data Set Config参数化CSV随机读取文件
Jmeter组件-Random CSV Data Set Config参数化CSV随机读取文件