视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库(上)

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: 视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库(上)

视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库


PyTorchImageModels,简称timm,是一个巨大的PyTorch代码集合,包括了一系列:


image models

layers

utilities

optimizers

schedulers

data-loaders / augmentations

training / validation scripts


旨在将各种SOTA模型整合在一起,并具有复现ImageNet训练结果的能力。


PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.


除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由来自加拿大温哥华Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下几个链接。


Github链接:https://github.com/rwightman/pytorch-image-models

官网链接:https://fastai.github.io/timmdocs/https://rwightman.github.io/pytorch-image-models/

简略文档:https://rwightman.github.io/pytorch-image-models/

详细文档:https://fastai.github.io/timmdocs/


安装


PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.


timm 提供了参考的 training 和 validation 脚本,用于复现在 ImageNet 上的训练结果;以及更多的 官方文档 和 timmdocs project.


https://rwightman.github.io/pytorch-image-models/


https://fastai.github.io/timmdocs/


但,由于 timm 的功能之多,所以在定制使用时很难知道如何入手. 这里主要进行概述.


pip install timm==0.5.4


所有的开发和测试都是在 Linux x86-64系统上的 Conda Python 3环境中完成的,尤其是 Python 3.6和3.7 3.8 3.9

PyTorch 版本1.4、1.5. x、1.6、1.7. x 和1.8已经使用此代码进行了测试。


import timm


加载预先训练好的模型


我们只需要简单的create_model就可以得到我们的模型,并且如果我们需要使用我们的预训练模型,只需要加上参数pretrained=True即可


import timm
m = timm.create_model('mobilenetv3_large_100', pretrained=True)
m.eval()
MobileNetV3(
  (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): Hardswish()
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (se): Identity()
        (conv_pw): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Identity()
      )
    )
    (1): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): Identity()
        (conv_pwl): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)
        (bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): Identity()
        (conv_pwl): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(72, 72, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=72, bias=False)
        (bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(72, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
        (bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
        (bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=240, bias=False)
        (bn2): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(80, 200, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=200, bias=False)
        (bn2): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
        (bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): InvertedResidual(
        (conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
        (bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (4): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=480, bias=False)
        (bn2): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(672, 672, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=672, bias=False)
        (bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=672, bias=False)
        (bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(672, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
        (bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
        (bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): Sequential(
      (0): ConvBnAct(
        (conv): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
      )
    )
  )
  (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
  (conv_head): Conv2d(960, 1280, kernel_size=(1, 1), stride=(1, 1))
  (act2): Hardswish()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (classifier): Linear(in_features=1280, out_features=1000, bias=True)
)


列出具有预训练权重的模型


我们可以简单看看,一共大概有400多个模型,我们都是可以随意使用的

import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
['adv_inception_v3',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspresnet50',
 'cspresnext50',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'dpn68',
 'dpn68b',
 'dpn92',
 'dpn98',
 'dpn107',
 'dpn131',
 'eca_nfnet_l0',
 'eca_nfnet_l1',
 'eca_nfnet_l2',
 'ecaresnet26t',
 'ecaresnet50d',
 'ecaresnet50d_pruned',
 'ecaresnet50t',
 'ecaresnet101d',
 'ecaresnet101d_pruned',
 'ecaresnet269d',
 'ecaresnetlight',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b1_pruned',
 'efficientnet_b2',
 'efficientnet_b2_pruned',
 'efficientnet_b3',
 'efficientnet_b3_pruned',
 'efficientnet_b4',
 'efficientnet_el',
 'efficientnet_el_pruned',
 'efficientnet_em',
 'efficientnet_es',
 'efficientnet_es_pruned',
 'efficientnet_lite0',
 'efficientnetv2_rw_m',
 'efficientnetv2_rw_s',
 'ens_adv_inception_resnet_v2',
 'ese_vovnet19b_dw',
 'ese_vovnet39b',
 'fbnetc_100',
 'gernet_l',
 'gernet_m',
 'gernet_s',
 'ghostnet_100',
 'gluon_inception_v3',
 'gluon_resnet18_v1b',
 'gluon_resnet34_v1b',
 'gluon_resnet50_v1b',
 'gluon_resnet50_v1c',
 'gluon_resnet50_v1d',
 'gluon_resnet50_v1s',
 'gluon_resnet101_v1b',
 'gluon_resnet101_v1c',
 'gluon_resnet101_v1d',
 'gluon_resnet101_v1s',
 'gluon_resnet152_v1b',
 'gluon_resnet152_v1c',
 'gluon_resnet152_v1d',
 'gluon_resnet152_v1s',
 'gluon_resnext50_32x4d',
 'gluon_resnext101_32x4d',
 'gluon_resnext101_64x4d',
 'gluon_senet154',
 'gluon_seresnext50_32x4d',
 'gluon_seresnext101_32x4d',
 'gluon_seresnext101_64x4d',
 'gluon_xception65',
 'gmixer_24_224',
 'gmlp_s16_224',
 'hardcorenas_a',
 'hardcorenas_b',
 'hardcorenas_c',
 'hardcorenas_d',
 'hardcorenas_e',
 'hardcorenas_f',
 'hrnet_w18',
 'hrnet_w18_small',
 'hrnet_w18_small_v2',
 'hrnet_w30',
 'hrnet_w32',
 'hrnet_w40',
 'hrnet_w44',
 'hrnet_w48',
 'hrnet_w64',
 'ig_resnext101_32x8d',
 'ig_resnext101_32x16d',
 'ig_resnext101_32x32d',
 'ig_resnext101_32x48d',
 'inception_resnet_v2',
 'inception_v3',
 'inception_v4',
 'legacy_senet154',
 'legacy_seresnet18',
 'legacy_seresnet34',
 'legacy_seresnet50',
 'legacy_seresnet101',
 'legacy_seresnet152',
 'legacy_seresnext26_32x4d',
 'legacy_seresnext50_32x4d',
 'legacy_seresnext101_32x4d',
 'levit_128',
 'levit_128s',
 'levit_192',
 'levit_256',
 'levit_384',
 'mixer_b16_224',
 'mixer_b16_224_in21k',
 'mixer_b16_224_miil',
 'mixer_b16_224_miil_in21k',
 'mixer_l16_224',
 'mixer_l16_224_in21k',
 'mixnet_l',
 'mixnet_m',
 'mixnet_s',
 'mixnet_xl',
 'mnasnet_100',
 'mobilenetv2_100',
 'mobilenetv2_110d',
 'mobilenetv2_120d',
 'mobilenetv2_140',
 'mobilenetv3_large_100',
 'mobilenetv3_large_100_miil',
 'mobilenetv3_large_100_miil_in21k',
 'mobilenetv3_rw',
 'nasnetalarge',
 'nf_regnet_b1',
 'nf_resnet50',
 'nfnet_l0',
 'pit_b_224',
 'pit_b_distilled_224',
 'pit_s_224',
 'pit_s_distilled_224',
 'pit_ti_224',
 'pit_ti_distilled_224',
 'pit_xs_224',
 'pit_xs_distilled_224',
 'pnasnet5large',
 'regnetx_002',
 'regnetx_004',
 'regnetx_006',
 'regnetx_008',
 'regnetx_016',
 'regnetx_032',
 'regnetx_040',
 'regnetx_064',
 'regnetx_080',
 'regnetx_120',
 'regnetx_160',
 'regnetx_320',
 'regnety_002',
 'regnety_004',
 'regnety_006',
 'regnety_008',
 'regnety_016',
 'regnety_032',
 'regnety_040',
 'regnety_064',
 'regnety_080',
 'regnety_120',
 'regnety_160',
 'regnety_320',
 'repvgg_a2',
 'repvgg_b0',
 'repvgg_b1',
 'repvgg_b1g4',
 'repvgg_b2',
 'repvgg_b2g4',
 'repvgg_b3',
 'repvgg_b3g4',
 'res2net50_14w_8s',
 'res2net50_26w_4s',
 'res2net50_26w_6s',
 'res2net50_26w_8s',
 'res2net50_48w_2s',
 'res2net101_26w_4s',
 'res2next50',
 'resmlp_12_224',
 'resmlp_12_distilled_224',
 'resmlp_24_224',
 'resmlp_24_distilled_224',
 'resmlp_36_224',
 'resmlp_36_distilled_224',
 'resmlp_big_24_224',
 'resmlp_big_24_224_in22ft1k',
 'resmlp_big_24_distilled_224',
 'resnest14d',
 'resnest26d',
 'resnest50d',
 'resnest50d_1s4x24d',
 'resnest50d_4s2x40d',
 'resnest101e',
 'resnest200e',
 'resnest269e',
 'resnet18',
 'resnet18d',
 'resnet26',
 'resnet26d',
 'resnet34',
 'resnet34d',
 'resnet50',
 'resnet50d',
 'resnet51q',
 'resnet101d',
 'resnet152d',
 'resnet200d',
 'resnetblur50',
 'resnetrs50',
 'resnetrs101',
 'resnetrs152',
 'resnetrs200',
 'resnetrs270',
 'resnetrs350',
 'resnetrs420',
 'resnetv2_50x1_bit_distilled',
 'resnetv2_50x1_bitm',
 'resnetv2_50x1_bitm_in21k',
 'resnetv2_50x3_bitm',
 'resnetv2_50x3_bitm_in21k',
 'resnetv2_101x1_bitm',
 'resnetv2_101x1_bitm_in21k',
 'resnetv2_101x3_bitm',
 'resnetv2_101x3_bitm_in21k',
 'resnetv2_152x2_bit_teacher',
 'resnetv2_152x2_bit_teacher_384',
 'resnetv2_152x2_bitm',
 'resnetv2_152x2_bitm_in21k',
 'resnetv2_152x4_bitm',
 'resnetv2_152x4_bitm_in21k',
 'resnext50_32x4d',
 'resnext50d_32x4d',
 'resnext101_32x8d',
 'rexnet_100',
 'rexnet_130',
 'rexnet_150',
 'rexnet_200',
 'selecsls42b',
 'selecsls60',
 'selecsls60b',
 'semnasnet_100',
 'seresnet50',
 'seresnet152d',
 'seresnext26d_32x4d',
 'seresnext26t_32x4d',
 'seresnext50_32x4d',
 'skresnet18',
 'skresnet34',
 'skresnext50_32x4d',
 'spnasnet_100',
 'ssl_resnet18',
 'ssl_resnet50',
 'ssl_resnext50_32x4d',
 'ssl_resnext101_32x4d',
 'ssl_resnext101_32x8d',
 'ssl_resnext101_32x16d',
 'swin_base_patch4_window7_224',
 'swin_base_patch4_window7_224_in22k',
 'swin_base_patch4_window12_384',
 'swin_base_patch4_window12_384_in22k',
 'swin_large_patch4_window7_224',
 'swin_large_patch4_window7_224_in22k',
 'swin_large_patch4_window12_384',
 'swin_large_patch4_window12_384_in22k',
 'swin_small_patch4_window7_224',
 'swin_tiny_patch4_window7_224',
 'swsl_resnet18',
 'swsl_resnet50',
 'swsl_resnext50_32x4d',
 'swsl_resnext101_32x4d',
 'swsl_resnext101_32x8d',
 'swsl_resnext101_32x16d',
 'tf_efficientnet_b0',
 'tf_efficientnet_b0_ap',
 'tf_efficientnet_b0_ns',
 'tf_efficientnet_b1',
 'tf_efficientnet_b1_ap',
 'tf_efficientnet_b1_ns',
 'tf_efficientnet_b2',
 'tf_efficientnet_b2_ap',
 'tf_efficientnet_b2_ns',
 'tf_efficientnet_b3',
 'tf_efficientnet_b3_ap',
 'tf_efficientnet_b3_ns',
 'tf_efficientnet_b4',
 'tf_efficientnet_b4_ap',
 'tf_efficientnet_b4_ns',
 'tf_efficientnet_b5',
 'tf_efficientnet_b5_ap',
 'tf_efficientnet_b5_ns',
 'tf_efficientnet_b6',
 'tf_efficientnet_b6_ap',
 'tf_efficientnet_b6_ns',
 'tf_efficientnet_b7',
 'tf_efficientnet_b7_ap',
 'tf_efficientnet_b7_ns',
 'tf_efficientnet_b8',
 'tf_efficientnet_b8_ap',
 'tf_efficientnet_cc_b0_4e',
 'tf_efficientnet_cc_b0_8e',
 'tf_efficientnet_cc_b1_8e',
 'tf_efficientnet_el',
 'tf_efficientnet_em',
 'tf_efficientnet_es',
 'tf_efficientnet_l2_ns',
 'tf_efficientnet_l2_ns_475',
 'tf_efficientnet_lite0',
 'tf_efficientnet_lite1',
 'tf_efficientnet_lite2',
 'tf_efficientnet_lite3',
 'tf_efficientnet_lite4',
 'tf_efficientnetv2_b0',
 'tf_efficientnetv2_b1',
 'tf_efficientnetv2_b2',
 'tf_efficientnetv2_b3',
 'tf_efficientnetv2_l',
 'tf_efficientnetv2_l_in21ft1k',
 'tf_efficientnetv2_l_in21k',
 'tf_efficientnetv2_m',
 'tf_efficientnetv2_m_in21ft1k',
 'tf_efficientnetv2_m_in21k',
 'tf_efficientnetv2_s',
 'tf_efficientnetv2_s_in21ft1k',
 'tf_efficientnetv2_s_in21k',
 'tf_inception_v3',
 'tf_mixnet_l',
 'tf_mixnet_m',
 'tf_mixnet_s',
 'tf_mobilenetv3_large_075',
 'tf_mobilenetv3_large_100',
 'tf_mobilenetv3_large_minimal_100',
 'tf_mobilenetv3_small_075',
 'tf_mobilenetv3_small_100',
 'tf_mobilenetv3_small_minimal_100',
 'tnt_s_patch16_224',
 'tresnet_l',
 'tresnet_l_448',
 'tresnet_m',
 'tresnet_m_448',
 'tresnet_m_miil_in21k',
 'tresnet_xl',
 'tresnet_xl_448',
 'tv_densenet121',
 'tv_resnet34',
 'tv_resnet50',
 'tv_resnet101',
 'tv_resnet152',
 'tv_resnext50_32x4d',
 'twins_pcpvt_base',
 'twins_pcpvt_large',
 'twins_pcpvt_small',
 'twins_svt_base',
 'twins_svt_large',
 'twins_svt_small',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'visformer_small',
 'vit_base_patch16_224',
 'vit_base_patch16_224_in21k',
 'vit_base_patch16_224_miil',
 'vit_base_patch16_224_miil_in21k',
 'vit_base_patch16_384',
 'vit_base_patch32_224',
 'vit_base_patch32_224_in21k',
 'vit_base_patch32_384',
 'vit_base_r50_s16_224_in21k',
 'vit_base_r50_s16_384',
 'vit_huge_patch14_224_in21k',
 'vit_large_patch16_224',
 'vit_large_patch16_224_in21k',
 'vit_large_patch16_384',
 'vit_large_patch32_224_in21k',
 'vit_large_patch32_384',
 'vit_large_r50_s32_224',
 'vit_large_r50_s32_224_in21k',
 'vit_large_r50_s32_384',
 'vit_small_patch16_224',
 'vit_small_patch16_224_in21k',
 'vit_small_patch16_384',
 'vit_small_patch32_224',
 'vit_small_patch32_224_in21k',
 'vit_small_patch32_384',
 'vit_small_r26_s32_224',
 'vit_small_r26_s32_224_in21k',
 'vit_small_r26_s32_384',
 'vit_tiny_patch16_224',
 'vit_tiny_patch16_224_in21k',
 'vit_tiny_patch16_384',
 'vit_tiny_r_s16_p8_224',
 'vit_tiny_r_s16_p8_224_in21k',
 'vit_tiny_r_s16_p8_384',
 'wide_resnet50_2',
 'wide_resnet101_2',
 'xception',
 'xception41',
 'xception65',
 'xception71']

通过通配符选择模型架构


这个方法,可以让我们快速找到我们所需要的模型,这样可以方便我们进行create_model

model_names = timm.list_models('*resne*t*')
pprint(model_names)
['bat_resnext26ts',
 'cspresnet50',
 'cspresnet50d',
 'cspresnet50w',
 'cspresnext50',
 'cspresnext50_iabn',
 'eca_lambda_resnext26ts',
 'ecaresnet26t',
 'ecaresnet50d',
 'ecaresnet50d_pruned',
 'ecaresnet50t',
 'ecaresnet101d',
 'ecaresnet101d_pruned',
 'ecaresnet200d',
 'ecaresnet269d',
 'ecaresnetlight',
 'ecaresnext26t_32x4d',
 'ecaresnext50t_32x4d',
 'ens_adv_inception_resnet_v2',
 'gcresnet50t',
 'gcresnext26ts',
 'geresnet50t',
 'gluon_resnet18_v1b',
 'gluon_resnet34_v1b',
 'gluon_resnet50_v1b',
 'gluon_resnet50_v1c',
 'gluon_resnet50_v1d',
 'gluon_resnet50_v1s',
 'gluon_resnet101_v1b',
 'gluon_resnet101_v1c',
 'gluon_resnet101_v1d',
 'gluon_resnet101_v1s',
 'gluon_resnet152_v1b',
 'gluon_resnet152_v1c',
 'gluon_resnet152_v1d',
 'gluon_resnet152_v1s',
 'gluon_resnext50_32x4d',
 'gluon_resnext101_32x4d',
 'gluon_resnext101_64x4d',
 'gluon_seresnext50_32x4d',
 'gluon_seresnext101_32x4d',
 'gluon_seresnext101_64x4d',
 'ig_resnext101_32x8d',
 'ig_resnext101_32x16d',
 'ig_resnext101_32x32d',
 'ig_resnext101_32x48d',
 'inception_resnet_v2',
 'lambda_resnet26t',
 'lambda_resnet50t',
 'legacy_seresnet18',
 'legacy_seresnet34',
 'legacy_seresnet50',
 'legacy_seresnet101',
 'legacy_seresnet152',
 'legacy_seresnext26_32x4d',
 'legacy_seresnext50_32x4d',
 'legacy_seresnext101_32x4d',
 'nf_ecaresnet26',
 'nf_ecaresnet50',
 'nf_ecaresnet101',
 'nf_resnet26',
 'nf_resnet50',
 'nf_resnet101',
 'nf_seresnet26',
 'nf_seresnet50',
 'nf_seresnet101',
 'resnest14d',
 'resnest26d',
 'resnest50d',
 'resnest50d_1s4x24d',
 'resnest50d_4s2x40d',
 'resnest101e',
 'resnest200e',
 'resnest269e',
 'resnet18',
 'resnet18d',
 'resnet26',
 'resnet26d',
 'resnet26t',
 'resnet34',
 'resnet34d',
 'resnet50',
 'resnet50d',
 'resnet50t',
 'resnet51q',
 'resnet61q',
 'resnet101',
 'resnet101d',
 'resnet152',
 'resnet152d',
 'resnet200',
 'resnet200d',
 'resnetblur18',
 'resnetblur50',
 'resnetrs50',
 'resnetrs101',
 'resnetrs152',
 'resnetrs200',
 'resnetrs270',
 'resnetrs350',
 'resnetrs420',
 'resnetv2_50',
 'resnetv2_50d',
 'resnetv2_50t',
 'resnetv2_50x1_bit_distilled',
 'resnetv2_50x1_bitm',
 'resnetv2_50x1_bitm_in21k',
 'resnetv2_50x3_bitm',
 'resnetv2_50x3_bitm_in21k',
 'resnetv2_101',
 'resnetv2_101d',
 'resnetv2_101x1_bitm',
 'resnetv2_101x1_bitm_in21k',
 'resnetv2_101x3_bitm',
 'resnetv2_101x3_bitm_in21k',
 'resnetv2_152',
 'resnetv2_152d',
 'resnetv2_152x2_bit_teacher',
 'resnetv2_152x2_bit_teacher_384',
 'resnetv2_152x2_bitm',
 'resnetv2_152x2_bitm_in21k',
 'resnetv2_152x4_bitm',
 'resnetv2_152x4_bitm_in21k',
 'resnext50_32x4d',
 'resnext50d_32x4d',
 'resnext101_32x4d',
 'resnext101_32x8d',
 'resnext101_64x4d',
 'seresnet18',
 'seresnet34',
 'seresnet50',
 'seresnet50t',
 'seresnet101',
 'seresnet152',
 'seresnet152d',
 'seresnet200d',
 'seresnet269d',
 'seresnext26d_32x4d',
 'seresnext26t_32x4d',
 'seresnext26tn_32x4d',
 'seresnext50_32x4d',
 'seresnext101_32x4d',
 'seresnext101_32x8d',
 'skresnet18',
 'skresnet34',
 'skresnet50',
 'skresnet50d',
 'skresnext50_32x4d',
 'ssl_resnet18',
 'ssl_resnet50',
 'ssl_resnext50_32x4d',
 'ssl_resnext101_32x4d',
 'ssl_resnext101_32x8d',
 'ssl_resnext101_32x16d',
 'swsl_resnet18',
 'swsl_resnet50',
 'swsl_resnext50_32x4d',
 'swsl_resnext101_32x4d',
 'swsl_resnext101_32x8d',
 'swsl_resnext101_32x16d',
 'tresnet_l',
 'tresnet_l_448',
 'tresnet_m',
 'tresnet_m_448',
 'tresnet_m_miil_in21k',
 'tresnet_xl',
 'tresnet_xl_448',
 'tv_resnet34',
 'tv_resnet50',
 'tv_resnet101',
 'tv_resnet152',
 'tv_resnext50_32x4d',
 'vit_base_resnet26d_224',
 'vit_base_resnet50_224_in21k',
 'vit_base_resnet50_384',
 'vit_base_resnet50d_224',
 'vit_small_resnet26d_224',
 'vit_small_resnet50d_s16_224',
 'wide_resnet50_2',
 'wide_resnet101_2']


https://rwightman.github.io/pytorch-image-models/models/ 介绍了timm实现的一些网络模型及其论文和参考代码

https://paperswithcode.com/lib/timm 也有列出


模型及论文


CNN模型:

添加了经典的 NFNet,RegNet,TResNet,Lambda Networks,GhostNet,ByoaNet 等以及 TResNet, MobileNet-V3, ViT 的 ImageNet-21k 训练的权重,EfficientNet-V2 ImageNet-1k,ImageNet-21k 训练的权重。


Transformer模型:

添加了经典的 TNT,Swin Transformer,PiT,Bottleneck Transformers,Halo Nets,CoaT,CaiT,LeViT, Visformer, ConViT,Twins,BiT 等。


MLP模型:

添加了经典的 MLP-Mixer,ResMLP,gMLP等。


优化器层面:

更新了Adabelief optimizer等。


所以本文是对 timm 库代码的最新解读,不只限于视觉 transformer 模型。


所有的PyTorch模型及其对应arxiv链接如下:


Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723

Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370

Bottleneck Transformers - https://arxiv.org/abs/2101.11605

CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239

CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399

ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697

CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929

DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877

DenseNet - https://arxiv.org/abs/1608.06993

DLA - https://arxiv.org/abs/1707.06484

DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629

EfficientNet (MBConvNet Family)

EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252

EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665

EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946

EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html

EfficientNet V2 - https://arxiv.org/abs/2104.00298

FBNet-C - https://arxiv.org/abs/1812.03443

MixNet - https://arxiv.org/abs/1907.09595

MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626

MobileNet-V2 - https://arxiv.org/abs/1801.04381

Single-Path NAS - https://arxiv.org/abs/1904.02877

GhostNet - https://arxiv.org/abs/1911.11907

gMLP - https://arxiv.org/abs/2105.08050

GPU-Efficient Networks - https://arxiv.org/abs/2006.14090

Halo Nets - https://arxiv.org/abs/2103.12731

HardCoRe-NAS - https://arxiv.org/abs/2102.11646

HRNet - https://arxiv.org/abs/1908.07919

Inception-V3 - https://arxiv.org/abs/1512.00567

Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261

Lambda Networks - https://arxiv.org/abs/2102.08602

LeViT (Vision Transformer in ConvNet’s Clothing) - https://arxiv.org/abs/2104.01136

MLP-Mixer - https://arxiv.org/abs/2105.01601

MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244

NASNet-A - https://arxiv.org/abs/1707.07012

NFNet-F - https://arxiv.org/abs/2102.06171

NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692

PNasNet - https://arxiv.org/abs/1712.00559

Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302

RegNet - https://arxiv.org/abs/2003.13678

RepVGG - https://arxiv.org/abs/2101.03697

ResMLP - https://arxiv.org/abs/2105.03404

ResNet/ResNeXt

ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385

ResNeXt - https://arxiv.org/abs/1611.05431

‘Bag of Tricks’ / Gluon C, D, E, S variations - https://arxiv.org/abs/1812.01187

Weakly-supervised (WSL) Instagram pretrained / ImageNet tuned ResNeXt101 - https://arxiv.org/abs/1805.00932

Semi-supervised (SSL) / Semi-weakly Supervised (SWSL) ResNet/ResNeXts - https://arxiv.org/abs/1905.00546

ECA-Net (ECAResNet) - https://arxiv.org/abs/1910.03151v4

Squeeze-and-Excitation Networks (SEResNet) - https://arxiv.org/abs/1709.01507

ResNet-RS - https://arxiv.org/abs/2103.07579

Res2Net - https://arxiv.org/abs/1904.01169

ResNeSt - https://arxiv.org/abs/2004.08955

ReXNet - https://arxiv.org/abs/2007.00992

SelecSLS - https://arxiv.org/abs/1907.00837

Selective Kernel Networks - https://arxiv.org/abs/1903.06586

Swin Transformer - https://arxiv.org/abs/2103.14030

Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112

TResNet - https://arxiv.org/abs/2003.13630

Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf

Vision Transformer - https://arxiv.org/abs/2010.11929

VovNet V2 and V1 - https://arxiv.org/abs/1911.06667

Xception - https://arxiv.org/abs/1610.02357

Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611

Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611

XCiT (Cross-Covariance Image Transformers) - https://arxiv.org/abs/2106.09681


1. Models


timm 提供了大量的模型结构集合,而且很多模型都包含了预训练权重,或 PyTorch 训练、或从Jax和TensorFlow中移植,很方便下载使用.


模型列表:https://paperswithcode.com/lib/timm


查看模型列表:

#打印 timm 提供的模型列表
print(timm.list_models())
print(len(timm.list_models())) #739
#带有预训练权重的模型列表
print(timm.list_models(pretrained=True))
print(len(timm.list_models(pretrained=True))) #592


其中,timm.list_models() 函数:


list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False)


查看特定族模型,如:

print(timm.list_models('gluon_resnet*'))
print(timm.list_models('*resnext*', 'resnet') )
print(timm.list_models('resnet*', pretrained=True))

1.1. create_model 一般用法


timm 创建模型最简单的方式是采用 create_model.


以 Resnet-D 模型为例(Bag of Tricks for Image Classification For Convolutional Neural Networks paper),其是Resnet 的一种变形,其采用 average pool 进行下采样.

model = timm.create_model('resnet50d', pretrained=True)
print(model)
#查看模型配置参数
print(model.default_cfg)
'''
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
 'num_classes': 1000,
 'input_size': (3, 224, 224),
 'pool_size': (7, 7),
 'crop_pct': 0.875,
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'conv1.0',
 'classifier': 'fc',
 'architecture': 'resnet50d'}
'''


1.2. create_model 修改输入通道


timm models 有个非常有用的特点,其可以处理任意通道数量的输入图像. 这是很多其他库所不具备的. 其实现原理可参考:


https://fastai.github.io/timmdocs/models#So-how-is-timm-able-to-load-these-weights?


model = timm.create_model('resnet50d', pretrained=True, in_chans=1)
print(model)
#test, single channel image
x = troch.randn(1, 1, 224, 224)
out = model(x)
print(out.shape) #torch.Size([1, 1000])


1.3. create_model 定制模型


timm create_model 函数提供了很多参数,用于模型定制,函数定义如:

create_model(model_name, pretrained=False, checkpoint_path='', scriptable=None, exportable=None, no_jit=None, **kwargs)


**kwargs 示例参数如,

global_pool - 定义最终分类层所采用的 global pooling 类型. 取决于网络结构是否用到了全局池化层.

drop_rate - 设定训练时的 dropout 比例,默认是 0.

num_classes - 输出类别数


1.3.1. 修改类别数


查看当前模型输出层:

#如果输出层是 fc,则如
print(model.fc)
#Linear(in_features=2048, out_features=1000, bias=True)
#通用方式,查看输出层,
print(model.get_classifier())

修改输出层类别数:

model = timm.create_model('resnet50d', pretrained=True, num_classes=10)
print(model)
print(model.get_classifier())
#Linear(in_features=2048, out_features=10, bias=True)

如果完全不需要创建最后一层,可以将 num_classes 设为 0,模型将用恒等函数作为最后一层,其对于查看倒数第二层的输出有用.

model = timm.create_model('resnet50d', pretrained=True, num_classes=0)
print(model)
print(model.get_classifier())
#Identity()

1.3.2. Global pooling


model.default_cfg 中出现的 pool_size 设置,说明了在分类器前用到了一个全局池化层,如:

print(model.global_pool)
#SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))

其中,pool_type 支持:


avg - 平均池化

max - 最大池化

avgmax - 平均池化和最大池化的求和,加权 0.5

catevgmax - 沿着特征维度的平均池化和最大池化的输出的拼接,特征维度会翻倍

'' - 不采用 pooling,其被替换为恒等操作(Identity)

pool_types = ['avg', 'max', 'avgmax', 'catavgmax', '']
x = torch.randn(1, 3, 224, 224)
for pool_type in pool_types:
    model = timm.create_model('resnet50d', pretrained=True, num_classes=0, global_pool=pool_type)
    model.eval()
    out = model(x)
    print(out.shape)


1.3.3. 修改已有模型


如,

model = timm.create_model('resnet50d', pretrained=True)
print(f'[INFO]Original Pooling: {model.global_pool}')
print(f'[INFO]Original Classifier: {model.get_classifier}')
model = model.reset_classifier(10, 'max')
print(f'[INFO]Modified Pooling: {model.global_pool}')
print(f'[INFO]Modified Classifier: {model.get_classifier}')


1.3.4. 创建新的分类 head


虽然单个线性层已经足够得到比较好的结果,但有些时候需要更大的分类 head 来提升性能.


model = timm.create_model('resnet50d', pretrained=True, num_classes=10, global_pool='catavgmax')
print(model)
num_in_features = model.get_classifier().in_features
print(num_in_features)
model.fc = nn.Sequential(
    nn.BatchNorm1d(num_in_features),
    nn.Linear(in_features=num_in_features, out_features=512, bias=False),
    nn.ReLU(),
    nn.BatchNorm1d(512),
    nn.Dropout(0.4),
    nn.Linear(in_features=512, out_features=10, bias=False))
model.eval()
x = troch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)


相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
相关文章
|
2月前
|
机器学习/深度学习 人工智能
类人神经网络再进一步!DeepMind最新50页论文提出AligNet框架:用层次化视觉概念对齐人类
【10月更文挑战第18天】这篇论文提出了一种名为AligNet的框架,旨在通过将人类知识注入神经网络来解决其与人类认知的不匹配问题。AligNet通过训练教师模型模仿人类判断,并将人类化的结构和知识转移至预训练的视觉模型中,从而提高模型在多种任务上的泛化能力和稳健性。实验结果表明,人类对齐的模型在相似性任务和出分布情况下表现更佳。
71 3
|
2月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
360 2
|
10天前
|
人工智能 安全 PyTorch
SPDL:Meta AI 推出的开源高性能AI模型数据加载解决方案,兼容主流 AI 框架 PyTorch
SPDL是Meta AI推出的开源高性能AI模型数据加载解决方案,基于多线程技术和异步事件循环,提供高吞吐量、低资源占用的数据加载功能,支持分布式系统和主流AI框架PyTorch。
44 10
SPDL:Meta AI 推出的开源高性能AI模型数据加载解决方案,兼容主流 AI 框架 PyTorch
|
19天前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
35 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
4天前
|
监控 安全 BI
什么是零信任模型?如何实施以保证网络安全?
随着数字化转型,网络边界不断变化,组织需采用新的安全方法。零信任基于“永不信任,永远验证”原则,强调无论内外部,任何用户、设备或网络都不可信任。该模型包括微分段、多因素身份验证、单点登录、最小特权原则、持续监控和审核用户活动、监控设备等核心准则,以实现强大的网络安全态势。
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于BP神经网络的苦瓜生长含水量预测模型matlab仿真
本项目展示了基于BP神经网络的苦瓜生长含水量预测模型,通过温度(T)、风速(v)、模型厚度(h)等输入特征,预测苦瓜的含水量。采用Matlab2022a开发,核心代码附带中文注释及操作视频。模型利用BP神经网络的非线性映射能力,对试验数据进行训练,实现对未知样本含水量变化规律的预测,为干燥过程的理论研究提供支持。
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
使用Pytorch构建视觉语言模型(VLM)
视觉语言模型(Vision Language Model,VLM)正在改变计算机对视觉和文本信息的理解与交互方式。本文将介绍 VLM 的核心组件和实现细节,可以让你全面掌握这项前沿技术。我们的目标是理解并实现能够通过指令微调来执行有用任务的视觉语言模型。
46 2
|
1月前
|
存储 网络协议 安全
30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场
本文精选了 30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场。
83 2
|
1月前
|
运维 网络协议 算法
7 层 OSI 参考模型:详解网络通信的层次结构
7 层 OSI 参考模型:详解网络通信的层次结构
119 1
|
2月前
|
网络协议 前端开发 Java
网络协议与IO模型
网络协议与IO模型
115 4
网络协议与IO模型
下一篇
DataWorks