【DSW Gallery】基于残差网络的度量学习示例

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以度量学习为例,为您介绍如何在PAI-DSW中使用EasyCV。

直接使用

请打开基于基于残差网络的度量学习示例,并点击右上角 “ 在DSW中打开” 。

image.png

EasyCV度量学习

  度量学习又称为相似性学习,有广泛的应用,可在数据集上构建合适的距离度量来建模回答实际问题。在机器学习中distance metric learning(也称 metric learning,度量学习)是一个很典型的任务,通常与很多熟知的 metric-based methods(如 KNN、K-means 等)结合起来使用以实现分类或者聚类 ,效果通常非常不错。

  本文将介绍如何在pai-dsw基于EasyCV快速进行度量模型的训练、推理。

运行环境要求

PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 32G

安装依赖包

注:在PAI-DSW docker中无需安装相关依赖,可跳过此1,2步骤, 在本地notebook环境中执行1,2步骤安装环境

1、获取torch和cuda版本,并根据版本号修改mmcv安装命令,安装对应版本的mmcv和nvidia-dali

import torch
import os
os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')
os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')
!echo $CUDA
!echo $Torch
cu101
torch1.8.1+cu101
# install some python deps
! pip install --upgrade tqdm
! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl

2、安装EasyCV算法包 注:在PAI-DSW docker中预安装了pai-easycv库,可跳过该步骤,若训练测试过程中报错,尝试用下方命令更新easycv版本

#pip install pai-easycv
! echo y | pip uninstall pai-easycv easycv
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/2022_6/pai_easycv-0.3.0-py3-none-any.whl

3、简单验证

from easycv.apis import *

CUB200 度量学习

下面示例介绍如何利用CUB200数据,使用ResNet50模型快速进行图像分类模型的训练评估、模型预测过程

数据准备

下载cub200数据,解压到data/cub200目录, 目录结构如下

data/cub200
├── images
├── images.txt
├── image_class_labels.txt
├── train_test_split.txt
! mkdir -p data/ && wget https://s3.amazonaws.com/fast-ai-imageclas/CUB_200_2011.tgz && tar -xzf CUB_200_2011.tgz -C data/ && mv data/CUB_200_2011 data/cub200

训练模型

下载训练配置文件,该配置文件默认从基于ImageNet的预训练模型导入权重,如需自监督预训练模型可以从上文的链接中下载并替换config中的配置。

! rm -rf cub_resnet50_jpg.py
!wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/metric_learning/cub_resnet50_jpg.py

使用单卡gpu进行训练和验证集评估,为了快速跑通,可自行将cub_resnet50_jpg.py中的total_epoch参数设置成10。

!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29500 /home/pai/lib/python3.6/site-packages/easycv/tools/train.py cub_resnet50_jpg.py --work_dir work_dirs/metric_learning/cub/r50 --launcher pytorch --fp16

导出模型

模型训练完成,使用export命令导出模型进行推理,导出的模型包含推理时所需的预处理信息、后处理信息

# 查看训练产生的pt文件
! ls  work_dirs/metric_learning/cub/r50*

RetrivalTopKEvaluator_R@K=1_best.pth 是训练过程中产生的acc最高的pth,导出该模型

! python -m easycv.tools.export ./cub_resnet50_jpg.py work_dirs/metric_learning/cub/r50/RetrivalTopKEvaluator_R@K=1_best.pth  work_dirs/metric_learning/cub/r50/best_export.pth

预测

下载测试图片

! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/fine_grain_cls/cub_raw/images/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg

导入模型权重,并预测测试图片的分类结果

import cv2
from easycv.predictors.feature_extractor import TorchFeatureExtractor
output_ckpt = 'work_dirs/metric_learning/cub/r50/best_export.pth'
tcls = TorchFeatureExtractor(output_ckpt)
img = cv2.imread('Black_Footed_Albatross_0001_796111.jpg')
# input image should be RGB order
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
output = tcls.predict([img])
相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
1月前
|
网络协议 网络虚拟化 数据中心
华为配置VXLAN构建虚拟网络实现相同网段互通示例(静态方式)
配置VXLAN构建虚拟网络实现相同网段互通示例(静态方式
|
1月前
|
消息中间件 网络协议 C++
C/C++网络编程基础知识超详细讲解第三部分(系统性学习day13)
C/C++网络编程基础知识超详细讲解第三部分(系统性学习day13)
|
1月前
|
监控 网络协议 Java
Linux 网络编程从入门到进阶 学习指南
在上一篇文章中,我们探讨了 Linux 系统编程的诸多基础构件,包括文件操作、进程管理和线程同步等,接下来,我们将视野扩展到网络世界。在这个新篇章里,我们要让应用跳出单机限制,学会在网络上跨机器交流信息。
Linux 网络编程从入门到进阶 学习指南
|
3月前
|
SQL 运维 安全
黑客(网络安全)技术自学——高效学习
黑客(网络安全)技术自学——高效学习
28 1
|
3月前
|
开发框架 网络协议 .NET
【网络奇缘】- 计算机网络|分层结构|深入学习ISO模型
【网络奇缘】- 计算机网络|分层结构|深入学习ISO模型
48 0
|
3月前
|
网络协议 Go API
Go语言学习-网络基础
Go语言学习-网络基础
32 0
|
3月前
|
网络协议
网络编程【TCP单向通信、TCP双向通信、一对多应用、一对多聊天服务器】(二)-全面详解(学习总结---从入门到深化)(下)
网络编程【TCP单向通信、TCP双向通信、一对多应用、一对多聊天服务器】(二)-全面详解(学习总结---从入门到深化)
35 2
|
3月前
|
网络协议 Linux 数据处理
网络编程【网络编程基本概念、 网络通信协议、IP地址 、 TCP协议和UDP协议】(一)-全面详解(学习总结---从入门到深化)
网络编程【网络编程基本概念、 网络通信协议、IP地址 、 TCP协议和UDP协议】(一)-全面详解(学习总结---从入门到深化)
82 3
|
2月前
|
机器学习/深度学习 测试技术 Ruby
YOLOv5改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
YOLOv5改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
129 2
|
2天前
|
存储 网络协议 关系型数据库
Python从入门到精通:2.3.2数据库操作与网络编程——学习socket编程,实现简单的TCP/UDP通信
Python从入门到精通:2.3.2数据库操作与网络编程——学习socket编程,实现简单的TCP/UDP通信

热门文章

最新文章