1. mmfetshot环境安装
配置安装,以下的配置亲测可用:
安装指令:
# install mmcv mmclassification mmdetection pip install mmcv-full==1.6.1 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html pip install mmcls==0.23.2 pip install mmdet==2.25.0 pip install mmfewshot # install mmfewshot git clone https://github.com/open-mmlab/mmfewshot.git cd mmfewshot pip install -r requirements/build.txt pip install -v -e . # or "python setup.py develop"
需要注意,一定要进行后续的 pip install -r requirements/build.txt 与 pip install -v -e . ,否则你的安装是不完整的,可能还会导致无法使用。
2. 模型训练
训练过程:此处以在vol数据集中训练tfa模型为例
Step1:基础训练
- 使用基类的所有图像和注释来训练基本模型。
python tools/detection/train.py configs/detection/tfa/voc/split1/tfa_r101_fpn_voc-split1_base-training.py --gpu-id 0
Step2:重新初始化基本模型的bbox头部
- 使用提供的脚本为所有类微调(基类 + 新类)创建一个新的 bbox 头。
- new bbox head 中基类的权重直接使用原来的作为初始化。
- new bbox head 中新类的权重使用随机初始化。
python tools/detection/misc/initialize_bbox_head.py \ --src1 work_dirs/tfa_r101_fpn_voc-split1_base-training/latest.pth \ --method random_init \ --save-dir work_dirs/tfa_r101_fpn_voc-split1_base-training/
Step3:Few-shot微调
- 使用 step2 中的基本模型作为模型初始化,并使用少量镜头数据集进一步微调 bbox 头部。
由于现在需要进行few-shot微调,需要进行特定的数据集处理,下载few-shot的标注信息,分别下载完解压到 data/few_shot_ann/ 路径下即可
coco数据集:https://download.openmmlab.com/mmfewshot/few_shot_ann/coco.tar.gz
- 数据结构:
mmfewshot ├── mmfewshot ├── tools ├── configs ├── data │ ├── coco │ │ ├── annotations │ │ ├── train2014 │ │ ├── val2014 │ │ ├── train2017 (optional) │ │ ├── val2017 (optional) │ ├── few_shot_ann │ │ ├── coco │ │ │ ├── annotations │ │ │ │ ├── train.json │ │ │ │ ├── val.json │ │ │ ├── attention_rpn_10shot (for coco17) │ │ │ ├── benchmark_10shot │ │ │ ├── benchmark_30shot
voc数据集:https://download.openmmlab.com/mmfewshot/few_shot_ann/voc.tar.gz
- 数据结构:
mmfewshot ├── mmfewshot ├── tools ├── configs ├── data │ ├── VOCdevkit │ │ ├── VOC2007 │ │ ├── VOC2012 │ ├── few_shot_ann │ │ ├── voc │ │ │ ├── benchmark_1shot │ │ │ ├── benchmark_2shot │ │ │ ├── benchmark_3shot │ │ │ ├── benchmark_5shot │ │ │ ├── benchmark_10shot
处理完的数据集结构如下所示:
准备完数据集,即可进行训练
CUDA_VISIBLE_DEVICES=0,1,3 bash tools/detection/dist_train.sh \ configs/detection/tfa/voc/split1/tfa_r101_fpn_voc-split1_5shot-fine-tuning.py 3
ps:这里的配置文件会自动的在相关的路径下架子基础训练的模型,比如:work_dirs/tfa_r101_fpn_voc-split1_base-training/base_model_random_init_bbox_head.pth,在配置文件中设置如下
# base model needs to be initialized with following script: # tools/detection/misc/initialize_bbox_head.py # please refer to configs/detection/tfa/README.md for more details. load_from = ('work_dirs/tfa_r101_fpn_voc-split1_base-training/' 'base_model_random_init_bbox_head.pth')
训练结束后,由于此时已经微调结束,所以在对应的目录下tfa_r101_fpn_voc-split1_5shot-fine-tuning,会生成相对应的权重,如下所示:
在训练结束后,打印的信息有一个log文件记录,如下所示:
+-------------+------+-------+--------+-------+ | class | gts | dets | recall | ap | +-------------+------+-------+--------+-------+ | aeroplane | 285 | 2296 | 0.877 | 0.615 | | bicycle | 337 | 2700 | 0.840 | 0.457 | | boat | 263 | 2400 | 0.768 | 0.372 | | bottle | 469 | 5199 | 0.795 | 0.484 | | car | 1201 | 5936 | 0.938 | 0.780 | | cat | 358 | 2728 | 0.925 | 0.495 | | chair | 756 | 7951 | 0.799 | 0.400 | | diningtable | 206 | 5288 | 0.859 | 0.367 | | dog | 489 | 3675 | 0.894 | 0.405 | | horse | 348 | 7908 | 0.937 | 0.462 | | person | 4528 | 14036 | 0.917 | 0.691 | | pottedplant | 480 | 4764 | 0.729 | 0.295 | | sheep | 242 | 2452 | 0.884 | 0.467 | | train | 282 | 2988 | 0.837 | 0.398 | | tvmonitor | 308 | 5185 | 0.880 | 0.549 | | bird | 459 | 4232 | 0.693 | 0.247 | | bus | 213 | 4890 | 0.864 | 0.267 | | cow | 244 | 3942 | 0.947 | 0.380 | | motorbike | 325 | 6876 | 0.871 | 0.461 | | sofa | 239 | 6499 | 0.787 | 0.195 | +-------------+------+-------+--------+-------+ | mAP | | | | 0.439 | +-------------+------+-------+--------+-------+ 2022-09-08 18:46:13,426 - mmfewshot - INFO - BASE_CLASSES_SPLIT1 mAP: 0.4825480580329895 2022-09-08 18:46:13,426 - mmfewshot - INFO - NOVEL_CLASSES_SPLIT1 mAP: 0.3099679946899414 2022-09-08 18:46:13,432 - mmfewshot - INFO - Exp name: tfa_r101_fpn_voc-split1_5shot-fine-tuning.py 2022-09-08 18:46:13,432 - mmfewshot - INFO - Iter(val) [1651] AP50: 0.4390, BASE_CLASSES_SPLIT1: AP50: 0.4830, NOVEL_CLASSES_SPLIT1: AP50: 0.3100, mAP: 0.4394
3. 模型推理
将刚刚训练好的模型进行验证,以查看效果是否匹配,这里验证的是最新的模型,也就是latest.pth
python tools/detection/test.py \ configs/detection/tfa/voc/split1/tfa_r101_fpn_voc-split1_5shot-fine-tuning.py \ work_dirs/tfa_r101_fpn_voc-split1_5shot-fine-tuning/latest.pth \ --eval mAP --gpu-id 2
验证结果:
可以看见,最后的验证结果与训练时刻的验证结果是一致的
参考资料:
https://mmfewshot.readthedocs.io/en/latest/index.html