Load and preprocess images

简介: This tutorial shows how to load and preprocess an image dataset in three ways. First, you will use high-level Keras preprocessing utilities and layers to read a directory of images on disk. Next, you will write your own input pipeline from scratch using tf.data. Finally, you will download a dataset

This tutorial shows how to load and preprocess an image dataset in three ways. First, you will use high-level Keras preprocessing utilities and layers to read a directory of images on disk. Next, you will write your own input pipeline from scratch using tf.data. Finally, you will download a dataset from the large catalog available in TensorFlow Datasets.

import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds

Download the flowers dataset

This tutorial uses a dataset of several thousand photos of flowers. The flowers dataset contains 5 sub-directories, one per class:

flowers_photos/
  daisy/
  dandelion/
  roses/
  sunflowers/
  tulips/
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url, 
                                   fname='flower_photos', 
                                   untar=True)
data_dir = pathlib.Path(data_dir)

# After downloading (218MB), you should now have a copy of the flower photos available. There are 3,670 total images:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

Load using tf.keras.preprocessing

Let's load these images off disk using tf.keras.preprocessing.image_dataset_from_directory.

Create a dataset

Define some parameters for the loader:

batch_size = 32
img_height = 180
img_width = 180

It's good practice to use a validation split when developing your model. You will use 80% of the images for training and 20% for validation.

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

You can find the class names in the class_names attribute on these datasets.

class_names = train_ds.class_names
print(class_names)

Standardize the data

The RGB channel values are in the [0, 255] range. This is not ideal for a neural network; in general you should seek to make your input values small. Here, you will standardize values to be in the [0, 1] range by using the tf.keras.layers.experimental.preprocessing.Rescaling layer.

normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)

There are two ways to use this layer. You can apply it to the dataset by calling map:

normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixels values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))

Configure the dataset for performance

Let's make sure to use buffered prefetching, so you can yield data from disk without having I/O become blocking. These are two important methods you should use when loading data.

.cache() keeps the images in memory after they're loaded off disk during the first epoch. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache.

.prefetch() overlaps data preprocessing and model execution while training.

Interested readers can learn more about both methods, as well as how to cache data to disk in the data performance guide.

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

Train a model

For completeness, you will show how to train a simple model using the datasets you have just prepared. This model has not been tuned in any way - the goal is to show you the mechanics using the datasets you just created. To learn more about image classification, visit this tutorial

num_classes = 5model = tf.keras.Sequential([  tf.keras.layers.experimental.preprocessing.Rescaling(1./255),  tf.keras.layers.Conv2D(32, 3, activation='relu'),  tf.keras.layers.MaxPooling2D(),  tf.keras.layers.Conv2D(32, 3, activation='relu'),  tf.keras.layers.MaxPooling2D(),  tf.keras.layers.Conv2D(32, 3, activation='relu'),  tf.keras.layers.MaxPooling2D(),  tf.keras.layers.Flatten(),  tf.keras.layers.Dense(128, activation='relu'),  tf.keras.layers.Dense(num_classes)])
model.compile(  optimizer='adam',  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),  metrics=['accuracy'])
model.fit(  train_ds,  validation_data=val_ds,  epochs=3)

You may notice the validation accuracy is low compared to the training accuracy, indicating your model is overfitting. You can learn more about overfitting and how to reduce it in this tutorial.

Using tf.data for finer control

The above tf.keras.preprocessing utilities are a convenient way to create a tf.data.Dataset from a directory of images. For finer grain control, you can write your own input pipeline using tf.data. This section shows how to do just that, beginning with the file paths from the TGZ file you downloaded earlier.

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]))print(class_names)

Split the dataset into train and validation:

val_size = int(image_count * 0.2)train_ds = list_ds.skip(val_size)val_ds = list_ds.take(val_size)

You can see the length of each dataset as follows:

print(tf.data.experimental.cardinality(train_ds).numpy())print(tf.data.experimental.cardinality(val_ds).numpy())

Write a short function that converts a file path to an (img, label) pair:

def get_label(file_path):  # convert the path to a list of path components  parts = tf.strings.split(file_path, os.path.sep)  # The second to last is the class-directory  one_hot = parts[-2] == class_names  # Integer encode the label  return tf.argmax(one_hot)
def decode_img(img):  # convert the compressed string to a 3D uint8 tensor  img = tf.io.decode_jpeg(img, channels=3)  # resize the image to the desired size  return tf.image.resize(img, [img_height, img_width])
def process_path(file_path):  label = get_label(file_path)  # load the raw data from the file as a string  img = tf.io.read_file(file_path)  img = decode_img(img)  return img, label

Use Dataset.map to create a dataset of image, label pairs:

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in train_ds.take(1):  print("Image shape: ", image.numpy().shape)  print("Label: ", label.numpy())

Configure dataset for performance

To train a model with this dataset you will want the data:

  • To be well shuffled.
  • To be batched.
  • Batches to be available as soon as possible.

These features can be added using the tf.data API. For more details, see the Input Pipeline Performance guide.

def configure_for_performance(ds):  ds = ds.cache()  ds = ds.shuffle(buffer_size=1000)  ds = ds.batch(batch_size)  ds = ds.prefetch(buffer_size=AUTOTUNE)  return dstrain_ds = configure_for_performance(train_ds)val_ds = configure_for_performance(val_ds)

Continue training the model

You have now manually built a similar tf.data.Dataset to the one created by the keras.preprocessing above. You can continue training the model with it. As before, you will train for just a few epochs to keep the running time short

model.fit(  train_ds,  validation_data=val_ds,  epochs=3)

Using TensorFlow Datasets

So far, this tutorial has focused on loading data off disk. You can also find a dataset to use by exploring the large catalog of easy-to-download datasets at TensorFlow Datasets. As you have previously loaded the Flowers dataset off disk, let's see how to import it with TensorFlow Datasets.

(train_ds, val_ds, test_ds), metadata = tfds.load(    'tf_flowers',    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],    with_info=True,    as_supervised=True,)

The flowers dataset has five classes.

num_classes = metadata.features['label'].num_classesprint(num_classes)

As before, remember to batch, shuffle, and configure each dataset for performance.

train_ds = configure_for_performance(train_ds)val_ds = configure_for_performance(val_ds)test_ds = configure_for_performance(test_ds)

You can find a complete example of working with the flowers dataset and TensorFlow Datasets by visiting the Data augmentation tutorial.

Next steps

This tutorial showed two ways of loading images off disk. First, you learned how to load and preprocess an image dataset using Keras preprocessing layers and utilities. Next, you learned how to write an input pipeline from scratch using tf.data. Finally, you learned how to download a dataset from TensorFlow Datasets. As a next step, you can learn how to add data augmentation by visiting this tutorial. To learn more about tf.data, you can visit the tf.data: Build TensorFlow input pipelines guide.

代码链接: https://codechina.csdn.net/csdn_codechina/enterprise_technology/-/blob/master/load_preprocess_images.ipynb

目录
相关文章
|
8月前
|
人工智能 边缘计算 算法
AI人流热力图分析监测技术
通过深度学习算法(如CSRNet)进行实时密度估算和热力图生成,结合历史数据分析预测高峰时段,优化人员调度与促销活动。采用边缘计算减少延迟,确保实时响应,并通过数据可视化工具提升管理决策效率。
650 24
|
人工智能 安全 数据管理
2024年CIO的14项优先事项和趋势
2024年CIO的14项优先事项和趋势
|
Java
Java 解压zip压缩包
因为最近项目需要批量上传文件,而这里的批量就是将文件压缩在了一个zip包里,然后读取文件进行解析文件里的内容。 因此需要先对上传的zip包进行解压。以下直接提供代码供参考: 1.第一个方法是用于解压zip压缩包的方法。
2941 0
|
7天前
|
存储 关系型数据库 分布式数据库
PostgreSQL 18 发布,快来 PolarDB 尝鲜!
PostgreSQL 18 发布,PolarDB for PostgreSQL 全面兼容。新版本支持异步I/O、UUIDv7、虚拟生成列、逻辑复制增强及OAuth认证,显著提升性能与安全。PolarDB-PG 18 支持存算分离架构,融合海量弹性存储与极致计算性能,搭配丰富插件生态,为企业提供高效、稳定、灵活的云数据库解决方案,助力企业数字化转型如虎添翼!
|
6天前
|
存储 人工智能 Java
AI 超级智能体全栈项目阶段二:Prompt 优化技巧与学术分析 AI 应用开发实现上下文联系多轮对话
本文讲解 Prompt 基本概念与 10 个优化技巧,结合学术分析 AI 应用的需求分析、设计方案,介绍 Spring AI 中 ChatClient 及 Advisors 的使用。
316 130
AI 超级智能体全栈项目阶段二:Prompt 优化技巧与学术分析 AI 应用开发实现上下文联系多轮对话
|
18天前
|
弹性计算 关系型数据库 微服务
基于 Docker 与 Kubernetes(K3s)的微服务:阿里云生产环境扩容实践
在微服务架构中,如何实现“稳定扩容”与“成本可控”是企业面临的核心挑战。本文结合 Python FastAPI 微服务实战,详解如何基于阿里云基础设施,利用 Docker 封装服务、K3s 实现容器编排,构建生产级微服务架构。内容涵盖容器构建、集群部署、自动扩缩容、可观测性等关键环节,适配阿里云资源特性与服务生态,助力企业打造低成本、高可靠、易扩展的微服务解决方案。
1330 8
|
5天前
|
监控 JavaScript Java
基于大模型技术的反欺诈知识问答系统
随着互联网与金融科技发展,网络欺诈频发,构建高效反欺诈平台成为迫切需求。本文基于Java、Vue.js、Spring Boot与MySQL技术,设计实现集欺诈识别、宣传教育、用户互动于一体的反欺诈系统,提升公众防范意识,助力企业合规与用户权益保护。
|
17天前
|
机器学习/深度学习 人工智能 前端开发
通义DeepResearch全面开源!同步分享可落地的高阶Agent构建方法论
通义研究团队开源发布通义 DeepResearch —— 首个在性能上可与 OpenAI DeepResearch 相媲美、并在多项权威基准测试中取得领先表现的全开源 Web Agent。
1411 87