开发者学堂课程【深度学习框架TensorFlow入门:tfrecords 文件读取】学习笔记,与课程紧密联系,让用户快速学习知识。
课程地址:https://developer.aliyun.com/learning/course/773/detail/13557
tfrecords 文件读取
内容介绍:
一、读取 TFRecords 文件API
二、案例:读取 CIFAR 的 TFRecords 文件
一、读取TFRecords文件API
1)构造文件名队列
2)读取
解析example
tf.parse_single_example(value,features={
“image”:tf.FixedLenFeature([],tf.string),
“label”:tf.FixedLenFeature([],tf.int64)
})
image = feature[“iamge”]
lebel = feature[“lebel”]
解码
tf.decode_raw()
3)构造批处理队列
读取这种文件整个过程与其他问价一样,只不过需要有个解析 Example 的步骤。从 TFRecords 文件中读取数据,可以使用 tf.TFRecordReader 的 tf.parse_single_example 解析器。这个操作可以将 Example 协议内存块 (protocol buffer) 解析为张量。
tf.parse_single_example(serialized,features=None,name=None)
解析一个单一的 Example 原型
serialized :标量字符串 Tensor ,一个序列化的 Example
features : dict 字典数据,键为读取的名字,值为 FixedLenFeature
return :一个键值对组成的字典,键为读取的名字
tf.FixedLenFeature(shape,dtype)
shape :输入数据的形状,一般不指定,为空列表
dtype :输入数据类型,与存储进文件的类型要一致
类型只能是 float32 , int34 , string
二、案例:读取 CIFAR 的 TFRecords 文件
import tensorflow as tf
class Cifar(object):
def __init__(self):
# 初始化操作
self.height = 32
self.width = 32
self.channels = 3
# 设置图像字节数
self.image_bytes = self.height * self.width * self.channels
self.label_bytes = 1
self.all_bytes = self.label_bytes + self.image_bytes
def read_and_decode(self):
"""
读取二进制文件
:return:
"""
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(file_list)
# 2、读取与解码
# 读取阶段
reader = tf.FixedLengthRecordReader(self.all_bytes)
# key 文件名 value 一个样本
key,value = reader,read(file_queue)
print("key:\n",key)
print("value:\n",value)
# 解码阶段
decoded = tf.decode_raw(value,tf.uint8)
print("decoded:\n",decoded)
# 将目标值和特征值切片切开
label=tf.slice(decoded,[0],[self.label_bytes])
tf.slice(decoded,[self.label_bytes],[self.image_bytes])
print("label:\n",label)
print("image:\n",image)
# 调整图片形状
image_reshaped=tf.reshaped(image,shape=[self,channels,self.heighy,self.width])
print("image_reshaped:\n",image_reshaped)
# 转置,将图片的顺序转为 height,width,channels
image_transposed=tf.transpose(image_reshaped,[1,2,0])
print("image_transposed:\n",image_transposed)
# 调整图像类型
image_cast=tf.cast(image_transposed,tf.float32)
# 3、批处理
label_batch,image_batch=tf.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
key_new,value_new,decoded_new,label_new,image_new,image_reshaped_new,image_transposed_new=sess.run([key,value,decoded,label,image_reshaped,image_transposed])
label_value,image_value=sess.run(label_batch,image_batch)
print("key_new:\n",key_new)
print("value_new:\n",value_new)
print("decoded_new:\n",decoded_new)
print("label_new:\n",label_new)
print("image_new:\n",image_new)
print("image_reshaped_new:\n",image_reshaped_new)
print("image_transposed_new:\n",image_transposed_new)
print("label_value:\n",label_value)
print("image_value:\n",image_value)
# 回收线程
coord.request_stop()
coord.join(threads)
return image_value,label_value
#写入
def write_to_tfrecords(self,image_batch,label_batch):
"""
将样本的特征值和目标值一起写入 tfrecords 文件
:param image:
:param label:
:return:
"""
with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
# 循环构造 example 对象,并序列化写入文件
for i in range(100):
image = image_batch[i].tostring()
label = label_batch[i][0]
# print("tfrecords_image:\n",image)
# print("tfrecords_label:\n",label)
example=tf.train.Example(features=tf.train.Features(features={
“image”:tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]),
“label”:tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# example.SerializeToString
# 将序列化后的 example 写入文件
writer.write(example.SerializeToString())
return None
#读取
def read_tfrecords(self):
"""
读取 TFRecords 文件
:return:
"""
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(["cifar10.tfrecords])
# 2、读取与解码
# 读取
reader = tf.TFRecordReader()
key,value = reader.read(file_queue)
# 解析 example
tf.parse_single_example(value,features={
“image”:tf.FixedLenFeature([],tf.string),
“label”:tf.FixedLenFeature([],tf.int64)
})
image = feature[“iamge”]
lebel = feature[“lebel”]
print("read_tf_image:\n",image)
print("read_tf_label:\n",label)
# 解码
tf.decode_raw(image,tf.uint8)
print("image_decoded:\n",image_decoded)
# 图像形状调整
image_reshaped=tf.reshape(image_decoded,[self.height,self.width,self.channel])
print("image_reshaped:\n",image_reshaped)
# 3、构造批处理队列
image_batch,label_batch=tf.train.batch([image_reshaped,label],batch_size=100,num_threads=2,capacity=100)
print("image_batch:\n",image_batch)
print("label_batch:\n",label_batch)
# 开启会话
with tf.Session() as sess:
#开启线程
coord = tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
image_value,label_value=sess.run([image_batch,label_batch])
print("image_value:\n",image_value)
print("label_value:\n",label_value)
#回收资源
coord.request_stop()
cooed.join(threads)
return None
if __name__ == "__main__":
cifar = Cifar()
# image_value,label_value = cifar.read_binary()
# cifar.write_to_tfrecords(image_value,label_value)
cifar.read_tfrecords()