【学习笔记】使用Keras构建CNN网络完成猫狗分类(适合初学者,简单易上手)

简介: 【学习笔记】使用Keras构建CNN网络完成猫狗分类(适合初学者,简单易上手)

【学习笔记】使用Keras构建CNN网络完成猫狗分类(适合初学者,简单易上手)

首先准备好猫和狗的图片数据集,在pycharm中新建一个项目cat_dog recognition,把数据集文件放在该文件夹下。

1ecd1b2606ed46e9956a89f231c9802c.png

2020062310470442.png

训练集和测试集都有猫和狗的图片。

1.图像数据预处理、

在项目中新建impreprocess.py文件:

from keras.preprocessing.image import ImageDataGenerator,array_to_img,img_to_array,load_img
import numpy as np
#数据图像生成
datagen=ImageDataGenerator(
    rotation_range=40,#随机旋转度数
    width_shift_range=0.2,#随机水平平移
    height_shift_range=0.2,#随机竖直平移
    rescale=1/255,#数据归一化
    shear_range=0.2,#随机裁剪
    zoom_range=0.2,#随机放大
    horizontal_flip=True,#水平翻转
    fill_mode='nearest',#填充方式
)
#载入图片
image=load_img('images/training_set/cats/cat.1.jpg')
x=img_to_array(image)#图像数据是一维的,把它转成数组形式
print(x.shape)
x=np.expand_dims(x,0)#在图片的0维增加一个维度,因为Keras处理图片时是4维,第一维代表图片数量
print(x.shape

运行结果:

1ecd1b2606ed46e9956a89f231c9802c.png

数据图像生成的参数说明:

1ecd1b2606ed46e9956a89f231c9802c.png

生成20张图片数据(特征图):在项目文件夹下新建temp文件夹

#生成20张图片数据
i=0
for batch in datagen.flow(x,batch_size=1,save_to_dir='temp',save_prefix='new_cat',save_format='jpeg'):
    i+=1
    if i==20:
        break
print('finshed!'

运行结果:

1ecd1b2606ed46e9956a89f231c9802c.png

在temp文件夹下随机生成20 张旋转、放大、缩小、平移等处理过的图片。这样做的目的是,因为一张图片有许多不同的状态,这样图片的信息量会增多,训练样本数据量就会变大,使得训练模型效果更好。

2.构造简单训练模型

在项目中新建CNNclassification.py文件

from keras.models import Sequential
from keras.layers import Conv2D,MaxPool2D,Activation,Dropout,Flatten,Dense
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from keras.models import  load_model
import numpy as np
from matplotlib import pyplot as plt
#定义模型
#将输入数据大小改为150*150*3再加入模型
model=Sequential()
model.add(Conv2D(input_shape= (150,150,3),filters=32,kernel_size=3,padding='same',activation='relu'))
model.add(Conv2D(filters=32,kernel_size=3,padding='same',activation='relu'))
model.add(MaxPool2D(pool_size=2,strides=2))
model.add(Conv2D(filters=64,kernel_size=3,padding='same',activation='relu'))
model.add(Conv2D(filters=64,kernel_size=3,padding='same',activation='relu'))
model.add(MaxPool2D(pool_size=2,strides=2))
model.add(Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
model.add(Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
model.add(MaxPool2D(pool_size=2,strides=2))
#卷积池化完原始图像是一个二维的特征图
model.add(Flatten())#把二维数据转换为一维
model.add(Dense(64,activation='relu'))#Dense代表全连接层,64是最后卷积池化后输出的神经元个数
model.add(Dropout(0.5))#防止过拟合
model.add(Dense(2,activation='softmax'))#softmax是把训练结果用概率形式表示的函数,2代表二分类
#定义优化器
adam=Adam(lr=1e-4)
#定义优化器、代价函数、训练过程中计算准确率
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
#数据增强
train_datagen=ImageDataGenerator(
    rotation_range=40,#随机旋转度数
    width_shift_range=0.2,#随机水平平移
    height_shift_range=0.2,#随机竖直平移
    rescale=1/255,#数据归一化
    shear_range=0.2,#随机裁剪
    zoom_range=0.2,#随机放大
    horizontal_flip=True,#水平翻转
    fill_mode='nearest',#填充方式
)
test_datagen=ImageDataGenerator(
    rescale=1/255,#数据归一化
)
batch=32#每次训练传入32张照片
#生成训练数据
train_generator=train_datagen.flow_from_directory(
    'images/training_set',#从训练集这个目录生成数据
    target_size=(150,150),#把生成数据大小定位150*150
    batch_size=batch,
)
#测试数据
test_generator=test_datagen.flow_from_directory(
    'images/test_set',#从训练集这个目录生成数据
    target_size=(150,150),#把生成数据大小定位150*150
    batch_size=batch,
)
#查看定义类别分类
print(train_generator.class_indices)
#定义训练模型
#传入生成的训练数据、每张图片训练1次,验证数据为生成的测试数据
model.fit_generator(train_generator,epochs=1,validation_data=test_generator

其中CNN卷积神经网络的搭建过程就是卷积层、池化层的组合,这部分内容,可以通过学习吴恩达深度学习课程卷积网络那部分来掌握,很好理解。

运行结果:3873张训练集图像分成两类,1080张测试集图像也分成两类。通过查看类别定义,0赋值给‘cats’,1赋值给‘dogs’。因为数据集太大,而本身电脑硬件条件有限,我把数据集和测试集的图片均删掉一半,并把迭代次数设置为一,方便更快看到效果。可以看到,这个模型训练完成后,训练集的精确度只有51%,测试集是56%,并不高。

1ecd1b2606ed46e9956a89f231c9802c.png

2020062310470442.png

3.保存模型

#保存模型
#pip install h5py
model.save('model_cnn.h5'

然后可以在项目文件夹中看到H5文件

1ecd1b2606ed46e9956a89f231c9802c.png

4.测试模型

还是在该文件中

label=np.array(['cat','dog'])#0、1赋值给标签
#载入模型
model=load_model('model_cnn.h5')
#导入图片
image=load_img('images/test_set/cats/cat.4001.jpg')
plt.imshow(image)
plt.show()
image=image.resize((150,150))
image=img_to_array(image)
image=image/255#数值归一化,转为0-1
image=np.expand_dims(image,0)
print(image.shape)
print(label[model.predict_classes(image)]

运行结果:

1ecd1b2606ed46e9956a89f231c9802c.png

1ecd1b2606ed46e9956a89f231c9802c.png

导入的是测试集中这张照片,但因为模型精确度只有50%,所以给出的答案是dog。

再换一张识别度高的猫咪图片试试:

1ecd1b2606ed46e9956a89f231c9802c.png

1ecd1b2606ed46e9956a89f231c9802c.png

结果就是cat了。

相关文章
|
20天前
|
机器学习/深度学习 数据采集 算法
基于GA遗传优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
本项目基于MATLAB2022a实现时间序列预测,采用CNN-GRU-SAM网络结构。卷积层提取局部特征,GRU层处理长期依赖,自注意力机制捕捉全局特征。完整代码含中文注释和操作视频,运行效果无水印展示。算法通过数据归一化、种群初始化、适应度计算、个体更新等步骤优化网络参数,最终输出预测结果。适用于金融市场、气象预报等领域。
基于GA遗传优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
|
16天前
|
机器学习/深度学习 算法 计算机视觉
基于CNN卷积神经网络的金融数据预测matlab仿真,对比BP,RBF,LSTM
本项目基于MATLAB2022A,利用CNN卷积神经网络对金融数据进行预测,并与BP、RBF和LSTM网络对比。核心程序通过处理历史价格数据,训练并测试各模型,展示预测结果及误差分析。CNN通过卷积层捕捉局部特征,BP网络学习非线性映射,RBF网络进行局部逼近,LSTM解决长序列预测中的梯度问题。实验结果表明各模型在金融数据预测中的表现差异。
|
28天前
|
机器学习/深度学习 数据采集 算法
基于PSO粒子群优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
本项目展示了基于PSO优化的CNN-GRU-SAM网络在时间序列预测中的应用。算法通过卷积层、GRU层、自注意力机制层提取特征,结合粒子群优化提升预测准确性。完整程序运行效果无水印,提供Matlab2022a版本代码,含详细中文注释和操作视频。适用于金融市场、气象预报等领域,有效处理非线性数据,提高预测稳定性和效率。
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
深入理解深度学习中的卷积神经网络(CNN)##
在当今的人工智能领域,深度学习已成为推动技术革新的核心力量之一。其中,卷积神经网络(CNN)作为深度学习的一个重要分支,因其在图像和视频处理方面的卓越性能而备受关注。本文旨在深入探讨CNN的基本原理、结构及其在实际应用中的表现,为读者提供一个全面了解CNN的窗口。 ##
|
2月前
|
机器学习/深度学习 人工智能 自然语言处理
深入理解深度学习中的卷积神经网络(CNN)
深入理解深度学习中的卷积神经网络(CNN)
|
1月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-GRU网络的数据分类识别算法matlab仿真
本项目展示了使用MATLAB2022a实现的贝叶斯优化、CNN和GRU算法优化效果。优化前后对比显著,完整代码附带中文注释及操作视频。贝叶斯优化适用于黑盒函数,CNN用于时间序列特征提取,GRU改进了RNN的长序列处理能力。
|
2月前
|
机器学习/深度学习 自然语言处理 算法
深入理解深度学习中的卷积神经网络(CNN)
深入理解深度学习中的卷积神经网络(CNN)
102 1
|
1月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
77 17
|
1月前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。
|
1月前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
59 10

热门文章

最新文章