第一步 解压数据集
读取AI识虫数据集标注信息
AI识虫数据集结构如下:
提供了2183张图片,其中训练集1693张,验证集245,测试集245张。
包含7种昆虫,分别是Boerner、Leconte、Linnaeus、acuminatus、armandi、coleoptera和linnaeus。
!unzip -oq /home/aistudio/data/data96442/insects.zip -d work
#将文件移动至work下,并删除无用文件 !mv work/data/insects work/ !ls /home/aistudio/work !rmdir --ignore-fail-on-non-empty work/data !rm -rf work/data
# 进入工作目录 /home/aistudio/work %cd /home/aistudio/work # 查看工作目录下的文件列表 !ls
/home/aistudio/work anchor_lables.py insects train.py box_utils.py insects_reader.py yolo_epoch0.pdparams calculate_map.py map_utils.py yolo_epoch100.pdparams darknet.py multinms.py yolo_epoch150.pdparams draw_anchors.py output_pic.png yolo_epoch199.pdparams draw_results.py predict.py yolo_epoch50.pdparams eval.py __pycache__ yolov3.py image_utils.py reader.py
第二步 启动训练
通过运行train.py 文件启动训练,训练好的模型参数会保存在/home/aistudio/work目录下。
# 运行时长: 12小时37分钟50秒582毫秒 !python train.py
第三步 启动评估
通过运行eval.py启动评估,需要制定待评估的图片文件存放路径和需要使用到的模型参数。评估结果会被保存在pred_results.json文件中。
为了演示计算过程,下面使用的是验证集下的图片./insects/val/images,在提交比赛结果的时候,请使用测试集图片./insects/test/images
这里提供的yolo_epoch50.pdparams 是未充分训练好的权重参数,请在比赛时换成自己训练好的权重参数
# 在测试集test上评估训练模型,image_dir指向测试集集路径,weight_file指向要使用的权重路径。 # 参加比赛时需要在测试集上运行这段代码,并把生成的pred_results.json提交上去 # 特别注意,这个weight_file文件提供了199的版本在数据集里面 !python eval.py --image_dir=insects/test/images --weight_file=yolo_epoch199.pdparams
# 在验证集val上评估训练模型,image_dir指向验证集路径,weight_file指向要使用的权重路径。 !python eval.py --image_dir=insects/val/images --weight_file=yolo_epoch199.pdparams
第四步 算精度指标
通过运行calculate_map.py计算最终精度指标mAP
· 同学们训练完之后,可以在val数据集上计算mAP查看结果,所以下面用到的是val标注数据./insects/val/annotations/xmls
· 提交比赛成绩的话需要在测试集上计算mAP,本地没有测试集的标注,只能提交json文件到比赛服务器上查看成绩
第五步 预测单张图片并可视化预测结果
!python predict.py --image_name=./insects/test/images/3157.jpeg --weight_file=./yolo_epoch50.pdparams # 预测结果保存在“/home/aistudio/work/output_pic.png"图像中,运行下面的代码进行可视化
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized W0627 09:44:59.653769 875 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W0627 09:44:59.658743 875 device_context.cc:422] device: 0, cuDNN Version: 7.6. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data
# 可视化检测结果 from PIL import Image import matplotlib.pyplot as plt %matplotlib inline img = Image.open("/home/aistudio/work/output_pic.png") plt.figure("Object Detection", figsize=(15, 15)) # 图像窗口名称 plt.imshow(img) plt.axis('off') # 关掉坐标轴为 off plt.title('Bugs Detestion') # 图像题目 plt.show()
总结 提升方案
这里给出的是一份基础版本的代码,可以在上面继续改进提升,可以使用的改进方案有:
1、使用其它模型如faster rcnn等 (难度系数5)
2、使用数据增多,可以对原图进行翻转、裁剪等操作 (难度系数3)
3、修改anchor参数的设置,教案中的anchor参数设置直接使用原作者在coco数据集上的设置,针对此模型是否要调整 (难度系数3)
4、调整优化器、学习率策略、正则化系数等是否能提升模型精度 (难度系数1)