tfrecords 文件读取|学习笔记

简介: 快速学习 tfrecords 文件读取

开发者学堂课程【深度学习框架TensorFlow入门tfrecords 文件读取学习笔记,与课程紧密联系,让用户快速学习知识。

课程地址https://developer.aliyun.com/learning/course/773/detail/13557


tfrecords 文件读取


内容介绍:

一、读取 TFRecords 文件API

二、案例:读取 CIFAR 的 TFRecords 文件


一、读取TFRecords文件API

1)构造文件名队列

2)读取

解析example

tf.parse_single_example(value,features={

“image”:tf.FixedLenFeature([],tf.string),

“label”:tf.FixedLenFeature([],tf.int64)

})

image = feature[“iamge”]

lebel = feature[“lebel”]

解码

tf.decode_raw()

3)构造批处理队列

读取这种文件整个过程与其他问价一样,只不过需要有个解析 Example 的步骤。从 TFRecords 文件中读取数据,可以使用 tf.TFRecordReader 的 tf.parse_single_example 解析器。这个操作可以将 Example 协议内存块 (protocol buffer) 解析为张量。

tf.parse_single_example(serialized,features=None,name=None)

解析一个单一的 Example 原型

serialized :标量字符串 Tensor ,一个序列化的 Example

features : dict 字典数据,键为读取的名字,值为 FixedLenFeature

return :一个键值对组成的字典,键为读取的名字

tf.FixedLenFeature(shape,dtype)

shape :输入数据的形状,一般不指定,为空列表

dtype :输入数据类型,与存储进文件的类型要一致

类型只能是 float32 , int34 , string


二、案例:读取 CIFAR 的 TFRecords 文件

import tensorflow as tf

class Cifar(object):

def __init__(self):

# 初始化操作

self.height = 32

self.width = 32

self.channels = 3

# 设置图像字节数

self.image_bytes = self.height * self.width * self.channels

self.label_bytes = 1

self.all_bytes = self.label_bytes + self.image_bytes

def read_and_decode(self):

"""

读取二进制文件

:return:

"""

# 1、构造文件名队列

file_queue = tf.train.string_input_producer(file_list)

# 2、读取与解码

# 读取阶段

reader = tf.FixedLengthRecordReader(self.all_bytes)

# key 文件名 value 一个样本

key,value = reader,read(file_queue)

print("key:\n",key)

print("value:\n",value)

# 解码阶段

decoded = tf.decode_raw(value,tf.uint8)

print("decoded:\n",decoded)

# 将目标值和特征值切片切开

label=tf.slice(decoded,[0],[self.label_bytes])

tf.slice(decoded,[self.label_bytes],[self.image_bytes])

print("label:\n",label)

print("image:\n",image)

# 调整图片形状

image_reshaped=tf.reshaped(image,shape=[self,channels,self.heighy,self.width])

print("image_reshaped:\n",image_reshaped)

# 转置,将图片的顺序转为 height,width,channels

image_transposed=tf.transpose(image_reshaped,[1,2,0])

print("image_transposed:\n",image_transposed)

# 调整图像类型

image_cast=tf.cast(image_transposed,tf.float32)

# 3、批处理

label_batch,image_batch=tf.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100)

# 开启会话

with tf.Session() as sess:

# 开启线程

coord = tf.train.Coordinator()

threads=tf.train.start_queue_runners(sess=sess,coord=coord)

key_new,value_new,decoded_new,label_new,image_new,image_reshaped_new,image_transposed_new=sess.run([key,value,decoded,label,image_reshaped,image_transposed])

label_value,image_value=sess.run(label_batch,image_batch)

print("key_new:\n",key_new)

print("value_new:\n",value_new)

print("decoded_new:\n",decoded_new)

print("label_new:\n",label_new)

print("image_new:\n",image_new)

print("image_reshaped_new:\n",image_reshaped_new)

print("image_transposed_new:\n",image_transposed_new)

print("label_value:\n",label_value)

print("image_value:\n",image_value)

# 回收线程

coord.request_stop()

coord.join(threads)

return image_value,label_value

#写入

def write_to_tfrecords(self,image_batch,label_batch):

"""

将样本的特征值和目标值一起写入 tfrecords 文件

:param image:

:param label:

:return:

"""

with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:

# 循环构造 example 对象,并序列化写入文件

for i in range(100):

image = image_batch[i].tostring()

label = label_batch[i][0]

# print("tfrecords_image:\n",image)

# print("tfrecords_label:\n",label)

example=tf.train.Example(features=tf.train.Features(features={

“image”:tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]),

“label”:tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),

}))

# example.SerializeToString

# 将序列化后的 example 写入文件

writer.write(example.SerializeToString())

return None

#读取

def read_tfrecords(self):

"""

读取 TFRecords 文件

:return:

"""

# 1、构造文件名队列

file_queue = tf.train.string_input_producer(["cifar10.tfrecords])

# 2、读取与解码

# 读取

reader = tf.TFRecordReader()

key,value = reader.read(file_queue)

# 解析 example

tf.parse_single_example(value,features={

“image”:tf.FixedLenFeature([],tf.string),

“label”:tf.FixedLenFeature([],tf.int64)

})

image = feature[“iamge”]

lebel = feature[“lebel”]

print("read_tf_image:\n",image)

print("read_tf_label:\n",label)

# 解码

tf.decode_raw(image,tf.uint8)

print("image_decoded:\n",image_decoded)

# 图像形状调整

image_reshaped=tf.reshape(image_decoded,[self.height,self.width,self.channel])

print("image_reshaped:\n",image_reshaped)

# 3、构造批处理队列

image_batch,label_batch=tf.train.batch([image_reshaped,label],batch_size=100,num_threads=2,capacity=100)

print("image_batch:\n",image_batch)

print("label_batch:\n",label_batch)

# 开启会话

with tf.Session() as sess:

#开启线程

coord = tf.train.Coordinator()

threads=tf.train.start_queue_runners(sess=sess,coord=coord)

image_value,label_value=sess.run([image_batch,label_batch])

print("image_value:\n",image_value)

print("label_value:\n",label_value)

#回收资源

coord.request_stop()

cooed.join(threads)

return None

if __name__ == "__main__":

cifar = Cifar()

# image_value,label_value = cifar.read_binary()

# cifar.write_to_tfrecords(image_value,label_value)

cifar.read_tfrecords()

相关文章
学术规范与论文写作(雨课堂)(研究生)期末考试 正确顺序
学术规范与论文写作(雨课堂)(研究生)期末考试 正确顺序
761 0
学术规范与论文写作(雨课堂)(研究生)期末考试 正确顺序
|
4天前
|
人工智能 运维 安全
|
2天前
|
人工智能 异构计算
敬请锁定《C位面对面》,洞察通用计算如何在AI时代持续赋能企业创新,助力业务发展!
敬请锁定《C位面对面》,洞察通用计算如何在AI时代持续赋能企业创新,助力业务发展!
|
10天前
|
人工智能 JavaScript 测试技术
Qwen3-Coder入门教程|10分钟搞定安装配置
Qwen3-Coder 挑战赛简介:无论你是编程小白还是办公达人,都能通过本教程快速上手 Qwen-Code CLI,利用 AI 轻松实现代码编写、文档处理等任务。内容涵盖 API 配置、CLI 安装及多种实用案例,助你提升效率,体验智能编码的乐趣。
833 109
|
4天前
|
机器学习/深度学习 人工智能 自然语言处理
B站开源IndexTTS2,用极致表现力颠覆听觉体验
在语音合成技术不断演进的背景下,早期版本的IndexTTS虽然在多场景应用中展现出良好的表现,但在情感表达的细腻度与时长控制的精准性方面仍存在提升空间。为了解决这些问题,并进一步推动零样本语音合成在实际场景中的落地能力,B站语音团队对模型架构与训练策略进行了深度优化,推出了全新一代语音合成模型——IndexTTS2 。
435 11
|
3天前
|
人工智能 测试技术 API
智能体(AI Agent)搭建全攻略:从概念到实践的终极指南
在人工智能浪潮中,智能体(AI Agent)正成为变革性技术。它们具备自主决策、环境感知、任务执行等能力,广泛应用于日常任务与商业流程。本文详解智能体概念、架构及七步搭建指南,助你打造专属智能体,迎接智能自动化新时代。