基于 Tensorflow 的蘑菇分类

简介: 基于 Tensorflow 的蘑菇分类

引言


当我们在大自然中行走的时候,经常会碰到各种各样的菌子,这时候我们就有了疑问:我们可以触碰它们吗?它们可以吃吗?如果有一个可以识别菌子的app就很棒了,so,现在让我们来实现吧~


在我们开始之前,让我们理解一些概念。计算机视觉是人工智能的一个有趣分支之一,是教模型在图像中查找信息从而理解视觉内容的艺术。当对人类(猫、狗、汽车……)进行图像分类非常简单时,机器总是很难具有竞争力,这是我们人类从小就学习的东西。计算机视觉已经走过了漫长的道路,现在有了深度学习,它的识别和人类一样好,在特定领域甚至更好。例如,在医学放射学中,可以训练人工智能来检测和分类肿瘤,并且通常比人类有更好的结果。


计算机视觉的第一步是图像检测。图像检测是在给定的图像中找到图像中的特定对象,并返回其坐标或包围盒。


图像分类是当你给出一个物体的图像时,你的模型以概率和置信率返回一个类。因此,我们的模型应该首先检测对象,然后根据它所训练的类型对它们进行分类。为此,我们通常使用 CNN(卷积神经网络)。


图像识别是当您给模型一个图像与多个对象。该模型为图像中的每个物体提供了它的边界框(目标检测)和类的预测,并给出了置信率。


现在我们遇到的是多目标的图像分类问题。


收集数据


为了训练一个模型,你需要好的标记数据,如果这一步出现了错误,后面所有的步骤都将徒劳无功。现在我们用的是 Kaggle 的真菌数据集,这是一个非常好的数据集,有1394个类可以在这里使用。数据集的链接如下:https://www.kaggle.com/c/fungi-challenge-fgvc-2018


数据处理


Tensorflow 为我们提供了一个很便利的API,即 tf.data.dataset。我们可以很方便的用一行代码创建一个有效的数据集,让我们来看看吧~

    data_dir = '/Mydirectory/images/'
        img_height = 256
        img_width = 256
        batch_size = 32
        train_ds = tf.keras.preprocessing.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="training",
            seed=123,
            image_size=(img_height, img_width),
            batch_size=batch_size)
        val_ds = tf.keras.preprocessing.image_dataset_from_directory(
            data_dir,
            validation_split=0.2,
            subset="validation",
            seed=123,
            image_size=(img_height, img_width),
            batch_size=batch_size)
        class_names = train_ds.class_names

    以下是我们将使用的 10 个类:


    [‘11082_Xerocomellus_chrysenteron’, ‘12919_Cylindrobasidium_laeve’, ‘14064_Fomitopsis_pinicola’, ‘14160_Ganoderma_pfeifferi’, ‘17233_Mycena_galericulata’, ‘20983_Trametes_versicolor’, ‘21143_Tricholoma_scalpturatum’, ‘40392_Armillaria_lutea’, ‘40985_Byssomerulius_corium’, ‘61207_Coprinellus_micaceus’]

    640.png

    让我们设置数据集性能

          #################################################
          # Dataset Performance
          ##################################################
          AUTOTUNE = tf.data.experimental.AUTOTUNE
          train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
          val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

      迁移学习模式


      开始创建自己的 CNN,但效果不佳。我不是那么有耐心去改进它,我选择进行迁移学习。迁移学习是重用在更大数据集上训练的模型的能力,这些模型已经学习了多个特征。为此,我们冻结顶层并使用新类重新训练,权重可重复使用。所以让我们用这些预先训练好的模型来帮助自己。我使用了 MobileNetV2 模型,因为它非常轻巧,在我的 GPU 上运行只需几秒钟。


      为了提高准确性,我增加了一个独特的步骤,那就是数据增强。数据增强是: 对于一个标记为图像的输入,您可以缩放或翻转它,并将其作为模型的输入添加。这有助于模型继续识别对象,即使它并不总是处于相同的位置。

        #################################################
            # Data Augmentation
            ##################################################
            data_augmentation = tf.keras.Sequential([
                tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
                tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
                tf.keras.layers.experimental.preprocessing.RandomZoom(0.1),
            ])
            #################################################
            # CREATE THE MODEL
            ##################################################
            num_classes = 10
            preprocess_input_mobilenet_v2 = tf.keras.applications.mobilenet_v2.preprocess_input
            base_model = tf.keras.applications.MobileNetV2(input_shape=(256, 256, 3),
                                                           include_top=False,
                                                           weights='imagenet')
            base_model.trainable = False
            image_batch, label_batch = next(iter(train_ds))
            feature_batch = base_model(image_batch)
            global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
            feature_batch_average = global_average_layer(feature_batch)
            prediction_layer = tf.keras.layers.Dense(num_classes, kernel_regularizer=tf.keras.regularizers.l2(0.001))
            prediction_batch = prediction_layer(feature_batch_average)
            inputs = tf.keras.Input(shape=(256, 256, 3))
            x = data_augmentation(inputs)
            x = preprocess_input_mobilenet_v2(x)
            x = base_model(x, training=False)
            x = global_average_layer(x)
            x = tf.keras.layers.Dropout(0.2)(x)
            outputs = prediction_layer(x)
            model = tf.keras.Model(inputs, outputs)
            #################################################
            # COMPILE THE MODEL
            ##################################################
            #
            model.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])
            #################################################
            # TRAIN THE MODEL
            ##################################################
            epochs = 10
            history = model.fit(
                train_ds,
                validation_data=val_ds,
                epochs=epochs
            )

        结果如下

        640.png

        现在让我们来预测下,代码如下所示:

          # #################################################
          # # LOAD THE MODEL
          # ##################################################
          model = tf.keras.models.load_model('MobileNetV2_Ep20')
          # #################################################
          # # Predictions
          # ##################################################
          img_url = "https://www.mycodb.fr/photos/Xerocomellus_chrysenteron_2014_rp_1.jpg"
          img_path = tf.keras.utils.get_file('mushroom_image', origin=img_url)
          img = tf.keras.preprocessing.image.load_img(
              img_path, target_size=(256, 256, 3)
          )
          img_array = tf.keras.preprocessing.image.img_to_array(img)
          img_array = tf.expand_dims(img_array, 0)  # Create a batch
          predictions = model.predict(img_array)
          predictions_sigmoid = tf.nn.sigmoid(predictions)
          score = tf.nn.softmax(predictions[0])
          print(
              "This image most likely belongs to {} with a {:.2f} percent confidence."
              .format(class_names[np.argmax(score)], 100 * np.max(score))
          )
          

          预测结果:

          640.png

          结果还是相当不错的吧~


          总结


          在本文中,我们了解了如何使用 tensorflow 训练一个用于分类菌子的模型,下一步我们就可以将它移植到移动端,想想还是很兴奋的呢~

          相关文章
          |
          机器学习/深度学习 算法 TensorFlow
          树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
          树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
          227 1
          |
          3月前
          |
          数据采集 TensorFlow 算法框架/工具
          【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
          本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
          73 0
          【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
          |
          2月前
          |
          机器学习/深度学习 TensorFlow 算法框架/工具
          利用Python和TensorFlow构建简单神经网络进行图像分类
          利用Python和TensorFlow构建简单神经网络进行图像分类
          70 3
          |
          8月前
          |
          机器学习/深度学习 数据可视化 TensorFlow
          基于tensorflow深度学习的猫狗分类识别
          基于tensorflow深度学习的猫狗分类识别
          279 1
          |
          8月前
          |
          机器学习/深度学习 编译器 TensorFlow
          基于Python TensorFlow Estimator的深度学习回归与分类代码——DNNRegressor
          基于Python TensorFlow Estimator的深度学习回归与分类代码——DNNRegressor
          |
          机器学习/深度学习 存储 TensorFlow
          Azure 机器学习 - 使用 Visual Studio Code训练图像分类 TensorFlow 模型
          Azure 机器学习 - 使用 Visual Studio Code训练图像分类 TensorFlow 模型
          145 0
          |
          机器学习/深度学习 TensorFlow 算法框架/工具
          TensorFlow HOWTO 4.1 多层感知机(分类)
          TensorFlow HOWTO 4.1 多层感知机(分类)
          83 0
          |
          TensorFlow 算法框架/工具
          TensorFlow HOWTO 2.3 支持向量分类(高斯核)
          TensorFlow HOWTO 2.3 支持向量分类(高斯核)
          89 0
          |
          机器学习/深度学习 TensorFlow 算法框架/工具
          TensorFlow HOWTO 2.1 支持向量分类(软间隔)
          TensorFlow HOWTO 2.1 支持向量分类(软间隔)
          71 0
          【图像分类】TensorFlow2.7版本搭建NIN网络
          【图像分类】TensorFlow2.7版本搭建NIN网络
          112 0
          【图像分类】TensorFlow2.7版本搭建NIN网络