【DSW Gallery】基于Top Down的关键点检测示例

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以关键点检测为例,为您介绍如何在PAI-DSW中使用EasyCV。

直接使用

请打开基于Top Down的关键点检测示例,并点击右上角 “ 在DSW中打开” 。

image.png

EasyCV关键点检测-TOP-DOWN

  关键点检测任务是计算机视觉任务领域的基础任务之一,包括人脸关键点、人体关键点以及特定物体(如手掌)关键点检测等,在姿态估计、行为识别、人机交互、虚拟现实以及无人驾驶等领域有重要的应用价值。

  本文将介绍如何在pai-dsw基于EasyCV快速进行人体关键点检测模型的训练和推理。

运行环境要求

PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 32G

安装依赖包

注:在PAI-DSW docker中无需安装相关依赖,可跳过此1,2步骤, 在本地notebook环境中执行1,2 步骤安装环境

1、获取torch和cuda版本,并根据版本号修改mmcv安装命令,安装对应版本的mmcv和nvidia-dali

import torch
import os
os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')
os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')
!echo $CUDA
!echo $Torch
[2023-02-03 16:42:27,634.634 dsw-16577-67c64db7b-kslkp:5077 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
cu101
torch1.8.2
# install some python deps
! pip install --upgrade tqdm
! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl

2、安装EasyCV算法包 注:在PAI-DSW docker中预安装了pai-easycv库,可跳过该步骤,若训练测试过程中报错,尝试用下方命令更新easycv版本

#pip install pai-easycv
! echo y | pip uninstall pai-easycv easycv
!pip install pai-easycv
from easycv.apis import *

正式开始

数据准备

本案例我们提供了小型关键点检测的数据集,以便你快速跑通,你可以下载链接数据

图片文件夹结构示例如下, 文件夹路径为./pose

pose/
├── images
    ├── 0001.jpg
    ├── 0002.jpg
    ├── 0003.jpg
    |...
└── train_200.json
└── val_20.json

 执行如下命令下载解压

! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/keypoint/pose_coco.tar.gz && tar -xpf pose_coco.tar.gz

训练模型

这个demo中我们采用litehrnet作为主干网络去进行训练

# 下载config文件
! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/keypoint/litehrnet_30_coco_384x288.py 

为了缩短训练时间,打开配置文件 litehrnet_30_coco_384x288.py,修改total_epoch参数为10, 每隔1次迭代打印一次日志。

# runtime settings
total_epochs = 10
# log config
log_config=dict(interval=1)
# 查看easycv安装位置
import easycv
print(easycv.__file__)
/home/pai/lib/python3.6/site-packages/easycv/__init__.py
!python -m easycv.tools.train litehrnet_30_coco_384x288.py --work_dir work_dir/pose/litehrnet_30_coco

模型导出

# 查看训练产生的pt文件
! ls  work_dir/pose/litehrnet_30_coco/*pth
work_dir/pose/litehrnet_30_coco/CoCoPoseTopDownEvaluator_AP_best.pth
work_dir/pose/litehrnet_30_coco/epoch_10.pth
! python -m easycv.tools.export  litehrnet_30_coco_384x288.py work_dir/pose/litehrnet_30_coco/CoCoPoseTopDownEvaluator_AP_best.pth work_dir/pose/litehrnet_30_coco/export_best.pth
[2023-02-03 18:20:24,874.874 dsw-16577-67c64db7b-kslkp:6643 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
pose/litehrnet_30_coco_384x288.py
load checkpoint from local path: work_dir/pose/litehrnet_30_coco/CoCoPoseTopDownEvaluator_AP_best.pth

模型预测

  在预测之前、我们还要下载一个目标检测模型,top-Down是直接从单个人体中进行关键点的预测,所以我们需要先将我们输入的图片中的人体先一个个检测出来,然后再一个个人体进行关键点检测模型进行检测关键点

!wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/pose/epoch_300.pt
Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.
ERROR: could not open HSTS store at '/root/.wget-hsts'. HSTS will be disabled.
--2023-02-03 17:54:09--  https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/gl_pp/epoch_300.pt
Resolving pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com)... 39.98.20.19
Connecting to pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com)|39.98.20.19|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 36133977 (34M) [application/octet-stream]
Saving to: ‘epoch_300.pt’
epoch_300.pt        100%[===================>]  34.46M  16.3MB/s    in 2.1s    
2023-02-03 17:54:11 (16.3 MB/s) - ‘epoch_300.pt’ saved [36133977/36133977]

下面预测过程中如果出现以下报错,请手动卸载mmdet,至终端运行 pip uninstall mmdet

KeyError: 'YOLOXLrUpdaterHook is already registered in hook'
from PIL import Image
import numpy as np
from easycv.predictors.pose_predictor import TorchPoseTopDownPredictorWithDetector
# 修改output_ckpt指向
pose_model_path = 'work_dir/pose/litehrnet_30_coco/export_best.pth'
detection_model_path = 'epoch_300.pt'
model_path = ','.join((pose_model_path, detection_model_path))
model_config={
                'pose': {
                    'bbox_thr': 0.3,
                    'format': 'xywh'
                },
                'detection': {
                    'model_type': 'TorchYoloXPredictor'
                }
            }
fe = TorchPoseTopDownPredictorWithDetector(model_path=model_path, model_config=model_config)
input_img = 'small_coco/images/000000012754.jpg'
input_data_list = [np.asarray(Image.open(input_img))]
results = fe.predict(input_data_list)[0]
print(results['pose_results'])
print(results['pose_results'][0]['keypoints'].shape)
/home/pai/lib/python3.6/site-packages/easycv/predictors/pose_predictor.py:437: DeprecationWarning: Call to deprecated class TorchYoloXPredictor (Please use YoloXPredictor.).
  detection_model_path, model_config=model_config['detection'])
reparam: 0
load checkpoint from local path: epoch_300.pt
/home/pai/lib/python3.6/site-packages/easycv/datasets/detection/pipelines/mm_transforms.py:1447: DeprecationWarning: pad_val of float type is deprecated now, please use pad_val=dict(img=(114.0, 114.0, 114.0), masks=(114.0, 114.0, 114.0), seg=255) instead.
  f'masks={pad_val}, seg=255) instead.', DeprecationWarning)
/home/pai/lib/python3.6/site-packages/easycv/datasets/detection/pipelines/mm_transforms.py:1447: DeprecationWarning: pad_val of float type is deprecated now, please use pad_val=dict(img=(114.0, 114.0, 114.0), masks=(114.0, 114.0, 114.0), seg=255) instead.
  f'masks={pad_val}, seg=255) instead.', DeprecationWarning)
相关实践学习
使用PAI+LLaMA Factory微调Qwen2-VL模型,搭建文旅领域知识问答机器人
使用PAI和LLaMA Factory框架,基于全参方法微调 Qwen2-VL模型,使其能够进行文旅领域知识问答,同时通过人工测试验证了微调的效果。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
数据可视化 数据挖掘
【数据分析与可视化】对图像进行SVD分解并重构图像实战(附源码)
【数据分析与可视化】对图像进行SVD分解并重构图像实战(附源码)
457 0
|
机器学习/深度学习 数据采集 编解码
MMPose | 关于自顶向下 2D HPE 算法的,全都在这里啦!
2D Human Pose Estimation (以下简称 2D HPE )旨在从图像或者视频中预测人体关节点(或称关键点,比如头,左手,右脚等)的二维空间位置坐标。2D HPE 的应用场景非常广泛,包括动作识别,动画生成,增强现实等。
2104 0
MMPose | 关于自顶向下 2D HPE 算法的,全都在这里啦!
|
Windows
已解决Win11报错 OSError: [WinError 1455] 页面文件太小,无法完成操作。
Win11报错 OSError: [WinError 1455] 页面文件太小,无法完成操作。 Error loading "D:\aaaa\envs\gs\lib\site-packages\torch\lib\caffe2_detectron_ops_gpu.dll" or one of its dependencies.
7847 0
已解决Win11报错 OSError: [WinError 1455] 页面文件太小,无法完成操作。
|
机器学习/深度学习 Python
训练集、测试集与验证集:机器学习模型评估的基石
在机器学习中,数据集通常被划分为训练集、验证集和测试集,以评估模型性能并调整参数。训练集用于拟合模型,验证集用于调整超参数和防止过拟合,测试集则用于评估最终模型性能。本文详细介绍了这三个集合的作用,并通过代码示例展示了如何进行数据集的划分。合理的划分有助于提升模型的泛化能力。
|
Web App开发 移动开发 小程序
看我如何让手机秒变扫码枪
为解决无扫码枪问题,作者受到微信小程序“超级扫码枪”启发,决定自制手机扫码到电脑的应用。项目需求是手机扫描条形码或二维码后实时传送到电脑。实现步骤包括:电脑端用Java Swing和Robot模拟键盘输入,手机端H5调用摄像头扫码(借助html5-qrcode库),并通过WebSocket服务将结果发送至电脑。项目源码及演示视频链接提供。
2340 5
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
257 12
|
存储 开发框架 .NET
【博士每天一篇文献-综述】A Comprehensive Survey of Continual Learning Theory, Method and Application
本文综述了持续学习的理论基础、方法论和应用实践,探讨了五种主要的解决策略,包括基于回放、架构、表示、优化和正则化的方法,并深入分析了持续学习的不同场景、分类、评价指标以及面临的挑战和解决方案。
506 1
【博士每天一篇文献-综述】A Comprehensive Survey of Continual Learning Theory, Method and Application
|
12月前
|
Ubuntu 网络安全 数据库
使用官方开源项目搭建自有Overleaf服务
【10月更文挑战第6天】本文详细介绍了在服务器上部署 Overleaf 服务的步骤,包括服务器环境准备、域名与 SSL 证书配置、获取官方项目代码、配置与构建服务,以及测试和使用服务等内容。适用于希望自建 Overleaf 服务的用户。建议服务器配置为 Ubuntu 系统,具备至少 10GB 磁盘和 2GB 内存。
770 0
|
存储 缓存 NoSQL
Redis经典问题:BigKey问题
BigKey问题常困扰着Redis用户,其影响不容忽视。本文将深入探讨BigKey问题的本质及解决方案,帮助你优化Redis性能,提升系统稳定性。
763 2
|
Go 开发工具 git
推荐一个开源流媒体服务器-livgo
推荐一个开源流媒体服务器-livgo
550 0

热门文章

最新文章