【DSW Gallery】基于残差网络的度量学习示例

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以度量学习为例,为您介绍如何在PAI-DSW中使用EasyCV。

直接使用

请打开基于基于残差网络的度量学习示例,并点击右上角 “ 在DSW中打开” 。

image.png

EasyCV度量学习

  度量学习又称为相似性学习,有广泛的应用,可在数据集上构建合适的距离度量来建模回答实际问题。在机器学习中distance metric learning(也称 metric learning,度量学习)是一个很典型的任务,通常与很多熟知的 metric-based methods(如 KNN、K-means 等)结合起来使用以实现分类或者聚类 ,效果通常非常不错。

  本文将介绍如何在pai-dsw基于EasyCV快速进行度量模型的训练、推理。

运行环境要求

PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 32G

安装依赖包

注:在PAI-DSW docker中无需安装相关依赖,可跳过此1,2步骤, 在本地notebook环境中执行1,2步骤安装环境

1、获取torch和cuda版本,并根据版本号修改mmcv安装命令,安装对应版本的mmcv和nvidia-dali

import torch
import os
os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')
os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')
!echo $CUDA
!echo $Torch
cu101
torch1.8.1+cu101
# install some python deps
! pip install --upgrade tqdm
! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl

2、安装EasyCV算法包 注:在PAI-DSW docker中预安装了pai-easycv库,可跳过该步骤,若训练测试过程中报错,尝试用下方命令更新easycv版本

#pip install pai-easycv
! echo y | pip uninstall pai-easycv easycv
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/2022_6/pai_easycv-0.3.0-py3-none-any.whl

3、简单验证

from easycv.apis import *

CUB200 度量学习

下面示例介绍如何利用CUB200数据,使用ResNet50模型快速进行图像分类模型的训练评估、模型预测过程

数据准备

下载cub200数据,解压到data/cub200目录, 目录结构如下

data/cub200
├── images
├── images.txt
├── image_class_labels.txt
├── train_test_split.txt
! mkdir -p data/ && wget https://s3.amazonaws.com/fast-ai-imageclas/CUB_200_2011.tgz && tar -xzf CUB_200_2011.tgz -C data/ && mv data/CUB_200_2011 data/cub200

训练模型

下载训练配置文件,该配置文件默认从基于ImageNet的预训练模型导入权重,如需自监督预训练模型可以从上文的链接中下载并替换config中的配置。

! rm -rf cub_resnet50_jpg.py
!wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/metric_learning/cub_resnet50_jpg.py

使用单卡gpu进行训练和验证集评估,为了快速跑通,可自行将cub_resnet50_jpg.py中的total_epoch参数设置成10。

!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29500 /home/pai/lib/python3.6/site-packages/easycv/tools/train.py cub_resnet50_jpg.py --work_dir work_dirs/metric_learning/cub/r50 --launcher pytorch --fp16

导出模型

模型训练完成,使用export命令导出模型进行推理,导出的模型包含推理时所需的预处理信息、后处理信息

# 查看训练产生的pt文件
! ls  work_dirs/metric_learning/cub/r50*

RetrivalTopKEvaluator_R@K=1_best.pth 是训练过程中产生的acc最高的pth,导出该模型

! python -m easycv.tools.export ./cub_resnet50_jpg.py work_dirs/metric_learning/cub/r50/RetrivalTopKEvaluator_R@K=1_best.pth  work_dirs/metric_learning/cub/r50/best_export.pth

预测

下载测试图片

! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/fine_grain_cls/cub_raw/images/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg

导入模型权重,并预测测试图片的分类结果

import cv2
from easycv.predictors.feature_extractor import TorchFeatureExtractor
output_ckpt = 'work_dirs/metric_learning/cub/r50/best_export.pth'
tcls = TorchFeatureExtractor(output_ckpt)
img = cv2.imread('Black_Footed_Albatross_0001_796111.jpg')
# input image should be RGB order
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
output = tcls.predict([img])
相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
26天前
|
网络安全 Python
Python网络编程小示例:生成CIDR表示的IP地址范围
本文介绍了如何使用Python生成CIDR表示的IP地址范围,通过解析CIDR字符串,将其转换为二进制形式,应用子网掩码,最终生成该CIDR块内所有可用的IP地址列表。示例代码利用了Python的`ipaddress`模块,展示了从指定CIDR表达式中提取所有IP地址的过程。
40 6
|
3月前
|
监控 网络协议 Linux
网络学习
网络学习
155 68
|
1月前
|
编解码 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(10-2):保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali——Liinux-Debian:就怕你学成黑客啦!)作者——LJS
保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali以及常见的报错及对应解决方案、常用Kali功能简便化以及详解如何具体实现
|
1月前
|
安全 网络协议 算法
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-1):主动信息收集之ping、Nmap 就怕你学成黑客啦!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-1):主动信息收集之ping、Nmap 就怕你学成黑客啦!
|
1月前
|
网络协议 安全 NoSQL
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!
scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练等具体操作详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法IKUN和I原们你这要是学不会我直接退出江湖;好吧!!!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!
|
1月前
|
网络协议 安全 算法
网络空间安全之一个WH的超前沿全栈技术深入学习之路(9):WireShark 简介和抓包原理及实战过程一条龙全线分析——就怕你学成黑客啦!
实战:WireShark 抓包及快速定位数据包技巧、使用 WireShark 对常用协议抓包并分析原理 、WireShark 抓包解决服务器被黑上不了网等具体操作详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法IKUN和I原们你这要是学不会我直接退出江湖;好吧!!!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(9):WireShark 简介和抓包原理及实战过程一条龙全线分析——就怕你学成黑客啦!
|
2月前
|
存储 安全 网络安全
浅谈网络安全的认识与学习规划
浅谈网络安全的认识与学习规划
47 6
|
1月前
|
人工智能 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS
|
1月前
|
安全 大数据 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(3-2):渗透测试行业术语扫盲)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(3-2):渗透测试行业术语扫盲)作者——LJS

热门文章

最新文章