1 介绍
下面主要用过CNN来实现垃圾的分类。在本数据集中,垃圾的种类有六种(和上海的标准不一样),分为玻璃、纸、硬纸板、塑料、金属、一般垃圾。
数据来源:垃圾分类数据
2 导入数据和包
import numpy as np import matplotlib.pyplot as plt from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img from keras.layers import Conv2D, Flatten, MaxPooling2D, Dense from keras.models import Sequential import glob, os, random
base_path = '../input/trash_div7612/dataset-resized'#填写你下载文件的地址 img_list = glob.glob(os.path.join(base_path, '*/*.jpg')) print(len(img_list))
输出结果:
我们总共有2527张图片。我们随机展示其中的6张图片。
for i, img_path in enumerate(random.sample(img_list, 6)): img = load_img(img_path) img = img_to_array(img, dtype=np.uint8) plt.subplot(2, 3, i+1) plt.imshow(img.squeeze())
输出结果:
3.对数据进行分组
train_datagen = ImageDataGenerator( rescale=1./225, shear_range=0.1, zoom_range=0.1, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, vertical_flip=True, validation_split=0.1) test_datagen = ImageDataGenerator( rescale=1./255, validation_split=0.1) train_generator = train_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=16, class_mode='categorical', subset='training', seed=0) validation_generator = test_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=16, class_mode='categorical', subset='validation', seed=0) labels = (train_generator.class_indices) labels = dict((v,k) for k,v in labels.items()) print(labels)
输出结果:
4.模型的建立和训练
model = Sequential([ Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(300, 300, 3)), MaxPooling2D(pool_size=2), Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'), MaxPooling2D(pool_size=2), Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'), MaxPooling2D(pool_size=2), Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'), MaxPooling2D(pool_size=2), Flatten(), Dense(64, activation='relu'), Dense(6, activation='softmax') ])
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.fit_generator(train_generator, epochs=100, steps_per_epoch=2276//32,validation_data=validation_generator, validation_steps=251//32)
部分输出结果:
5.结果展示
下面我们随机抽取validation中的16张图片,展示图片以及其标签,并且给予我们的预测。
我们发现预测的准确度还是蛮高的,对于大部分图片,都能识别出其类别。
test_x, test_y = validation_generator.__getitem__(1) preds = model.predict(test_x) plt.figure(figsize=(16, 16)) for i in range(16): plt.subplot(4, 4, i+1) plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])])) plt.imshow(test_x[i])