TensorFlow读取数据

简介: TensorFlow读取数据

TensorFlow读取数据


最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。

本文介绍如何使用TensorFlow来读取图片数据,主要介绍写入TFRecord文件再读取和直接使用队列来读取两种方式。假设我们图片目录结构如下:

|---a
|   |---1.jpg
|   |---2.jpg
|   |---3.jpg
|
|---b
|   |---1.jpg
|   |---2.jpg
|   |---3.jpg
|
|---c
|   |---1.jpg
|   |---2.jpg
|   |---3.jpg

1 使用TFRecoder

思路:思路:使用TFRecod主要是把每张图片及其对应的label写入到一个tfrecode文件中。tfrecode以二进制形式保存,其中内部使用了protobuf定义协议,即定义格式序列化为二进制。我们可以使用tf提供的tf.train.Example来指定序列化格式。将a目录中所有的文件的label指定为a,另外两个目录b、c同理。

代码如下:

def build_data(dir,file_str,map_str):
    '''
    :param dir: 根目录,dir下所有子目录名称为label
    :param file_str: 导出的tfrecorde文件
    :param map_str: 数字序号0~n与label映射关系保存路径
    :return:
    '''
    files=os.listdir(dir);
    writer = tf.python_io.TFRecordWriter(file_str)  # 要生成的文件
    # 由于tf.train.Feature只能取float、int和bytes,因此需要将label映射到int,保存到文件
    map_file = open(map_str,'w')
    for index,label in enumerate(files):     #遍历文件夹
        data_dir = os.path.join(dir,label)
        map_file.write(str(index) + ":" + label + "\n")
        for img_name in os.listdir(data_dir):  #遍历图片
            img_path=os.path.join(data_dir,img_name)
            img = Image.open(img_path)         #读取图片
            img = img.resize((256, 256))       #将图片宽高转为256*256
            img_raw=img.tobytes()              #图片转为字节
            example=tf.train.Example(features=tf.train.Features(feature={
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))
            writer.write(example.SerializeToString())  # 序列化为字符串并写入文件
    writer.close()
    map_file.close();

接下来是读取tfrecord文件。注意读取时label、img名称及类型要一致:

def read_data(file_str):
    # 根据文件名生成一个队列
    file_path_queue = tf.train.string_input_producer([file_str])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(file_path_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img': tf.FixedLenFeature([], tf.string),
                                       })
    label = tf.cast(features['label'], tf.int64)       # 读取label
    img = tf.decode_raw(features['img'], tf.uint8)
    img = tf.reshape(img, [256, 256, 3])               #将维度转为256*256的3通道
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5  #将图片中的数据转为[-0.5,0.5]
    return img, label

接下来看看如何使用:

build_data("D:/test","D:/data/tf.tfrecorde","D:/data/map.txt")
img, label =read_data("D:/data/tf.tfrecorde")
#使用shuffle_batch可以随机打乱输入
img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=30, capacity=2000,
                                                min_after_dequeue=1000)
init = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(init)
    threads = tf.train.start_queue_runners(sess=sess)
    for i in range(3):
        imgs, labels= sess.run([img_batch, label_batch])
        #我们也可以根据需要对val, l进行处理 
        print(imgs.shape, labels)

运行结果如下:

(30, 256, 256, 3) [1 2 2 1 1 2 2 1 0 1 0 1 0 0 2 0 0 0 2 1 1 1 1 0 0 1 2 1 2 0]
(30, 256, 256, 3) [2 1 1 0 0 1 1 0 2 2 2 0 0 0 0 2 1 0 0 2 0 0 2 2 2 1 0 1 0 2]
(30, 256, 256, 3) [2 0 2 0 1 2 1 2 2 1 0 2 0 0 2 2 2 1 1 1 1 1 0 0 2 0 2 2 0 0]

从结果可以看出,虽然我们提供的图片只有9张。每一类各3张,但是能读取303030张出来,这主要是通过循环读取得到的。也就是说数量上虽然增加了,但实际上也就是那9张图片。

2 不使用TFRecord

TFRecord适合将标签、图片数据等其他相关的数据一起封装到一个对象,然后逐个读取。有时候,我们并不需要标签,只需要对图片读取。那么可以考虑之间从路径队列中读取,而不需要转到TFRecord文件。

直接上代码:

def read_data(dir ):
    '''
    :param dir: 图片根目录
    '''
    input_paths = glob.glob(os.path.join(dir, "*.jpg"))
    decode = tf.image.decode_jpeg
    if len(input_paths) == 0:    #如果不存在jpg图片,则遍历png图片
        input_paths = glob.glob(os.path.join(dir, "*.png"))
        decode = tf.image.decode_png
    if len(input_paths) == 0:    #如果png图片不存在,抛出异常
        raise Exception("input_dir contains no image files")
    #产生文件路径队列,并且打乱顺序
    path_queue = tf.train.string_input_producer(input_paths, shuffle=True)
    reader = tf.WholeFileReader()   #创建读取文件对象
    paths, contents = reader.read(path_queue) #从队列中读取
    img_raw = decode(contents)
    # 将图片缩小到256*256,如果在此之前对图片预处理(放缩),那么这一步可省略
    img_raw = tf.image.resize_images(img_raw, [256, 256])
    img_raw = tf.image.convert_image_dtype(img_raw, dtype=tf.float32)
    img_raw.set_shape([256, 256, 3])#设置shape
    return img_raw

接下来看看如何使用:

img = read_data("D:/test/*" )
img_batch = tf.train.batch([img], batch_size=30)
init = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(init)
    threads = tf.train.start_queue_runners(sess=sess)
    for i in range(3):
        imgs = sess.run( img_batch )
        print(imgs.shape )

看看运行结果:

(30, 256, 256, 3)
(30, 256, 256, 3)
(30, 256, 256, 3)


相关文章
|
6月前
|
机器学习/深度学习 算法 TensorFlow
【Python深度学习】Tensorflow对半环形数据分类、手写数字识别、猫狗识别实战(附源码)
【Python深度学习】Tensorflow对半环形数据分类、手写数字识别、猫狗识别实战(附源码)
124 0
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
PYTHON TENSORFLOW 2二维卷积神经网络CNN对图像物体识别混淆矩阵评估|数据分享
PYTHON TENSORFLOW 2二维卷积神经网络CNN对图像物体识别混淆矩阵评估|数据分享
|
2月前
|
数据挖掘 PyTorch TensorFlow
|
2月前
|
机器学习/深度学习 数据挖掘 TensorFlow
🔍揭秘Python数据分析奥秘,TensorFlow助力解锁数据背后的亿万商机
【9月更文挑战第11天】在信息爆炸的时代,数据如沉睡的宝藏,等待发掘。Python以简洁的语法和丰富的库生态成为数据分析的首选,而TensorFlow则为深度学习赋能,助你洞察数据核心,解锁商机。通过Pandas库,我们可以轻松处理结构化数据,进行统计分析和可视化;TensorFlow则能构建复杂的神经网络模型,捕捉非线性关系,提升预测准确性。两者的结合,让你在商业竞争中脱颖而出,把握市场脉搏,释放数据的无限价值。以下是使用Pandas进行简单数据分析的示例:
43 5
|
2月前
|
机器学习/深度学习 数据挖掘 TensorFlow
从数据小白到AI专家:Python数据分析与TensorFlow/PyTorch深度学习的蜕变之路
【9月更文挑战第10天】从数据新手成长为AI专家,需先掌握Python基础语法,并学会使用NumPy和Pandas进行数据分析。接着,通过Matplotlib和Seaborn实现数据可视化,最后利用TensorFlow或PyTorch探索深度学习。这一过程涉及从数据清洗、可视化到构建神经网络的多个步骤,每一步都需不断实践与学习。借助Python的强大功能及各类库的支持,你能逐步解锁数据的深层价值。
65 0
|
4月前
|
数据挖掘 PyTorch TensorFlow
Python数据分析新纪元:TensorFlow与PyTorch双剑合璧,深度挖掘数据价值
【7月更文挑战第30天】随着大数据时代的发展,数据分析变得至关重要,深度学习作为其前沿技术,正推动数据分析进入新阶段。本文介绍如何结合使用TensorFlow和PyTorch两大深度学习框架,最大化数据价值。
103 8
|
3月前
|
缓存 开发者 测试技术
跨平台应用开发必备秘籍:运用 Uno Platform 打造高性能与优雅设计兼备的多平台应用,全面解析从代码共享到最佳实践的每一个细节
【8月更文挑战第31天】Uno Platform 是一种强大的工具,允许开发者使用 C# 和 XAML 构建跨平台应用。本文探讨了 Uno Platform 中实现跨平台应用的最佳实践,包括代码共享、平台特定功能、性能优化及测试等方面。通过共享代码、采用 MVVM 模式、使用条件编译指令以及优化性能,开发者可以高效构建高质量应用。Uno Platform 支持多种测试方法,确保应用在各平台上的稳定性和可靠性。这使得 Uno Platform 成为个人项目和企业应用的理想选择。
62 0
|
3月前
|
机器学习/深度学习 缓存 TensorFlow
TensorFlow 数据管道优化超重要!掌握这些关键技巧,大幅提升模型训练效率!
【8月更文挑战第31天】在机器学习领域,高效的数据处理对构建优秀模型至关重要。TensorFlow作为深度学习框架,其数据管道优化能显著提升模型训练效率。数据管道如同模型生命线,负责将原始数据转化为可理解形式。低效的数据管道会限制模型性能,即便模型架构先进。优化方法包括:合理利用数据加载与预处理功能,使用`tf.data.Dataset` API并行读取文件;使用`tf.image`进行图像数据增强;缓存数据避免重复读取,使用`cache`和`prefetch`方法提高效率。通过这些方法,可以大幅提升数据管道效率,加快模型训练速度。
47 0
|
4月前
|
机器学习/深度学习 数据挖掘 TensorFlow
🔍揭秘Python数据分析奥秘,TensorFlow助力解锁数据背后的亿万商机
【7月更文挑战第29天】在数据丰富的时代,Python以其简洁和强大的库支持成为数据分析首选。Pandas库简化了数据处理与分析,如读取CSV文件、执行统计分析及可视化销售趋势。TensorFlow则通过深度学习技术挖掘复杂数据模式,提升预测准确性。两者结合助力商业决策,把握市场先机,释放数据巨大价值。
50 4
|
4月前
|
机器学习/深度学习 数据挖掘 TensorFlow
数据界的“福尔摩斯”如何炼成?Python+TensorFlow数据分析实战全攻略
【7月更文挑战第30天】数据界的“福尔摩斯”运用Python与TensorFlow解开数据之谜。
52 2