使用 TensorFlow 和 Keras 构建图像分类器

简介: 【10月更文挑战第2天】使用 TensorFlow 和 Keras 构建图像分类器

引言

深度学习技术的进步使得计算机视觉领域发生了翻天覆地的变化,特别是图像分类任务。TensorFlow 是目前最流行的人工智能框架之一,而 Keras 则是建立在其之上的一种高级 API,旨在简化神经网络的设计与实现。本文将介绍如何使用 TensorFlow 和 Keras 构建一个简单的图像分类器,以识别 CIFAR-10 数据集中的图像类别。

环境准备

首先,确保你的环境中已安装了 Python,并且安装了 TensorFlow。可以使用以下命令安装 TensorFlow:

pip install tensorflow

如果你还没有安装 Jupyter Notebook,也可以通过以下命令来安装:

pip install jupyter

导入必要的库

我们将使用 TensorFlow 和 Keras API 来构建模型,并利用 NumPy 和 Matplotlib 来辅助数据处理和结果可视化。

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np
import matplotlib.pyplot as plt

数据预处理

CIFAR-10 数据集是一个著名的图像分类基准数据集,包含了 50000 个训练图像和 10000 个测试图像,每个图像都是 32x32 像素的彩色图像,共分为 10 个类别。

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Display some images
def plot_images(images, labels, class_names, to_predict=False):
    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(images[i])
        if to_predict:
            plt.xlabel(class_names[np.argmax(labels[i])])
        else:
            plt.xlabel(class_names[labels[i][0]])
    plt.show()

plot_images(train_images, train_labels, class_names)

构建模型

接下来,我们将构建一个卷积神经网络(Convolutional Neural Network,简称 CNN)来处理图像分类任务。

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.summary()

训练模型

使用训练数据来训练模型,并设置验证集来监控模型在未见数据上的表现。

history = model.fit(train_images, train_labels, epochs=10, 
                    validation_data=(test_images, test_labels))

评估模型

使用测试数据来评估模型的性能。

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
print(f'\nTest accuracy: {test_acc}')

预测图像

使用训练好的模型对一些测试图像进行预测。

probability_model = tf.keras.Sequential([model, 
                                         tf.keras.layers.Softmax()])
predictions = probability_model.predict(test_images)

plot_images(test_images, predictions, class_names, to_predict=True)

结论

本文通过使用 TensorFlow 和 Keras API 构建了一个用于 CIFAR-10 数据集的图像分类器。通过加载数据集、构建模型、训练和评估模型,我们展示了完整的机器学习流程。尽管这里使用的模型相对简单,但在实际应用中,可能需要进一步优化网络结构或使用更先进的技术来提升模型的性能。此外,还可以探索使用数据增强技术来进一步提升模型的泛化能力。

相关文章
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
104 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【大作业-04】手把手教你构建垃圾分类系统-基于tensorflow2.3
本文介绍了基于TensorFlow 2.3的垃圾分类系统,通过B站视频和博客详细讲解了系统的构建过程。系统使用了包含8万张图片、245个类别的数据集,训练了LeNet和MobileNet两个卷积神经网络模型,并通过PyQt5构建了图形化界面,用户上传图片后,系统能识别垃圾的具体种类。此外,还提供了模型和数据集的下载链接,方便读者复现实验。垃圾分类对于提高资源利用率、减少环境污染具有重要意义。
71 0
【大作业-04】手把手教你构建垃圾分类系统-基于tensorflow2.3
|
23天前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
53 5
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
56 3
|
29天前
|
机器学习/深度学习 数据采集 TensorFlow
利用TensorFlow实现简单的图像分类模型
利用TensorFlow实现简单的图像分类模型
29 0
|
2月前
|
机器学习/深度学习 数据可视化 TensorFlow
使用TensorFlow构建一个简单的图像分类模型
【10月更文挑战第18天】使用TensorFlow构建一个简单的图像分类模型
71 1
|
2月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
116 0
|
5天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
109 55
|
1月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
78 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
|
1月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
81 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型