【DSW Gallery】基于残差网络的图像分类示例

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以图像分类为例,为您介绍如何在PAI-DSW中使用EasyCV。

直接使用

请打开基于残差网络的图像分类示例,并点击右上角 “ 在DSW中打开” 。

image.png


EasyCV图像分类

  残差网络(ResNet)由微软研究院何凯明、张祥雨、任少卿、孙剑等人提出的模型结构,在2015年ILSVRC(ImageNet Large Scale Visual Recognition Challenge)比赛中取得了冠军,是目前CV领域中最常用的网络之一,作为分类、分割、检测等下游任务的主干网络在工业界和学术界都有广泛的应用。

  在EasyCV中,我们提供了多种ResNet结构的预训练模型可用于下游任务的fintune,包括在ImageNet数据集上的预训练模型(Link)以及MoCo、SwAV等自监督算法的预训练模型(Link)

  本文将介绍如何在pai-dsw基于EasyCV快速使用Resnet50进行图像分类模型的训练、推理。

运行环境要求

PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 32G

安装依赖包

1、获取torch和cuda版本,并根据版本号修改mmcv安装命令,安装对应版本的mmcv和nvidia-dali

import torch
import os
os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')
os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')
!echo $CUDA
!echo $Torch
cu101
torch1.8.1+cu101
# install some python deps
! pip install --upgrade tqdm
! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl

2、安装EasyCV算法包 注:在PAI-DSW docker中预安装了pai-easycv库,可跳过该步骤,若训练测试过程中报错,尝试用下方命令更新easycv版本

#pip install pai-easycv
! echo y | pip uninstall pai-easycv easycv
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/2022_6/pai_easycv-0.3.0-py3-none-any.whl

Cifar10 分类

下面示例介绍如何利用cifar10数据,使用ResNet50模型快速进行图像分类模型的训练评估、模型预测过程

数据准备

下载cifar10数据,解压到data/cifar目录, 目录结构如下

data/cifar
└── cifar-10-batches-py
    ├── batches.meta
    ├── data_batch_1
    ├── data_batch_2
    ├── data_batch_3
    ├── data_batch_4
    ├── data_batch_5
    ├── readme.html
    ├── read.py
    └── test_batch
! mkdir -p data/cifar && wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/cifar-10-python.tar.gz &&  tar -zxf cifar-10-python.tar.gz -C data/cifar/
--2022-06-28 10:10:19--  http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/cifar-10-python.tar.gz
Resolving pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com... 39.98.20.13
Connecting to pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com|39.98.20.13|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/gzip]
Saving to: ‘cifar-10-python.tar.gz’
cifar-10-python.tar 100%[===================>] 162.60M  11.8MB/s    in 13s     
2022-06-28 10:10:32 (12.7 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]

训练模型

下载训练配置文件,该配置文件默认从基于ImageNet的预训练模型导入权重,如需自监督预训练模型可以从上文的链接中下载并替换config中的配置。

! rm -rf r50_b128_10e_jpg.py
!wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/classification/cifar10/r50_b128_10e_jpg.py
--2022-06-28 10:10:37--  http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/configs/classification/cifar10/r50_b128_10e_jpg.py
Resolving pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com... 39.98.20.13
Connecting to pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com|39.98.20.13|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2185 (2.1K) [text/x-python]
Saving to: ‘r50_b128_10e_jpg.py’
r50_b128_10e_jpg.py 100%[===================>]   2.13K  --.-KB/s    in 0s      
2022-06-28 10:10:37 (296 MB/s) - ‘r50_b128_10e_jpg.py’ saved [2185/2185]

使用单卡gpu进行训练和验证集评估,为了快速跑通,默认设置epoch为10。

! python -m easycv.tools.train  r50_b128_10e_jpg.py --work_dir work_dirs/classification/cifar10/r50
Read base config from /home/pai/lib/python3.6/site-packages/easycv/configs/base.py
/home/pai/lib/python3.6/site-packages/easycv/utils/setup_env.py:37: UserWarning: Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
  f'Setting OMP_NUM_THREADS environment variable for each process '
/home/pai/lib/python3.6/site-packages/easycv/utils/setup_env.py:47: UserWarning: Setting MKL_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
  f'Setting MKL_NUM_THREADS environment variable for each process '
2022-06-28 10:10:52,525 - easycv - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.6.15 | packaged by conda-forge | (default, Dec  3 2021, 18:49:41) [GCC 9.4.0]
CUDA available: True
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 10.1, V10.1.243
GPU 0: Tesla V100-SXM2-16GB
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.8.1+cu101
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e5257a1af05e9ff683)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.1, CUDNN_VERSION=7.6.3, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 
TorchVision: 0.9.1+cu101
OpenCV: 4.6.0
MMCV: 1.4.4
EasyCV: 0.3.0
------------------------------------------------------------
2022-06-28 10:10:52,525 - easycv - INFO - Distributed training: False
2022-06-28 10:10:52,525 - easycv - INFO - Config:
/home/pai/lib/python3.6/site-packages/easycv/configs/base.py
train_cfg = {}
test_cfg = {}
optimizer_config = dict()  # grad_clip, coalesce, bucket_size_mb
# yapf:disable
log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
# runtime settings
dist_params = dict(backend='nccl')
cudnn_benchmark = False
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
/mnt/workspace/r50_b128_10e_jpg.py
_base_ = '../../base.py'
# model settings
model = dict(
    type='Classification',
    pretrained=
    'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth',
    backbone=dict(
        type='ResNet',
        depth=50,
        out_indices=[4],  # 4: stage-4
        norm_cfg=dict(type='BN')),
    head=dict(
        type='ClsHead', with_avg_pool=True, in_channels=2048, num_classes=10))
# dataset settings
class_list = [
    'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
    'ship', 'truck'
]
data_source_cfg = dict(type='ClsSourceCifar10', root='data/cifar/')
dataset_type = 'ClsDataset'
img_norm_cfg = dict(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
train_pipeline = [
    dict(type='RandomCrop', size=32, padding=4),
    dict(type='RandomHorizontalFlip'),
    dict(type='ToTensor'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Collect', keys=['img', 'gt_labels'])
]
test_pipeline = [
    dict(type='ToTensor'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Collect', keys=['img', 'gt_labels'])
]
data = dict(
    imgs_per_gpu=128,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_source=dict(split='train', **data_source_cfg),
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_source=dict(split='test', **data_source_cfg),
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_source=dict(split='test', **data_source_cfg),
        pipeline=test_pipeline))
# additional hooks
eval_config = dict(initial=True, interval=1, gpu_collect=True)
eval_pipelines = [
    dict(
        mode='test',
        data=data['val'],
        dist_eval=True,
        evaluators=[dict(type='ClsEvaluator', topk=(1, 5))],
    )
]
custom_hooks = []
# optimizer
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005)
# learning policy
lr_config = dict(policy='step', step=[150, 250])
checkpoint_config = dict(interval=1)
# runtime settings
total_epochs = 10
# log setting
log_config = dict(interval=100)
# export config
export = dict(export_neck=True)
2022-06-28 10:10:52,526 - easycv - INFO - Config Dict:
{"train_cfg": {}, "test_cfg": {}, "optimizer_config": {}, "log_config": {"interval": 100, "hooks": [{"type": "TextLoggerHook"}]}, "dist_params": {"backend": "nccl"}, "cudnn_benchmark": false, "log_level": "INFO", "load_from": null, "resume_from": null, "workflow": [["train", 1]], "model": {"type": "Classification", "pretrained": "http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth", "backbone": {"type": "ResNet", "depth": 50, "out_indices": [4], "norm_cfg": {"type": "BN"}}, "head": {"type": "ClsHead", "with_avg_pool": true, "in_channels": 2048, "num_classes": 10}}, "class_list": ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"], "data_source_cfg": {"type": "ClsSourceCifar10", "root": "data/cifar/"}, "dataset_type": "ClsDataset", "img_norm_cfg": {"mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.201]}, "train_pipeline": [{"type": "RandomCrop", "size": 32, "padding": 4}, {"type": "RandomHorizontalFlip"}, {"type": "ToTensor"}, {"type": "Normalize", "mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.201]}, {"type": "Collect", "keys": ["img", "gt_labels"]}], "test_pipeline": [{"type": "ToTensor"}, {"type": "Normalize", "mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.201]}, {"type": "Collect", "keys": ["img", "gt_labels"]}], "data": {"imgs_per_gpu": 128, "workers_per_gpu": 2, "train": {"type": "ClsDataset", "data_source": {"split": "train", "type": "ClsSourceCifar10", "root": "data/cifar/"}, "pipeline": [{"type": "RandomCrop", "size": 32, "padding": 4}, {"type": "RandomHorizontalFlip"}, {"type": "ToTensor"}, {"type": "Normalize", "mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.201]}, {"type": "Collect", "keys": ["img", "gt_labels"]}]}, "val": {"type": "ClsDataset", "data_source": {"split": "test", "type": "ClsSourceCifar10", "root": "data/cifar/"}, "pipeline": [{"type": "ToTensor"}, {"type": "Normalize", "mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.201]}, {"type": "Collect", "keys": ["img", "gt_labels"]}]}, "test": {"type": "ClsDataset", "data_source": {"split": "test", "type": "ClsSourceCifar10", "root": "data/cifar/"}, "pipeline": [{"type": "ToTensor"}, {"type": "Normalize", "mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.201]}, {"type": "Collect", "keys": ["img", "gt_labels"]}]}}, "eval_config": {"initial": true, "interval": 1, "gpu_collect": true}, "eval_pipelines": [{"mode": "test", "data": {"type": "ClsDataset", "data_source": {"split": "test", "type": "ClsSourceCifar10", "root": "data/cifar/"}, "pipeline": [{"type": "ToTensor"}, {"type": "Normalize", "mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.201]}, {"type": "Collect", "keys": ["img", "gt_labels"]}]}, "dist_eval": true, "evaluators": [{"type": "ClsEvaluator", "topk": [1, 5]}]}], "custom_hooks": [], "optimizer": {"type": "SGD", "lr": 0.001, "momentum": 0.9, "weight_decay": 0.0005}, "lr_config": {"policy": "step", "step": [150, 250]}, "checkpoint_config": {"interval": 1}, "total_epochs": 10, "export": {"export_neck": true}, "work_dir": "work_dirs/classification/cifar10/r50", "oss_work_dir": null, "gpus": 1}
2022-06-28 10:10:52,526 - easycv - INFO - GPU INFO : Tesla V100-SXM2-16GB
2022-06-28 10:10:52,743 - easycv - INFO - load checkpoint from http path: http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth
Downloading: "http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth" to /root/.cache/torch/hub/checkpoints/resnet50.pth
100%|██████████████████████████████████████| 97.8M/97.8M [00:07<00:00, 13.7MB/s]
2022-06-28 10:11:00,549 - easycv - WARNING - The model and loaded state dict do not match exactly
unexpected key in source state_dict: fc.weight, fc.bias
data shuffle: True
GPU INFO :  Tesla V100-SXM2-16GB
2022-06-28 10:11:05,463 - easycv - INFO - open validate hook
2022-06-28 10:11:05,896 - easycv - INFO - register EvaluationHook {'initial': True, 'evaluators': [<easycv.core.evaluation.classification_eval.ClsEvaluator object at 0x7f9d3aaf6048>]}
2022-06-28 10:11:05,896 - easycv - INFO - Start running, host: root@dsw-140073-56b898d496-rng6m, work_dir: /mnt/workspace/work_dirs/classification/cifar10/r50
2022-06-28 10:11:05,897 - easycv - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(ABOVE_NORMAL) OptimizerHook                      
(NORMAL      ) CheckpointHook                     
(NORMAL      ) EvalHook                           
(NORMAL      ) BestCkptSaverHook                  
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LOW         ) IterTimerHook                      
 -------------------- 
after_train_iter:
(ABOVE_NORMAL) OptimizerHook                      
(NORMAL      ) CheckpointHook                     
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
after_train_epoch:
(NORMAL      ) CheckpointHook                     
(NORMAL      ) EvalHook                           
(NORMAL      ) BestCkptSaverHook                  
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_val_epoch:
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_val_iter:
(LOW         ) IterTimerHook                      
 -------------------- 
after_val_iter:
(LOW         ) IterTimerHook                      
 -------------------- 
after_val_epoch:
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
after_run:
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
2022-06-28 10:11:05,898 - easycv - INFO - workflow: [('train', 1)], max: 10 epochs
2022-06-28 10:11:05,899 - easycv - INFO - Checkpoints will be saved to /mnt/workspace/work_dirs/classification/cifar10/r50 by HardDiskBackend.
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6519.5 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:11:07,470 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:11:07,652 - easycv - INFO - End SaveBest metric
2022-06-28 10:11:13,519 - easycv - INFO - Epoch [1][100/391]  lr: 1.000e-03, eta: 0:03:43, time: 0.059, data_time: 0.022, memory: 544, loss: 2.0069
2022-06-28 10:11:17,237 - easycv - INFO - Epoch [1][200/391]  lr: 1.000e-03, eta: 0:02:57, time: 0.037, data_time: 0.001, memory: 544, loss: 1.5517
2022-06-28 10:11:20,945 - easycv - INFO - Epoch [1][300/391]  lr: 1.000e-03, eta: 0:02:39, time: 0.037, data_time: 0.001, memory: 544, loss: 1.3852
2022-06-28 10:11:24,361 - easycv - INFO - Saving checkpoint at 1 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6869.4 task/s, elapsed: 1s, ETA:     0s
2022-06-28 10:11:26,171 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:11:26,724 - easycv - INFO - End SaveBest metric
2022-06-28 10:11:26,725 - easycv - INFO - Epoch(val) [1][391] ClsEvaluator_neck_top1: 67.8100, neck_top1: 67.8100, neck_top5: 97.2900
2022-06-28 10:11:32,589 - easycv - INFO - Epoch [2][100/391]  lr: 1.000e-03, eta: 0:02:13, time: 0.059, data_time: 0.022, memory: 544, loss: 1.2302
2022-06-28 10:11:36,410 - easycv - INFO - Epoch [2][200/391]  lr: 1.000e-03, eta: 0:02:08, time: 0.038, data_time: 0.001, memory: 544, loss: 1.1861
2022-06-28 10:11:40,123 - easycv - INFO - Epoch [2][300/391]  lr: 1.000e-03, eta: 0:02:04, time: 0.037, data_time: 0.001, memory: 544, loss: 1.1583
2022-06-28 10:11:43,538 - easycv - INFO - Saving checkpoint at 2 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6582.5 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:11:45,427 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:11:45,998 - easycv - INFO - End SaveBest metric
2022-06-28 10:11:45,998 - easycv - INFO - Epoch(val) [2][391] ClsEvaluator_neck_top1: 75.3600, neck_top1: 75.3600, neck_top5: 98.4600
2022-06-28 10:11:51,889 - easycv - INFO - Epoch [3][100/391]  lr: 1.000e-03, eta: 0:01:51, time: 0.059, data_time: 0.022, memory: 544, loss: 1.0935
2022-06-28 10:11:55,659 - easycv - INFO - Epoch [3][200/391]  lr: 1.000e-03, eta: 0:01:48, time: 0.038, data_time: 0.001, memory: 544, loss: 1.0723
2022-06-28 10:11:59,441 - easycv - INFO - Epoch [3][300/391]  lr: 1.000e-03, eta: 0:01:44, time: 0.038, data_time: 0.001, memory: 544, loss: 1.0663
2022-06-28 10:12:02,826 - easycv - INFO - Saving checkpoint at 3 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6455.9 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:12:04,726 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:12:05,288 - easycv - INFO - End SaveBest metric
2022-06-28 10:12:05,288 - easycv - INFO - Epoch(val) [3][391] ClsEvaluator_neck_top1: 78.3200, neck_top1: 78.3200, neck_top5: 98.8100
2022-06-28 10:12:11,133 - easycv - INFO - Epoch [4][100/391]  lr: 1.000e-03, eta: 0:01:35, time: 0.058, data_time: 0.022, memory: 544, loss: 1.0227
2022-06-28 10:12:14,905 - easycv - INFO - Epoch [4][200/391]  lr: 1.000e-03, eta: 0:01:31, time: 0.038, data_time: 0.001, memory: 544, loss: 1.0175
2022-06-28 10:12:18,635 - easycv - INFO - Epoch [4][300/391]  lr: 1.000e-03, eta: 0:01:28, time: 0.037, data_time: 0.001, memory: 544, loss: 1.0005
2022-06-28 10:12:22,109 - easycv - INFO - Saving checkpoint at 4 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6586.6 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:12:23,978 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:12:24,555 - easycv - INFO - End SaveBest metric
2022-06-28 10:12:24,555 - easycv - INFO - Epoch(val) [4][391] ClsEvaluator_neck_top1: 80.3300, neck_top1: 80.3300, neck_top5: 99.0100
2022-06-28 10:12:30,554 - easycv - INFO - Epoch [5][100/391]  lr: 1.000e-03, eta: 0:01:20, time: 0.060, data_time: 0.022, memory: 544, loss: 0.9649
2022-06-28 10:12:34,256 - easycv - INFO - Epoch [5][200/391]  lr: 1.000e-03, eta: 0:01:16, time: 0.037, data_time: 0.001, memory: 544, loss: 0.9584
2022-06-28 10:12:37,994 - easycv - INFO - Epoch [5][300/391]  lr: 1.000e-03, eta: 0:01:13, time: 0.037, data_time: 0.001, memory: 544, loss: 0.9724
2022-06-28 10:12:41,542 - easycv - INFO - Saving checkpoint at 5 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6427.7 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:12:43,447 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:12:44,037 - easycv - INFO - End SaveBest metric
2022-06-28 10:12:44,037 - easycv - INFO - Epoch(val) [5][391] ClsEvaluator_neck_top1: 81.8100, neck_top1: 81.8100, neck_top5: 99.0700
2022-06-28 10:12:49,890 - easycv - INFO - Epoch [6][100/391]  lr: 1.000e-03, eta: 0:01:05, time: 0.058, data_time: 0.022, memory: 544, loss: 0.9375
2022-06-28 10:12:53,609 - easycv - INFO - Epoch [6][200/391]  lr: 1.000e-03, eta: 0:01:02, time: 0.037, data_time: 0.001, memory: 544, loss: 0.9250
2022-06-28 10:12:57,373 - easycv - INFO - Epoch [6][300/391]  lr: 1.000e-03, eta: 0:00:58, time: 0.038, data_time: 0.001, memory: 544, loss: 0.9330
2022-06-28 10:13:00,831 - easycv - INFO - Saving checkpoint at 6 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6454.0 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:13:02,762 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:13:03,347 - easycv - INFO - End SaveBest metric
2022-06-28 10:13:03,347 - easycv - INFO - Epoch(val) [6][391] ClsEvaluator_neck_top1: 82.7200, neck_top1: 82.7200, neck_top5: 99.1400
2022-06-28 10:13:09,252 - easycv - INFO - Epoch [7][100/391]  lr: 1.000e-03, eta: 0:00:51, time: 0.059, data_time: 0.022, memory: 544, loss: 0.9097
2022-06-28 10:13:12,981 - easycv - INFO - Epoch [7][200/391]  lr: 1.000e-03, eta: 0:00:48, time: 0.037, data_time: 0.002, memory: 544, loss: 0.9089
2022-06-28 10:13:16,745 - easycv - INFO - Epoch [7][300/391]  lr: 1.000e-03, eta: 0:00:44, time: 0.038, data_time: 0.002, memory: 544, loss: 0.9057
2022-06-28 10:13:20,153 - easycv - INFO - Saving checkpoint at 7 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6489.4 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:13:22,066 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:13:22,647 - easycv - INFO - End SaveBest metric
2022-06-28 10:13:22,647 - easycv - INFO - Epoch(val) [7][391] ClsEvaluator_neck_top1: 83.5900, neck_top1: 83.5900, neck_top5: 99.1600
2022-06-28 10:13:28,536 - easycv - INFO - Epoch [8][100/391]  lr: 1.000e-03, eta: 0:00:37, time: 0.059, data_time: 0.023, memory: 544, loss: 0.8750
2022-06-28 10:13:32,282 - easycv - INFO - Epoch [8][200/391]  lr: 1.000e-03, eta: 0:00:34, time: 0.037, data_time: 0.001, memory: 544, loss: 0.8953
2022-06-28 10:13:36,017 - easycv - INFO - Epoch [8][300/391]  lr: 1.000e-03, eta: 0:00:30, time: 0.037, data_time: 0.001, memory: 544, loss: 0.8790
2022-06-28 10:13:39,425 - easycv - INFO - Saving checkpoint at 8 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6329.1 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:13:41,386 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:13:41,970 - easycv - INFO - End SaveBest metric
2022-06-28 10:13:41,970 - easycv - INFO - Epoch(val) [8][391] ClsEvaluator_neck_top1: 84.2600, neck_top1: 84.2600, neck_top5: 99.1700
2022-06-28 10:13:47,907 - easycv - INFO - Epoch [9][100/391]  lr: 1.000e-03, eta: 0:00:23, time: 0.059, data_time: 0.022, memory: 544, loss: 0.8629
2022-06-28 10:13:51,654 - easycv - INFO - Epoch [9][200/391]  lr: 1.000e-03, eta: 0:00:20, time: 0.037, data_time: 0.001, memory: 544, loss: 0.8670
2022-06-28 10:13:55,404 - easycv - INFO - Epoch [9][300/391]  lr: 1.000e-03, eta: 0:00:16, time: 0.037, data_time: 0.002, memory: 544, loss: 0.8525
2022-06-28 10:13:58,823 - easycv - INFO - Saving checkpoint at 9 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 6273.3 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:14:00,771 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:14:01,358 - easycv - INFO - End SaveBest metric
2022-06-28 10:14:01,358 - easycv - INFO - Epoch(val) [9][391] ClsEvaluator_neck_top1: 84.6100, neck_top1: 84.6100, neck_top5: 99.1300
2022-06-28 10:14:07,422 - easycv - INFO - Epoch [10][100/391] lr: 1.000e-03, eta: 0:00:10, time: 0.061, data_time: 0.023, memory: 544, loss: 0.8362
2022-06-28 10:14:11,350 - easycv - INFO - Epoch [10][200/391] lr: 1.000e-03, eta: 0:00:06, time: 0.039, data_time: 0.002, memory: 544, loss: 0.8471
2022-06-28 10:14:15,212 - easycv - INFO - Epoch [10][300/391] lr: 1.000e-03, eta: 0:00:03, time: 0.039, data_time: 0.002, memory: 544, loss: 0.8429
2022-06-28 10:14:18,680 - easycv - INFO - Saving checkpoint at 10 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 5919.4 task/s, elapsed: 2s, ETA:     0s
2022-06-28 10:14:20,735 - easycv - INFO - SaveBest metric_name: ['ClsEvaluator_neck_top1']
2022-06-28 10:14:21,332 - easycv - INFO - End SaveBest metric
2022-06-28 10:14:21,332 - easycv - INFO - Epoch(val) [10][391]  ClsEvaluator_neck_top1: 84.9100, neck_top1: 84.9100, neck_top5: 99.1500

导出模型

模型训练完成,使用export命令导出模型进行推理,导出的模型包含推理时所需的预处理信息、后处理信息

# 查看训练产生的pt文件
! ls  work_dirs/classification/cifar10/r50*
20220628_101052.log    epoch_1.pth  epoch_5.pth  epoch_9.pth
20220628_101052.log.json   epoch_2.pth  epoch_6.pth
ClsEvaluator_neck_top1_best.pth  epoch_3.pth  epoch_7.pth
epoch_10.pth       epoch_4.pth  epoch_8.pth

ClsEvaluator_neck_top1_best.pth 是训练过程中产生的acc最高的pth,导出该模型

! python -m easycv.tools.export r50_b128_10e_jpg.py work_dirs/classification/cifar10/r50/ClsEvaluator_neck_top1_best.pth  work_dirs/classification/cifar10/r50/best_export.pth
r50_b128_10e_jpg.py
Read base config from /home/pai/lib/python3.6/site-packages/easycv/configs/base.py
2022-06-28 10:15:03,604 - easycv - INFO - load checkpoint from http path: http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth
2022-06-28 10:15:03,832 - easycv - WARNING - The model and loaded state dict do not match exactly
unexpected key in source state_dict: fc.weight, fc.bias
load checkpoint from local path: work_dirs/classification/cifar10/r50/ClsEvaluator_neck_top1_best.pth

预测

下载测试图片

! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/qince_data/predict/aeroplane_s_000004.png
--2022-06-28 10:15:08--  http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/qince_data/predict/aeroplane_s_000004.png
Resolving pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com... 39.98.20.13
Connecting to pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com|39.98.20.13|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2391 (2.3K) [image/png]
Saving to: ‘aeroplane_s_000004.png’
aeroplane_s_000004. 100%[===================>]   2.33K  --.-KB/s    in 0s      
2022-06-28 10:15:08 (333 MB/s) - ‘aeroplane_s_000004.png’ saved [2391/2391]

导入模型权重,并预测测试图片的分类结果

import cv2
from easycv.predictors.classifier import TorchClassifier
output_ckpt = 'work_dirs/classification/cifar10/r50/best_export.pth'
tcls = TorchClassifier(output_ckpt, topk=1)
img = cv2.imread('aeroplane_s_000004.png')
# input image should be RGB order
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
output = tcls.predict([img])
print(output)
2022-06-28 10:15:22,071 - easycv - INFO - load model from default path: http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth
2022-06-28 10:15:22,073 - easycv - INFO - load checkpoint from http path: http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth
2022-06-28 10:15:22,302 - easycv - WARNING - The model and loaded state dict do not match exactly
unexpected key in source state_dict: fc.weight, fc.bias
load checkpoint from local path: work_dirs/classification/cifar10/r50/best_export.pth
[{'class': [0], 'class_name': ['airplane'], 'class_probs': {'airplane': 0.88359135, 'automobile': 0.014351366, 'bird': 0.017601196, 'cat': 0.008215358, 'deer': 0.009162315, 'dog': 0.011745055, 'frog': 0.0109530995, 'horse': 0.019746583, 'ship': 0.008050704, 'truck': 0.016582947}}]
相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
2月前
|
网络协议 网络虚拟化 数据中心
华为配置VXLAN构建虚拟网络实现相同网段互通示例(静态方式)
配置VXLAN构建虚拟网络实现相同网段互通示例(静态方式
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
90 0
|
4月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
基于深度学习的图像分类:使用卷积神经网络实现猫狗分类器
基于深度学习的图像分类:使用卷积神经网络实现猫狗分类器
76 0
|
3月前
|
机器学习/深度学习 测试技术 Ruby
YOLOv5改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
YOLOv5改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
138 2
|
8天前
|
机器学习/深度学习 算法
基于深度学习网络的十二生肖图像分类matlab仿真
该内容是关于使用GoogLeNet算法进行十二生肖图像分类的总结。在MATLAB2022a环境下,GoogLeNet通过Inception模块学习高层语义特征,处理不同尺寸的输入。核心程序展示了验证集上部分图像的预测标签和置信度,以4x4网格显示16张图像,每张附带预测类别和概率。
|
15天前
|
机器学习/深度学习 数据可视化 数据挖掘
【视频】少样本图像分类?迁移学习、自监督学习理论和R语言CNN深度学习卷积神经网络实例
【视频】少样本图像分类?迁移学习、自监督学习理论和R语言CNN深度学习卷积神经网络实例
|
7月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
使用卷积神经网络(CNN)进行图像分类与识别
使用卷积神经网络(CNN)进行图像分类与识别
262 0
|
7月前
|
机器学习/深度学习 数据可视化 TensorFlow
使用卷积神经网络构建一个图像分类模型
使用卷积神经网络构建一个图像分类模型
80 0
|
4月前
|
网络协议 NoSQL Linux
非阻塞socket网络编程之数据收发完整代码示例
非阻塞socket网络编程之数据收发完整代码示例
|
5月前
|
机器学习/深度学习 算法 TensorFlow
【Python机器学习】神经网络中误差反向传播(BP)算法详解及代码示例(图文解释 附源码)
【Python机器学习】神经网络中误差反向传播(BP)算法详解及代码示例(图文解释 附源码)
50 0

热门文章

最新文章