CNN鲜花分类

简介: CNN鲜花分类

@toc

1、数据集介绍

image-20220731201217893

总共5种花,按照文件夹区分花朵的类别。

image-20220731201256905

下载下来的是个压缩包,需要将其解压。

数据集下载地址:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

2、代码实战

2.1 导入依赖

import PIL
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import lib

import tensorflow as tf
from tensorflow.keras import layers,models

2.2 下载数据

# 下载数据集到本地
data_url='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'
data_dir=tf.keras.utils.get_file('flower_photos',origin=data_url,untar=True)#untar=True 下载后解压
data_dir=pathlib.Path(data_dir)

2.3 统计数据集

# 统计数据集大小
dataset_size=len(list(data_dir.glob('*/*.jpg')))
dataset_size

image-20220731201534382

总共3670张照片,比上次小狗分类那个少多了。
# 显示部分图片
imgs=list(data_dir.glob('*/*.jpg'))
imgs

image-20220731201624819

查看下第1张图片

img1=imgs[0] #第一张图片
img1

image-20220731201645820

str(img1)
PIL.Image.open(str(img1)) #读取并显示

image-20220731201712941

再查看下第2张图片

img2=imgs[1] #第2张图片
PIL.Image.open(str(img2))

image-20220731201810987

2.4 创建dataset

训练集:

# 3 创建dataset
BATCH_SIZE=32 
HEIGHT=180
WIDTH=180

#80%是训练集,20%是验证集
train_ds=tf.keras.preprocessing.image_dataset_from_directory(directory=data_dir,
                                                            batch_size=BATCH_SIZE,
                                                            validation_split=0.2,
                                                            subset='training',
                                                            seed=666,
                                                            image_size=(HEIGHT,WIDTH))
train_ds

image-20220731201905489

class_names=train_ds.class_names #数据集类别
class_names

image-20220731201929563

验证集:

val_ds=tf.keras.preprocessing.image_dataset_from_directory(directory=data_dir,
                                                            batch_size=BATCH_SIZE,
                                                            validation_split=0.2,
                                                            subset='validation',
                                                            seed=666,
                                                            image_size=(HEIGHT,WIDTH))
val_ds

image-20220731201943016

2.5 可视化一个batch_size

# 可视化一个batch_size的数据
for images,labels in train_ds.take(1):
    for i in range(9): # 一个batch_size有32张,这里只显示9张
        plt.subplot(3,3,i+1)
        plt.imshow(images[i].numpy().astype('uint8'))
        plt.title(class_names[labels[i]])
        plt.axis('off')

image-20220731202114050

2.6 将数据集缓存到内存中,加速读取

#将数据集缓存到内存中,加速读取
AUTOTUNE=tf.data.AUTOTUNE
train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)

2.7 搭建模型

这里仅作测试,并没有使用预训练模型
#搭建模型
model=models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255,input_shape=(HEIGHT,WIDTH,3)),# 数据归一化
    layers.Conv2D(16,3,padding='same',activation='relu'),
    layers.MaxPool2D(),
    layers.Conv2D(32,3,padding='same',activation='relu'),
    layers.MaxPool2D(),
    layers.Conv2D(64,3,padding='same',activation='relu'),
    layers.MaxPool2D(),
    layers.Flatten(),
    layers.Dense(128,activation='relu'),
    layers.Dense(5)
])
model.summary()

image-20220731202226475

2.8 编译模型

#编译模型
model.compile(optimizer='adam',
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])
这里使用的SparseCategoricalCrossentropy会自动帮我们

2.9 模型训练

#模型训练
EPOCHS=10
history=model.fit(train_ds,validation_data=val_ds,epochs=EPOCHS)
这里由于设备太拉跨,略微出手已是显卡极限,所以就只设置了10个epoch

image-20220731202356629

2.10 可视化训练结果

# 可视化训练结果
ranges=range(EPOCHS)
train_acc=history.history['accuracy']
val_acc=history.history['val_accuracy']

train_loss=history.history['loss']
val_loss=history.history['val_loss']

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.plot(ranges,train_acc,label='train_acc')
plt.plot(ranges,val_acc,label='val_acc')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')

plt.subplot(1,2,2)
plt.plot(ranges,train_loss,label='train_loss')
plt.plot(ranges,val_loss,label='val_loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

plt.show()

image-20220731202720010

==过拟合非常严重,下面对模型进行优化==

3、模型优化

3.1 数据增强设置

# 数据增强参数设置
data_argumentation=tf.keras.Sequential([
    # 随机水平翻转
    layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=(HEIGHT,WIDTH,3)),
    # 随机旋转
    layers.experimental.preprocessing.RandomRotation(0.1), # 旋转
    # 随机缩放
    layers.experimental.preprocessing.RandomZoom(0.1),  # 
])
这块的API太多了,多去查查官网。

3.2 显示数据增强后的效果

# 显示数据增强后的效果
for images,labels in train_ds.take(1):
    for i in range(9): # 一个batch_size有32张,这里只显示9张
        plt.subplot(3,3,i+1)
        argumeng_images=data_argumentation(images) #数据增强
        plt.imshow(argumeng_images[i].numpy().astype('uint8')) # 显示
        plt.title(class_names[labels[i]])
        plt.axis('off')

image-20220731203516183

3.3 搭建新的模型

#搭建新的模型
model_2=models.Sequential([
    data_argumentation, # 数据增强
    layers.experimental.preprocessing.Rescaling(1./255),# 数据归一化
    layers.Conv2D(16,3,padding='same',activation='relu'),
    layers.MaxPool2D(),
    layers.Conv2D(32,3,padding='same',activation='relu'),
    layers.MaxPool2D(),
    layers.Conv2D(64,3,padding='same',activation='relu'),
    layers.MaxPool2D(),
    layers.Dropout(0.2),
    layers.Flatten(),
    layers.Dense(128,activation='relu'),
    layers.Dense(5)
])
model_2.summary()

image-20220731203554763

3.4 编译模型

#编译模型
model_2.compile(optimizer='adam',
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])

3.5 模型训练

#模型训练
history=model_2.fit(train_ds,validation_data=val_ds,epochs=EPOCHS)

image-20220731203641691

3.6 可视化训练结果

# 可视化训练结果
ranges=range(EPOCHS)
train_acc=history.history['accuracy']
val_acc=history.history['val_accuracy']

train_loss=history.history['loss']
val_loss=history.history['val_loss']

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.plot(ranges,train_acc,label='train_acc')
plt.plot(ranges,val_acc,label='val_acc')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')

plt.subplot(1,2,2)
plt.plot(ranges,train_loss,label='train_loss')
plt.plot(ranges,val_loss,label='val_loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

plt.show()

image-20220731203701217

现在这个效果比优化之前的好多了。

3.7 模型预测

# 模型预测
test_img=tf.keras.preprocessing.image.load_img('sunfloor.jpg',target_size=(HEIGHT,WIDTH))
test_img

image-20220731203738583

这里我们自己在网上下载一张向日葵的图片进行预测
test_img=tf.keras.preprocessing.image.img_to_array(test_img) # 类型变换
test_img.shape

image-20220731203834288

将数据扩充一维,因为第一个维度是batchsize

test_img=tf.expand_dims(test_img,0) #扩充一维
test_img.shape

image-20220731203938104

预测:

preds=model_2.predict(test_img) #预测
preds.shape

image-20220731204055097

得分:

preds #得分

image-20220731204114249

得分转换成概率:

scores=tf.nn.softmax(preds[0])# 得分转换成概率
scores

image-20220731204134881

print('模型预测可能性最大的类别是:{},概率值为:{}'.format(class_names[np.argmax(scores)],np.max(scores)))

image-20220731204159161

这里最后一个全连接层可以直接加上个softmax激活函数,这样预测后就不用再转化了。
目录
相关文章
|
21天前
|
机器学习/深度学习 算法 TensorFlow
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
|
21天前
|
机器学习/深度学习 数据可视化 数据挖掘
R语言深度学习卷积神经网络 (CNN)对 CIFAR 图像进行分类:训练与结果评估可视化
R语言深度学习卷积神经网络 (CNN)对 CIFAR 图像进行分类:训练与结果评估可视化
|
21天前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
|
8月前
|
机器学习/深度学习 传感器 算法
NGO-CNN-SVM分类预测 | Matlab 北方苍鹰算法优化卷积神经网络-支持向量机分类预测
NGO-CNN-SVM分类预测 | Matlab 北方苍鹰算法优化卷积神经网络-支持向量机分类预测
|
21天前
|
机器学习/深度学习 数据采集 传感器
基于CNN和双向gru的心跳分类系统
CNN and Bidirectional GRU-Based Heartbeat Sound Classification Architecture for Elderly People是发布在2023 MDPI Mathematics上的论文,提出了基于卷积神经网络和双向门控循环单元(CNN + BiGRU)注意力的心跳声分类,论文不仅显示了模型还构建了完整的系统。
49 6
|
21天前
|
机器学习/深度学习 并行计算 算法
【计算机视觉+CNN】keras+ResNet残差网络实现图像识别分类实战(附源码和数据集 超详细)
【计算机视觉+CNN】keras+ResNet残差网络实现图像识别分类实战(附源码和数据集 超详细)
83 0
|
机器学习/深度学习
【文本分类】基于预训练语言模型的BERT-CNN多层级专利分类研究
【文本分类】基于预训练语言模型的BERT-CNN多层级专利分类研究
251 0
【文本分类】基于预训练语言模型的BERT-CNN多层级专利分类研究
|
9月前
|
机器学习/深度学习 传感器 自然语言处理
多元分类预测 | Matlab 基于卷积长短期记忆网络(CNN-LSTM)分类预测
多元分类预测 | Matlab 基于卷积长短期记忆网络(CNN-LSTM)分类预测
|
9月前
|
机器学习/深度学习 传感器 算法
多元分类预测 | Matlab 基于基于卷积双向长短期记忆网络(CNN-BILSTM)分类预测
多元分类预测 | Matlab 基于基于卷积双向长短期记忆网络(CNN-BILSTM)分类预测
|
9月前
|
机器学习/深度学习 传感器 算法
多元分类预测 | Matlab 基于卷积支持向量机(CNN-SVM)分类预测
多元分类预测 | Matlab 基于卷积支持向量机(CNN-SVM)分类预测

热门文章

最新文章