开发者学堂课程【深度学习框架 TensorFlow 入门:二进制文件读取案例】学习笔记,与课程紧密联系,让用户快速学习知识。
课程地址:https://developer.aliyun.com/learning/course/773/detail/13555
二进制文件读取案例
内容介绍:
一、NHWC 与 NCHW
二、读取案例
一、NHWC 与 NCHW
在读取设置图片形状的时候有两种格式:
设置为“NHWC”时,排列顺序为[batch,height,width,channels];
设置为“NCHW”时,排列顺序为[batch,channels,height,width]。
其中 N 表示这批图像有几张, H 表示图像在竖直方向有多少像素, W 表示水平方向像素, C 表示通道数。
Tensorflow 默认的[height,width,channel]
假设 RGB 三通道两种格式的区别如下如所示:
理解
假设1,2,3,4-红色 5,6,7,8-绿色 9,10,11,12-蓝色
如果通过在最低维度0[channels,height,width],RGB 三颜色分成三组,在第一维度上找到三个 RGB 颜色
如果通过在最低维度0[height,width,channels],在第三维度上找到 RGB 三个颜色
二、读取案例
import tensorflow as tf
class Cifar(object):
def __init__(self):
# 初始化操作
self.height = 32
self.width = 32
self.channels = 3
# 设置图像字节数
self.image_bytes = self.height * self.width * self.channels
self.label_bytes = 1
self.all_bytes = self.label_bytes + self.image_bytes
def read_and_decode(self):
"""
读取二进制文件
:return:
"""
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(file_list)
# 2、读取与解码
# 读取阶段
reader = tf.FixedLengthRecordReader(self.all_bytes)
# key 文件名 value 一个样本
key,value = reader,read(file_queue)
print("key:\n",key)
print("value:\n",value)
# 解码阶段
decoded = tf.decode_raw(value,tf.uint8)
print("decoded:\n",decoded)
# 将目标值和特征值切片切开
label=tf.slice(decoded,[0],[self.label_bytes])
tf.slice(decoded,[self.label_bytes],[self.image_bytes])
print("label:\n",label)
print("image:\n",image)
# 调整图片形状
image_reshaped=tf.reshaped(image,shape=[self,channels,self.heighy,self.width])
print("image_reshaped:\n",image_reshaped)
# 转置,将图片的顺序转为 height,width,channels
image_transposed=tf.transpose(image_reshaped,[1,2,0])
print("image_transposed:\n",image_transposed)
# 调整图像类型
image_cast=tf.cast(image_transposed,tf.float32)
# 3、批处理
label_batch,image_batch=tf.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
key_new,value_new,decoded_new,label_new,image_new,image_reshaped_new,image_transposed_new=sess.run([key,value,decoded,label,image_reshaped,image_transposed])
label_value,image_value=sess.run(label_batch,image_batch)
print("key_new:\n",key_new)
print("value_new:\n",value_new)
print("decoded_new:\n",decoded_new)
print("label_new:\n",label_new)
print("image_new:\n",image_new)
print("image_reshaped_new:\n",image_reshaped_new)
print("image_transposed_new:\n",image_transposed_new)
print("label_value:\n",label_value)
print("image_value:\n",image_value)
# 回收线程
coord.request_stop()
coord.join(threads)
return None
if __name__ == "__main__":
file_name = os.listdir("./cifar-10-batches-bin")
print("file_name:\n",file_name)
# 构造文件名路径列表
fiel_list = [os.path.join("./cifar-10-batches-bin/",file) for file in file_name if file[-3:] == "bin"]
print("file_list:\n",file_list)
# 实例化 Cifar
Cifar = Cifar()
Cifar.read_and_decode(file_list)