【DSW Gallery】基于MOCOV2的自监督学习示例

简介: EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以自监督学习-MOCO为例,为您介绍如何在PAI-DSW中使用EasyCV。

直接使用

请打开基于MOCOV2的自监督学习示例,并点击右上角 “ 在DSW中打开” 。

image.png

EasyCV自监督训练-MOCOv2

  自监督学习(Self-Supervised Learning)能利用大量无标注的数据进行表征学习,然后在特定下游任务上对参数进行微调。通过这样的方式,能够在较少有标注数据上取得优于有监督学习方法的精度。

  近年来,自监督学习受到了越来越多的关注,如Yann Lecun也在 AAAI 上讲 Self-Supervised Learning 是未来的大势所趋。在CV领域涌现了如SwAV、MOCO、DINO、MoBY等一系列工作。

  本文将介绍如何在pai-dsw基于EasyCV快速使用MOCOV2进行度量模型的训练、微调、推理。

运行环境要求

PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 32G

安装依赖包

注:在PAI-DSW docker中无需安装相关依赖,可跳过此1,2步骤, 在本地notebook环境中执行1,2 步骤安装环境

1、获取torch和cuda版本,并根据版本号修改mmcv安装命令,安装对应版本的mmcv和nvidia-dali

import torch
import os
os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')
os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')
!echo $CUDA
!echo $Torch
[2023-02-03 11:17:50,705.705 dsw-16577-7f6b8db66d-qvgjd:382 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
cu101
torch1.8.2
# install some python deps
! pip install --upgrade tqdm
! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl

2、安装EasyCV算法包 注:在PAI-DSW docker中预安装了pai-easycv库,可跳过该步骤,若训练测试过程中报错,尝试用下方命令更新easycv版本

#pip install pai-easycv
! echo y | pip uninstall pai-easycv easycv
!pip install pai-easycv

3、简单验证

from easycv.apis import *

正式工作

数据准备

  自监督训练只需要提供无标注图片即可进行, 你可以下载ImageNet数据,或者使用你自己的图片数据。需要提供一个包含若干图片的文件夹路径p,以及一个文件列表,文件列表中是每个图片相对图片目录p的路径.

  图片文件夹结构示例如下, 文件夹路径为./images

images/
├── 0001.jpg
├── 0002.jpg
├── 0003.jpg
|...
└── 9999.jpg
文件列表内容示例如下
0001.jpg
0002.jpg
0003.jpg
...
9999.jpg

 为了快速走通流程,我们也提供了一个小的示例数据集,执行如下命令下载解压

! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz && tar -zxf imagenet_raw_demo.tar.gz
Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.
ERROR: could not open HSTS store at '/root/.wget-hsts'. HSTS will be disabled.
--2023-02-03 11:20:43--  http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz
Resolving pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com)... 39.98.20.13
Connecting to pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com)|39.98.20.13|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 33109699 (32M) [application/x-gzip]
Saving to: ‘imagenet_raw_demo.tar.gz’
imagenet_raw_demo.t 100%[===================>]  31.58M  17.5MB/s    in 1.8s    
2023-02-03 11:20:45 (17.5 MB/s) - ‘imagenet_raw_demo.tar.gz’ saved [33109699/33109699]
# 重命名文件夹
! mv imagenet_raw_demo  imagenet_raw

训练模型

这个Demo中我们采用mocov2自监督算法训练ResNet50 主干网络, 下载示例配置文件

! rm -rf mocov2_rn50_8xb32_200e_jpg.py
! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mocov2/mocov2_rn50_8xb32_200e_jpg.py
Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.
ERROR: could not open HSTS store at '/root/.wget-hsts'. HSTS will be disabled.
--2023-02-03 11:39:05--  http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mocov2/mocov2_rn50_8xb32_200e_jpg.py
Resolving pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com)... 39.98.20.13
Connecting to pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com)|39.98.20.13|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2182 (2.1K) [text/x-python]
Saving to: ‘mocov2_rn50_8xb32_200e_jpg.py’
mocov2_rn50_8xb32_2 100%[===================>]   2.13K  --.-KB/s    in 0s      
2023-02-03 11:39:05 (223 MB/s) - ‘mocov2_rn50_8xb32_200e_jpg.py’ saved [2182/2182]

为了缩短训练时间,打开配置文件 mae_vit_base_patch16_8xb64_1600e.py,修改total_epoch参数为10, 每隔1次迭代打印一次日志。

# runtime settings
total_epochs = 10
# log config
log_config=dict(interval=1)

正式训练时,建议使用单机8卡配合该配置文件使用,如果要使用单机单卡,建议调小optimizer.lr初始学习率

# 查看easycv安装位置
import easycv
print(easycv.__file__)
/home/pai/lib/python3.6/site-packages/easycv/__init__.py
!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \
/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mocov2_rn50_8xb32_200e_jpg.py --work_dir work_dir/selfsup/jpg/rn50_mocov2 --launcher pytorch

使用自监督模型进行特征抽取

模型导出

模型导出会对自监督模型信息裁剪,保留特征抽取必要的backbone和head

# 查看训练产生的pt文件
! ls  work_dir/selfsup/jpg/rn50_mocov2/*.pth
work_dir/selfsup/jpg/rn50_mocov2/epoch_10.pth
work_dir/selfsup/jpg/rn50_mocov2/epoch_2.pth
work_dir/selfsup/jpg/rn50_mocov2/epoch_4.pth
work_dir/selfsup/jpg/rn50_mocov2/epoch_6.pth
work_dir/selfsup/jpg/rn50_mocov2/epoch_8.pth
! python -m easycv.tools.export  mocov2_rn50_8xb32_200e_jpg.py work_dir/selfsup/jpg/rn50_mocov2/epoch_10.pth work_dir/selfsup/jpg/rn50_mocov2/export.pth
[2023-02-03 11:30:39,553.553 dsw-16577-7f6b8db66d-qvgjd:1615 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
mocov2_rn50_8xb32_200e_jpg.py
load checkpoint from local path: work_dir/selfsup/jpg/rn50_mocov2/epoch_10.pth
#下载测试图片
! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/product_detection/248347732153_1040.jpg
Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.
ERROR: could not open HSTS store at '/root/.wget-hsts'. HSTS will be disabled.
--2023-02-03 11:30:53--  http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/product_detection/248347732153_1040.jpg
Resolving pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com)... 39.98.20.13
Connecting to pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com (pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com)|39.98.20.13|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 122788 (120K) [image/jpeg]
Saving to: ‘248347732153_1040.jpg’
248347732153_1040.j 100%[===================>] 119.91K  --.-KB/s    in 0.1s    
2023-02-03 11:30:53 (1.06 MB/s) - ‘248347732153_1040.jpg’ saved [122788/122788]
import cv2
from easycv.predictors.feature_extractor import TorchFeatureExtractor
# 修改output_ckpt指向
output_ckpt = 'work_dir/selfsup/jpg/rn50_mocov2/export.pth'
fe = TorchFeatureExtractor(output_ckpt)
img = cv2.imread('248347732153_1040.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
feature = fe.predict([img])
print(feature[0]['feature'].shape)
print(feature[0])
load model from init weights
load checkpoint from local path: work_dir/selfsup/jpg/rn50_mocov2/export.pth
The model and loaded state dict do not match exactly
missing keys in source state_dict: head_0.fc_cls.weight, head_0.fc_cls.bias
(2048,)
{'feature': array([0.60371155, 1.058624  , 0.8255347 , ..., 0.3980614 , 0.8326854 ,
       0.6025144 ], dtype=float32)}

自监督预训练+ 图像分类finetune

参考EasyCV图像分类的demo, 在训练时加上--load_from 参数,使用自监督预训练的模型权重, 注意这里不需要使用

! python -m easycv.tools.train  r50.py --work_dir work_dirs/classification/cifar10/r50  --load_from work_dir/selfsup/jpg/rn50_mocov2/epoch_10.pth


相关实践学习
使用PAI+LLaMA Factory微调Qwen2-VL模型,搭建文旅领域知识问答机器人
使用PAI和LLaMA Factory框架,基于全参方法微调 Qwen2-VL模型,使其能够进行文旅领域知识问答,同时通过人工测试验证了微调的效果。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
机器学习/深度学习 运维
Moment:又一个开源的时间序列基础模型
MOMENT团队推出Time-series Pile,一个大型公共时间序列数据集,用于预训练首个开源时间序列模型家族。模型基于Transformer,采用遮蔽预训练技术,适用于预测、分类、异常检测和输入任务。研究发现,随机初始化比使用语言模型权重更有效,且直接预训练的模型表现出色。MOMENT改进了Transformer架构,调整了Layer norm并引入关系位置嵌入。模型在长期预测和异常检测中表现优异,但对于数值预测的效果尚不明朗。论文贡献包括开源方法、数据集创建和资源有限情况下的性能评估框架。
1203 0
|
Ubuntu 开发工具 计算机视觉
RK3588 RGA 图像操作
RK3588 RGA 图像操作
|
JSON 网络协议 Go
golang使用resty库实现模拟请求正方教务
本文主要讲解了如何使用golang模拟请求正方教务
806 0
|
6月前
|
缓存 NoSQL Java
一些高频面试题
这篇文章整理了一些高频面试题
216 0
|
6月前
|
机器学习/深度学习 资源调度 算法
【图像去噪的滤波器】非局部均值滤波器的实现,用于鲁棒的图像去噪研究(Matlab代码实现)
【图像去噪的滤波器】非局部均值滤波器的实现,用于鲁棒的图像去噪研究(Matlab代码实现)
230 2
|
容器
echarts的grid——图表的位置配置
echarts的grid——图表的位置配置
2695 1
|
机器学习/深度学习 数据采集 人工智能
从零构建:深度学习模型的新手指南###
【10月更文挑战第21天】 本文将深入浅出地解析深度学习的核心概念,为初学者提供一条清晰的学习路径,涵盖从理论基础到实践应用的全过程。通过比喻和实例,让复杂概念变得易于理解,旨在帮助读者搭建起深度学习的知识框架,为进一步探索人工智能领域奠定坚实基础。 ###
340 3
|
机器学习/深度学习 数据可视化 算法
基于深度学习的瓶子检测软件(UI界面+YOLOv5+训练数据集)
基于深度学习的瓶子检测软件(UI界面+YOLOv5+训练数据集)
903 0
|
小程序
微信小程序开发---购物商城系统。【详细业务需求描述+实现效果】
这篇文章详细介绍了作者开发的微信小程序购物商城系统,包括功能列表、项目结构、具体页面展示和部分源码,涵盖了从首页、商品分类、商品列表、商品详情、购物车、支付、订单查询、个人中心到商品收藏和意见反馈等多个页面的实现效果和业务需求描述。
微信小程序开发---购物商城系统。【详细业务需求描述+实现效果】
|
数据采集 API 索引
异步任务处理系统问题之异步任务处理系统的问题如何解决
异步任务处理系统问题之异步任务处理系统的问题如何解决
385 2

热门文章

最新文章