直接使用
请打开基于Top Down的关键点检测示例,并点击右上角 “ 在DSW中打开” 。
EasyCV关键点检测-TOP-DOWN
关键点检测任务是计算机视觉任务领域的基础任务之一,包括人脸关键点、人体关键点以及特定物体(如手掌)关键点检测等,在姿态估计、行为识别、人机交互、虚拟现实以及无人驾驶等领域有重要的应用价值。
本文将介绍如何在pai-dsw基于EasyCV快速进行人体关键点检测模型的训练和推理。
运行环境要求
PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 32G
安装依赖包
注:在PAI-DSW docker中无需安装相关依赖,可跳过此1,2步骤, 在本地notebook环境中执行1,2 步骤安装环境
1、获取torch和cuda版本,并根据版本号修改mmcv安装命令,安装对应版本的mmcv和nvidia-dali
import torch import os os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '') os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '') !echo $CUDA !echo $Torch
[2023-02-03 16:42:27,634.634 dsw-16577-67c64db7b-kslkp:5077 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
cu101 torch1.8.2
# install some python deps ! pip install --upgrade tqdm ! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html ! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
2、安装EasyCV算法包 注:在PAI-DSW docker中预安装了pai-easycv库,可跳过该步骤,若训练测试过程中报错,尝试用下方命令更新easycv版本
#pip install pai-easycv ! echo y | pip uninstall pai-easycv easycv !pip install pai-easycv
from easycv.apis import *
正式开始
数据准备
本案例我们提供了小型关键点检测的数据集,以便你快速跑通,你可以下载链接数据
图片文件夹结构示例如下, 文件夹路径为./pose
pose/ ├── images ├── 0001.jpg ├── 0002.jpg ├── 0003.jpg |... └── train_200.json └── val_20.json
执行如下命令下载解压
! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/keypoint/pose_coco.tar.gz && tar -xpf pose_coco.tar.gz
训练模型
这个demo中我们采用litehrnet作为主干网络去进行训练
# 下载config文件 ! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/keypoint/litehrnet_30_coco_384x288.py
为了缩短训练时间,打开配置文件 litehrnet_30_coco_384x288.py,修改total_epoch参数为10, 每隔1次迭代打印一次日志。
# runtime settings total_epochs = 10 # log config log_config=dict(interval=1)
# 查看easycv安装位置 import easycv print(easycv.__file__)
/home/pai/lib/python3.6/site-packages/easycv/__init__.py
!python -m easycv.tools.train litehrnet_30_coco_384x288.py --work_dir work_dir/pose/litehrnet_30_coco
模型导出
# 查看训练产生的pt文件 ! ls work_dir/pose/litehrnet_30_coco/*pth
work_dir/pose/litehrnet_30_coco/CoCoPoseTopDownEvaluator_AP_best.pth work_dir/pose/litehrnet_30_coco/epoch_10.pth
! python -m easycv.tools.export litehrnet_30_coco_384x288.py work_dir/pose/litehrnet_30_coco/CoCoPoseTopDownEvaluator_AP_best.pth work_dir/pose/litehrnet_30_coco/export_best.pth
[2023-02-03 18:20:24,874.874 dsw-16577-67c64db7b-kslkp:6643 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off. pose/litehrnet_30_coco_384x288.py load checkpoint from local path: work_dir/pose/litehrnet_30_coco/CoCoPoseTopDownEvaluator_AP_best.pth
模型预测
在预测之前、我们还要下载一个目标检测模型,top-Down是直接从单个人体中进行关键点的预测,所以我们需要先将我们输入的图片中的人体先一个个检测出来,然后再一个个人体进行关键点检测模型进行检测关键点
!wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/pose/epoch_300.pt
Will not apply HSTS. The HSTS database must be a regular and non-world-writable file. ERROR: could not open HSTS store at '/root/.wget-hsts'. HSTS will be disabled. --2023-02-03 17:54:09-- https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/gl_pp/epoch_300.pt Resolving pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com)... 39.98.20.19 Connecting to pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com)|39.98.20.19|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 36133977 (34M) [application/octet-stream] Saving to: ‘epoch_300.pt’ epoch_300.pt 100%[===================>] 34.46M 16.3MB/s in 2.1s 2023-02-03 17:54:11 (16.3 MB/s) - ‘epoch_300.pt’ saved [36133977/36133977]
下面预测过程中如果出现以下报错,请手动卸载mmdet,至终端运行 pip uninstall mmdet
KeyError: 'YOLOXLrUpdaterHook is already registered in hook'
from PIL import Image import numpy as np from easycv.predictors.pose_predictor import TorchPoseTopDownPredictorWithDetector # 修改output_ckpt指向 pose_model_path = 'work_dir/pose/litehrnet_30_coco/export_best.pth' detection_model_path = 'epoch_300.pt' model_path = ','.join((pose_model_path, detection_model_path)) model_config={ 'pose': { 'bbox_thr': 0.3, 'format': 'xywh' }, 'detection': { 'model_type': 'TorchYoloXPredictor' } } fe = TorchPoseTopDownPredictorWithDetector(model_path=model_path, model_config=model_config) input_img = 'small_coco/images/000000012754.jpg' input_data_list = [np.asarray(Image.open(input_img))] results = fe.predict(input_data_list)[0] print(results['pose_results']) print(results['pose_results'][0]['keypoints'].shape)
/home/pai/lib/python3.6/site-packages/easycv/predictors/pose_predictor.py:437: DeprecationWarning: Call to deprecated class TorchYoloXPredictor (Please use YoloXPredictor.). detection_model_path, model_config=model_config['detection'])
reparam: 0 load checkpoint from local path: epoch_300.pt
/home/pai/lib/python3.6/site-packages/easycv/datasets/detection/pipelines/mm_transforms.py:1447: DeprecationWarning: pad_val of float type is deprecated now, please use pad_val=dict(img=(114.0, 114.0, 114.0), masks=(114.0, 114.0, 114.0), seg=255) instead. f'masks={pad_val}, seg=255) instead.', DeprecationWarning)
/home/pai/lib/python3.6/site-packages/easycv/datasets/detection/pipelines/mm_transforms.py:1447: DeprecationWarning: pad_val of float type is deprecated now, please use pad_val=dict(img=(114.0, 114.0, 114.0), masks=(114.0, 114.0, 114.0), seg=255) instead. f'masks={pad_val}, seg=255) instead.', DeprecationWarning)