一、基于PaddleX的【稻田医生】稻田病害分类
比赛地址:www.kaggle.com/competition…
1.问题描述
大米是世界范围内的主食之一。稻谷是去壳前的粗粮,主要在亚洲国家在热带气候中种植。水稻种植需要持续监督,因为多种疾病和害虫可能会影响水稻作物,导致高达 70% 的产量损失。通常需要专家监督来减轻这些疾病并防止作物损失。由于作物保护专家的可用性有限,人工疾病诊断既繁琐又昂贵。因此,通过利用在各个领域取得可喜成果的基于计算机视觉的技术来自动化疾病识别过程变得越来越重要。
2.比赛简介
本次比赛的主要目标是开发一种基于机器或深度学习的模型来准确分类给定的稻叶图像。我们提供了一个包含 10,407 个 (75%) 标记图像的训练数据集,涵盖 10 个类别(9 个疾病类别和正常叶片)。此外,我们还为每个图像提供额外的元数据,例如稻谷品种和年龄。您的任务是将给定测试数据集中的 3,469 个 (25%) 图像中的每个水稻图像分类为九种疾病类别之一或正常叶子。
二、数据分析
1.数据介绍
我们提供了一个包含 10,407 个(75%)标记的水稻叶片图像的训练数据集,涵盖 10 个类别(9 个疾病和正常叶片)。我们还为每个图像提供额外的元数据,例如稻谷品种和年龄。您的任务是使用训练数据集开发一个准确的疾病分类模型,然后将测试数据集中的 3,469 个(25%)水稻叶片图像中的每个样本分类为九种疾病或正常叶片之一。
train.csv - 训练集
- image_id- 唯一图像标识符对应于train_images目录中的图像文件名 (.jpg)。
- label- 水稻病害类型,也是目标类别。有十类,包括正常的叶子。
- variety- 水稻品种的名称。
- age- 以天为单位的稻谷年龄。
sample_submission.csv - 样本提交文件。
train_images - 该目录包含 10,407 张训练图像,存储在对应于 10 个目标类的不同子目录下。文件名对应image_id于train.csv.
test_images - 此目录包含 3,469 个测试集图像。
2.数据解压缩
!unzip -qoa data/data148690/paddy-disease-classification.zip -d data
3.训练集统计
# 训练集统计 import pathlib train_data_dir = pathlib.Path('data/train_images/') print(train_data_dir) # 带目录 train_image_count = len(list(train_data_dir.glob('*/*.jpg'))) print(train_image_count)
data/train_images 10407
# 测试集统计 test_data_dir = pathlib.Path('data/test_images/') print(test_data_dir) # 不带目录,直接图片 test_image_count = len(list(test_data_dir.glob('*.jpg'))) print(test_image_count)
data/test_images 3469
# 训练集查看 import PIL import PIL.Image image_sample = list(train_data_dir.glob('hispa/*')) PIL.Image.open(str(image_sample[0]))
import pandas as pd train_df=pd.read_csv("data/train.csv") train_df.head() .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
image_id | label | variety | age | |
0 | 100330.jpg | bacterial_leaf_blight | ADT45 | 45 |
1 | 100365.jpg | bacterial_leaf_blight | ADT45 | 45 |
2 | 100382.jpg | bacterial_leaf_blight | ADT45 | 45 |
3 | 100632.jpg | bacterial_leaf_blight | ADT45 | 45 |
4 | 101918.jpg | bacterial_leaf_blight | ADT45 | 45 |
# 查看叶片分类 class_names = train_df.label.unique() print(class_names)
['bacterial_leaf_blight' 'bacterial_leaf_streak' 'bacterial_panicle_blight' 'blast' 'brown_spot' 'dead_heart' 'downy_mildew' 'hispa' 'normal' 'tungro']
4.每种病虫害图片
import os import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageDraw, ImageFont data_dir = 'data' train_file_path = os.path.join(data_dir, 'train.csv') train_info = pd.read_csv(train_file_path) for disease in np.unique(train_info["label"]): disease_path = os.path.join(data_dir, "train_images", disease) img_names = os.listdir(disease_path) fig, axes = plt.subplots(1, 7, figsize=(20,12)) for idx in range(7): img_path = os.path.join(disease_path, img_names[idx]) image = Image.open(img_path) axes[idx].imshow(image) axes[idx].set_title(disease) axes[idx].axis('off') fig.show()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/figure.py:457: UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure "matplotlib is currently using a non-GUI backend, "
三、PaddleX环境准备
!pip install paddlex 复制代码
!pip list|grep paddlex
paddlex 2.1.0 [33mWARNING: You are using pip version 22.0.4; however, version 22.1.1 is available. You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m[33m [0m
四、模型训练
1.数据集划分
!paddlex --split_dataset --format ImageNet --dataset_dir data/train_images --val_value 0.2
2.导入PaddleX库
# 环境变量配置,用于控制是否使用GPU # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' import paddle import paddlex as pdx from paddlex import transforms as T
3.数据增强
# 定义训练和验证时的transforms # API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md train_transforms = T.Compose( [T.RandomCrop(crop_size=224), T.RandomHorizontalFlip(), T.Normalize()]) eval_transforms = T.Compose([ T.ResizeByShort(short_size=256), T.CenterCrop(crop_size=224), T.Normalize() ])
4.定义数据集
# 定义训练和验证所用的数据集 # API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md train_dataset = pdx.datasets.ImageNet( data_dir='data/train_images', file_list='data/train_images/train_list.txt', label_list='data/train_images/labels.txt', transforms=train_transforms, shuffle=True) eval_dataset = pdx.datasets.ImageNet( data_dir='data/train_images', file_list='data/train_images/val_list.txt', label_list='data/train_images/labels.txt', transforms=eval_transforms)
5.模型初始化
# 初始化模型,并进行训练 # 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/visualdl.md num_classes = len(train_dataset.labels) model = pdx.cls.MobileNetV3_large(num_classes=num_classes) # 自定义优化器:使用CosineAnnealingDecay train_batch_size = 400 num_steps_each_epoch = len(train_dataset) // train_batch_size num_epochs = 100 scheduler = paddle.optimizer.lr.CosineAnnealingDecay( learning_rate=.001, T_max=num_steps_each_epoch * num_epochs) warmup_epoch = 5 warmup_steps = warmup_epoch * num_steps_each_epoch scheduler = paddle.optimizer.lr.LinearWarmup( learning_rate=scheduler, warmup_steps=warmup_steps, start_lr=0.0, end_lr=.001) custom_optimizer = paddle.optimizer.Momentum( learning_rate=scheduler, momentum=.9, weight_decay=paddle.regularizer.L2Decay(coeff=.00002), parameters=model.net.parameters())
6.模型训练
# API说明:https://github.com/PaddlePaddle/PaddleX/blob/95c53dec89ab0f3769330fa445c6d9213986ca5f/paddlex/cv/models/classifier.py#L153 # 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html model.train( num_epochs=num_epochs, train_dataset=train_dataset, train_batch_size=train_batch_size, eval_dataset=eval_dataset, optimizer=custom_optimizer, save_dir='output/mobilenetv3_large', use_vdl=True)
VisualDL训练过程如下:
五、模型预测
重启环境,进行预测。
# 单图预测 import paddlex as pdx model = pdx.load_model('output/mobilenetv3_large/best_model') result = model.predict('data/test_images/200001.jpg') print("Predict Result: ", result) print("Predict Result: ", result[0]['category_id'])
2022-05-25 00:14:24 [INFO] Model[MobileNetV3_large] loaded. Predict Result: [{'category_id': 7, 'category': 'hispa', 'score': 0.87708503}] Predict Result: 7
# test_images文件夹批量预测 import pathlib import os import paddlex as pdx # 载入模型 model = pdx.load_model('output/mobilenetv3_large/best_model') # 结果文件 f=open("result.csv","w") f.write('image_id,label\n') # 遍历文件夹 test_data_dir = pathlib.Path('data/test_images/') # 不带目录,直接图片 test_files=list(test_data_dir.glob('*.jpg')) for myfile in test_files: result = model.predict(str(myfile)) filename=os.path.basename(myfile) # 写入文件 f.write(f"{filename},{result[0]['category_id']}\n") f.close()
2022-05-25 00:20:23 [INFO] Model[MobileNetV3_large] loaded.
!head result.csv
image_id,label 203240.jpg,5 202134.jpg,2 203075.jpg,5 203450.jpg,8 201291.jpg,9 202727.jpg,2 201387.jpg,8 203014.jpg,6 202823.jpg,5
如上所示,下载result.csv即可提交。