视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库(上)

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: 视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库(上)

视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库


PyTorchImageModels,简称timm,是一个巨大的PyTorch代码集合,包括了一系列:


image models

layers

utilities

optimizers

schedulers

data-loaders / augmentations

training / validation scripts


旨在将各种SOTA模型整合在一起,并具有复现ImageNet训练结果的能力。


PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.


除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由来自加拿大温哥华Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下几个链接。


Github链接:https://github.com/rwightman/pytorch-image-models

官网链接:https://fastai.github.io/timmdocs/https://rwightman.github.io/pytorch-image-models/

简略文档:https://rwightman.github.io/pytorch-image-models/

详细文档:https://fastai.github.io/timmdocs/


安装


PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.


timm 提供了参考的 training 和 validation 脚本,用于复现在 ImageNet 上的训练结果;以及更多的 官方文档 和 timmdocs project.


https://rwightman.github.io/pytorch-image-models/


https://fastai.github.io/timmdocs/


但,由于 timm 的功能之多,所以在定制使用时很难知道如何入手. 这里主要进行概述.


pip install timm==0.5.4


所有的开发和测试都是在 Linux x86-64系统上的 Conda Python 3环境中完成的,尤其是 Python 3.6和3.7 3.8 3.9

PyTorch 版本1.4、1.5. x、1.6、1.7. x 和1.8已经使用此代码进行了测试。


import timm


加载预先训练好的模型


我们只需要简单的create_model就可以得到我们的模型,并且如果我们需要使用我们的预训练模型,只需要加上参数pretrained=True即可


import timm
m = timm.create_model('mobilenetv3_large_100', pretrained=True)
m.eval()
MobileNetV3(
  (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): Hardswish()
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (se): Identity()
        (conv_pw): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Identity()
      )
    )
    (1): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): Identity()
        (conv_pwl): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)
        (bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): Identity()
        (conv_pwl): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(72, 72, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=72, bias=False)
        (bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(72, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
        (bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
        (bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=240, bias=False)
        (bn2): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(80, 200, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=200, bias=False)
        (bn2): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
        (bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): InvertedResidual(
        (conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
        (bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): Identity()
        (conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (4): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=480, bias=False)
        (bn2): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(672, 672, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=672, bias=False)
        (bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=672, bias=False)
        (bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(672, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
        (bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
        (conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
        (bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Hardswish()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): Sequential(
      (0): ConvBnAct(
        (conv): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Hardswish()
      )
    )
  )
  (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
  (conv_head): Conv2d(960, 1280, kernel_size=(1, 1), stride=(1, 1))
  (act2): Hardswish()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (classifier): Linear(in_features=1280, out_features=1000, bias=True)
)


列出具有预训练权重的模型


我们可以简单看看,一共大概有400多个模型,我们都是可以随意使用的

import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
['adv_inception_v3',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspresnet50',
 'cspresnext50',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'dpn68',
 'dpn68b',
 'dpn92',
 'dpn98',
 'dpn107',
 'dpn131',
 'eca_nfnet_l0',
 'eca_nfnet_l1',
 'eca_nfnet_l2',
 'ecaresnet26t',
 'ecaresnet50d',
 'ecaresnet50d_pruned',
 'ecaresnet50t',
 'ecaresnet101d',
 'ecaresnet101d_pruned',
 'ecaresnet269d',
 'ecaresnetlight',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b1_pruned',
 'efficientnet_b2',
 'efficientnet_b2_pruned',
 'efficientnet_b3',
 'efficientnet_b3_pruned',
 'efficientnet_b4',
 'efficientnet_el',
 'efficientnet_el_pruned',
 'efficientnet_em',
 'efficientnet_es',
 'efficientnet_es_pruned',
 'efficientnet_lite0',
 'efficientnetv2_rw_m',
 'efficientnetv2_rw_s',
 'ens_adv_inception_resnet_v2',
 'ese_vovnet19b_dw',
 'ese_vovnet39b',
 'fbnetc_100',
 'gernet_l',
 'gernet_m',
 'gernet_s',
 'ghostnet_100',
 'gluon_inception_v3',
 'gluon_resnet18_v1b',
 'gluon_resnet34_v1b',
 'gluon_resnet50_v1b',
 'gluon_resnet50_v1c',
 'gluon_resnet50_v1d',
 'gluon_resnet50_v1s',
 'gluon_resnet101_v1b',
 'gluon_resnet101_v1c',
 'gluon_resnet101_v1d',
 'gluon_resnet101_v1s',
 'gluon_resnet152_v1b',
 'gluon_resnet152_v1c',
 'gluon_resnet152_v1d',
 'gluon_resnet152_v1s',
 'gluon_resnext50_32x4d',
 'gluon_resnext101_32x4d',
 'gluon_resnext101_64x4d',
 'gluon_senet154',
 'gluon_seresnext50_32x4d',
 'gluon_seresnext101_32x4d',
 'gluon_seresnext101_64x4d',
 'gluon_xception65',
 'gmixer_24_224',
 'gmlp_s16_224',
 'hardcorenas_a',
 'hardcorenas_b',
 'hardcorenas_c',
 'hardcorenas_d',
 'hardcorenas_e',
 'hardcorenas_f',
 'hrnet_w18',
 'hrnet_w18_small',
 'hrnet_w18_small_v2',
 'hrnet_w30',
 'hrnet_w32',
 'hrnet_w40',
 'hrnet_w44',
 'hrnet_w48',
 'hrnet_w64',
 'ig_resnext101_32x8d',
 'ig_resnext101_32x16d',
 'ig_resnext101_32x32d',
 'ig_resnext101_32x48d',
 'inception_resnet_v2',
 'inception_v3',
 'inception_v4',
 'legacy_senet154',
 'legacy_seresnet18',
 'legacy_seresnet34',
 'legacy_seresnet50',
 'legacy_seresnet101',
 'legacy_seresnet152',
 'legacy_seresnext26_32x4d',
 'legacy_seresnext50_32x4d',
 'legacy_seresnext101_32x4d',
 'levit_128',
 'levit_128s',
 'levit_192',
 'levit_256',
 'levit_384',
 'mixer_b16_224',
 'mixer_b16_224_in21k',
 'mixer_b16_224_miil',
 'mixer_b16_224_miil_in21k',
 'mixer_l16_224',
 'mixer_l16_224_in21k',
 'mixnet_l',
 'mixnet_m',
 'mixnet_s',
 'mixnet_xl',
 'mnasnet_100',
 'mobilenetv2_100',
 'mobilenetv2_110d',
 'mobilenetv2_120d',
 'mobilenetv2_140',
 'mobilenetv3_large_100',
 'mobilenetv3_large_100_miil',
 'mobilenetv3_large_100_miil_in21k',
 'mobilenetv3_rw',
 'nasnetalarge',
 'nf_regnet_b1',
 'nf_resnet50',
 'nfnet_l0',
 'pit_b_224',
 'pit_b_distilled_224',
 'pit_s_224',
 'pit_s_distilled_224',
 'pit_ti_224',
 'pit_ti_distilled_224',
 'pit_xs_224',
 'pit_xs_distilled_224',
 'pnasnet5large',
 'regnetx_002',
 'regnetx_004',
 'regnetx_006',
 'regnetx_008',
 'regnetx_016',
 'regnetx_032',
 'regnetx_040',
 'regnetx_064',
 'regnetx_080',
 'regnetx_120',
 'regnetx_160',
 'regnetx_320',
 'regnety_002',
 'regnety_004',
 'regnety_006',
 'regnety_008',
 'regnety_016',
 'regnety_032',
 'regnety_040',
 'regnety_064',
 'regnety_080',
 'regnety_120',
 'regnety_160',
 'regnety_320',
 'repvgg_a2',
 'repvgg_b0',
 'repvgg_b1',
 'repvgg_b1g4',
 'repvgg_b2',
 'repvgg_b2g4',
 'repvgg_b3',
 'repvgg_b3g4',
 'res2net50_14w_8s',
 'res2net50_26w_4s',
 'res2net50_26w_6s',
 'res2net50_26w_8s',
 'res2net50_48w_2s',
 'res2net101_26w_4s',
 'res2next50',
 'resmlp_12_224',
 'resmlp_12_distilled_224',
 'resmlp_24_224',
 'resmlp_24_distilled_224',
 'resmlp_36_224',
 'resmlp_36_distilled_224',
 'resmlp_big_24_224',
 'resmlp_big_24_224_in22ft1k',
 'resmlp_big_24_distilled_224',
 'resnest14d',
 'resnest26d',
 'resnest50d',
 'resnest50d_1s4x24d',
 'resnest50d_4s2x40d',
 'resnest101e',
 'resnest200e',
 'resnest269e',
 'resnet18',
 'resnet18d',
 'resnet26',
 'resnet26d',
 'resnet34',
 'resnet34d',
 'resnet50',
 'resnet50d',
 'resnet51q',
 'resnet101d',
 'resnet152d',
 'resnet200d',
 'resnetblur50',
 'resnetrs50',
 'resnetrs101',
 'resnetrs152',
 'resnetrs200',
 'resnetrs270',
 'resnetrs350',
 'resnetrs420',
 'resnetv2_50x1_bit_distilled',
 'resnetv2_50x1_bitm',
 'resnetv2_50x1_bitm_in21k',
 'resnetv2_50x3_bitm',
 'resnetv2_50x3_bitm_in21k',
 'resnetv2_101x1_bitm',
 'resnetv2_101x1_bitm_in21k',
 'resnetv2_101x3_bitm',
 'resnetv2_101x3_bitm_in21k',
 'resnetv2_152x2_bit_teacher',
 'resnetv2_152x2_bit_teacher_384',
 'resnetv2_152x2_bitm',
 'resnetv2_152x2_bitm_in21k',
 'resnetv2_152x4_bitm',
 'resnetv2_152x4_bitm_in21k',
 'resnext50_32x4d',
 'resnext50d_32x4d',
 'resnext101_32x8d',
 'rexnet_100',
 'rexnet_130',
 'rexnet_150',
 'rexnet_200',
 'selecsls42b',
 'selecsls60',
 'selecsls60b',
 'semnasnet_100',
 'seresnet50',
 'seresnet152d',
 'seresnext26d_32x4d',
 'seresnext26t_32x4d',
 'seresnext50_32x4d',
 'skresnet18',
 'skresnet34',
 'skresnext50_32x4d',
 'spnasnet_100',
 'ssl_resnet18',
 'ssl_resnet50',
 'ssl_resnext50_32x4d',
 'ssl_resnext101_32x4d',
 'ssl_resnext101_32x8d',
 'ssl_resnext101_32x16d',
 'swin_base_patch4_window7_224',
 'swin_base_patch4_window7_224_in22k',
 'swin_base_patch4_window12_384',
 'swin_base_patch4_window12_384_in22k',
 'swin_large_patch4_window7_224',
 'swin_large_patch4_window7_224_in22k',
 'swin_large_patch4_window12_384',
 'swin_large_patch4_window12_384_in22k',
 'swin_small_patch4_window7_224',
 'swin_tiny_patch4_window7_224',
 'swsl_resnet18',
 'swsl_resnet50',
 'swsl_resnext50_32x4d',
 'swsl_resnext101_32x4d',
 'swsl_resnext101_32x8d',
 'swsl_resnext101_32x16d',
 'tf_efficientnet_b0',
 'tf_efficientnet_b0_ap',
 'tf_efficientnet_b0_ns',
 'tf_efficientnet_b1',
 'tf_efficientnet_b1_ap',
 'tf_efficientnet_b1_ns',
 'tf_efficientnet_b2',
 'tf_efficientnet_b2_ap',
 'tf_efficientnet_b2_ns',
 'tf_efficientnet_b3',
 'tf_efficientnet_b3_ap',
 'tf_efficientnet_b3_ns',
 'tf_efficientnet_b4',
 'tf_efficientnet_b4_ap',
 'tf_efficientnet_b4_ns',
 'tf_efficientnet_b5',
 'tf_efficientnet_b5_ap',
 'tf_efficientnet_b5_ns',
 'tf_efficientnet_b6',
 'tf_efficientnet_b6_ap',
 'tf_efficientnet_b6_ns',
 'tf_efficientnet_b7',
 'tf_efficientnet_b7_ap',
 'tf_efficientnet_b7_ns',
 'tf_efficientnet_b8',
 'tf_efficientnet_b8_ap',
 'tf_efficientnet_cc_b0_4e',
 'tf_efficientnet_cc_b0_8e',
 'tf_efficientnet_cc_b1_8e',
 'tf_efficientnet_el',
 'tf_efficientnet_em',
 'tf_efficientnet_es',
 'tf_efficientnet_l2_ns',
 'tf_efficientnet_l2_ns_475',
 'tf_efficientnet_lite0',
 'tf_efficientnet_lite1',
 'tf_efficientnet_lite2',
 'tf_efficientnet_lite3',
 'tf_efficientnet_lite4',
 'tf_efficientnetv2_b0',
 'tf_efficientnetv2_b1',
 'tf_efficientnetv2_b2',
 'tf_efficientnetv2_b3',
 'tf_efficientnetv2_l',
 'tf_efficientnetv2_l_in21ft1k',
 'tf_efficientnetv2_l_in21k',
 'tf_efficientnetv2_m',
 'tf_efficientnetv2_m_in21ft1k',
 'tf_efficientnetv2_m_in21k',
 'tf_efficientnetv2_s',
 'tf_efficientnetv2_s_in21ft1k',
 'tf_efficientnetv2_s_in21k',
 'tf_inception_v3',
 'tf_mixnet_l',
 'tf_mixnet_m',
 'tf_mixnet_s',
 'tf_mobilenetv3_large_075',
 'tf_mobilenetv3_large_100',
 'tf_mobilenetv3_large_minimal_100',
 'tf_mobilenetv3_small_075',
 'tf_mobilenetv3_small_100',
 'tf_mobilenetv3_small_minimal_100',
 'tnt_s_patch16_224',
 'tresnet_l',
 'tresnet_l_448',
 'tresnet_m',
 'tresnet_m_448',
 'tresnet_m_miil_in21k',
 'tresnet_xl',
 'tresnet_xl_448',
 'tv_densenet121',
 'tv_resnet34',
 'tv_resnet50',
 'tv_resnet101',
 'tv_resnet152',
 'tv_resnext50_32x4d',
 'twins_pcpvt_base',
 'twins_pcpvt_large',
 'twins_pcpvt_small',
 'twins_svt_base',
 'twins_svt_large',
 'twins_svt_small',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'visformer_small',
 'vit_base_patch16_224',
 'vit_base_patch16_224_in21k',
 'vit_base_patch16_224_miil',
 'vit_base_patch16_224_miil_in21k',
 'vit_base_patch16_384',
 'vit_base_patch32_224',
 'vit_base_patch32_224_in21k',
 'vit_base_patch32_384',
 'vit_base_r50_s16_224_in21k',
 'vit_base_r50_s16_384',
 'vit_huge_patch14_224_in21k',
 'vit_large_patch16_224',
 'vit_large_patch16_224_in21k',
 'vit_large_patch16_384',
 'vit_large_patch32_224_in21k',
 'vit_large_patch32_384',
 'vit_large_r50_s32_224',
 'vit_large_r50_s32_224_in21k',
 'vit_large_r50_s32_384',
 'vit_small_patch16_224',
 'vit_small_patch16_224_in21k',
 'vit_small_patch16_384',
 'vit_small_patch32_224',
 'vit_small_patch32_224_in21k',
 'vit_small_patch32_384',
 'vit_small_r26_s32_224',
 'vit_small_r26_s32_224_in21k',
 'vit_small_r26_s32_384',
 'vit_tiny_patch16_224',
 'vit_tiny_patch16_224_in21k',
 'vit_tiny_patch16_384',
 'vit_tiny_r_s16_p8_224',
 'vit_tiny_r_s16_p8_224_in21k',
 'vit_tiny_r_s16_p8_384',
 'wide_resnet50_2',
 'wide_resnet101_2',
 'xception',
 'xception41',
 'xception65',
 'xception71']

通过通配符选择模型架构


这个方法,可以让我们快速找到我们所需要的模型,这样可以方便我们进行create_model

model_names = timm.list_models('*resne*t*')
pprint(model_names)
['bat_resnext26ts',
 'cspresnet50',
 'cspresnet50d',
 'cspresnet50w',
 'cspresnext50',
 'cspresnext50_iabn',
 'eca_lambda_resnext26ts',
 'ecaresnet26t',
 'ecaresnet50d',
 'ecaresnet50d_pruned',
 'ecaresnet50t',
 'ecaresnet101d',
 'ecaresnet101d_pruned',
 'ecaresnet200d',
 'ecaresnet269d',
 'ecaresnetlight',
 'ecaresnext26t_32x4d',
 'ecaresnext50t_32x4d',
 'ens_adv_inception_resnet_v2',
 'gcresnet50t',
 'gcresnext26ts',
 'geresnet50t',
 'gluon_resnet18_v1b',
 'gluon_resnet34_v1b',
 'gluon_resnet50_v1b',
 'gluon_resnet50_v1c',
 'gluon_resnet50_v1d',
 'gluon_resnet50_v1s',
 'gluon_resnet101_v1b',
 'gluon_resnet101_v1c',
 'gluon_resnet101_v1d',
 'gluon_resnet101_v1s',
 'gluon_resnet152_v1b',
 'gluon_resnet152_v1c',
 'gluon_resnet152_v1d',
 'gluon_resnet152_v1s',
 'gluon_resnext50_32x4d',
 'gluon_resnext101_32x4d',
 'gluon_resnext101_64x4d',
 'gluon_seresnext50_32x4d',
 'gluon_seresnext101_32x4d',
 'gluon_seresnext101_64x4d',
 'ig_resnext101_32x8d',
 'ig_resnext101_32x16d',
 'ig_resnext101_32x32d',
 'ig_resnext101_32x48d',
 'inception_resnet_v2',
 'lambda_resnet26t',
 'lambda_resnet50t',
 'legacy_seresnet18',
 'legacy_seresnet34',
 'legacy_seresnet50',
 'legacy_seresnet101',
 'legacy_seresnet152',
 'legacy_seresnext26_32x4d',
 'legacy_seresnext50_32x4d',
 'legacy_seresnext101_32x4d',
 'nf_ecaresnet26',
 'nf_ecaresnet50',
 'nf_ecaresnet101',
 'nf_resnet26',
 'nf_resnet50',
 'nf_resnet101',
 'nf_seresnet26',
 'nf_seresnet50',
 'nf_seresnet101',
 'resnest14d',
 'resnest26d',
 'resnest50d',
 'resnest50d_1s4x24d',
 'resnest50d_4s2x40d',
 'resnest101e',
 'resnest200e',
 'resnest269e',
 'resnet18',
 'resnet18d',
 'resnet26',
 'resnet26d',
 'resnet26t',
 'resnet34',
 'resnet34d',
 'resnet50',
 'resnet50d',
 'resnet50t',
 'resnet51q',
 'resnet61q',
 'resnet101',
 'resnet101d',
 'resnet152',
 'resnet152d',
 'resnet200',
 'resnet200d',
 'resnetblur18',
 'resnetblur50',
 'resnetrs50',
 'resnetrs101',
 'resnetrs152',
 'resnetrs200',
 'resnetrs270',
 'resnetrs350',
 'resnetrs420',
 'resnetv2_50',
 'resnetv2_50d',
 'resnetv2_50t',
 'resnetv2_50x1_bit_distilled',
 'resnetv2_50x1_bitm',
 'resnetv2_50x1_bitm_in21k',
 'resnetv2_50x3_bitm',
 'resnetv2_50x3_bitm_in21k',
 'resnetv2_101',
 'resnetv2_101d',
 'resnetv2_101x1_bitm',
 'resnetv2_101x1_bitm_in21k',
 'resnetv2_101x3_bitm',
 'resnetv2_101x3_bitm_in21k',
 'resnetv2_152',
 'resnetv2_152d',
 'resnetv2_152x2_bit_teacher',
 'resnetv2_152x2_bit_teacher_384',
 'resnetv2_152x2_bitm',
 'resnetv2_152x2_bitm_in21k',
 'resnetv2_152x4_bitm',
 'resnetv2_152x4_bitm_in21k',
 'resnext50_32x4d',
 'resnext50d_32x4d',
 'resnext101_32x4d',
 'resnext101_32x8d',
 'resnext101_64x4d',
 'seresnet18',
 'seresnet34',
 'seresnet50',
 'seresnet50t',
 'seresnet101',
 'seresnet152',
 'seresnet152d',
 'seresnet200d',
 'seresnet269d',
 'seresnext26d_32x4d',
 'seresnext26t_32x4d',
 'seresnext26tn_32x4d',
 'seresnext50_32x4d',
 'seresnext101_32x4d',
 'seresnext101_32x8d',
 'skresnet18',
 'skresnet34',
 'skresnet50',
 'skresnet50d',
 'skresnext50_32x4d',
 'ssl_resnet18',
 'ssl_resnet50',
 'ssl_resnext50_32x4d',
 'ssl_resnext101_32x4d',
 'ssl_resnext101_32x8d',
 'ssl_resnext101_32x16d',
 'swsl_resnet18',
 'swsl_resnet50',
 'swsl_resnext50_32x4d',
 'swsl_resnext101_32x4d',
 'swsl_resnext101_32x8d',
 'swsl_resnext101_32x16d',
 'tresnet_l',
 'tresnet_l_448',
 'tresnet_m',
 'tresnet_m_448',
 'tresnet_m_miil_in21k',
 'tresnet_xl',
 'tresnet_xl_448',
 'tv_resnet34',
 'tv_resnet50',
 'tv_resnet101',
 'tv_resnet152',
 'tv_resnext50_32x4d',
 'vit_base_resnet26d_224',
 'vit_base_resnet50_224_in21k',
 'vit_base_resnet50_384',
 'vit_base_resnet50d_224',
 'vit_small_resnet26d_224',
 'vit_small_resnet50d_s16_224',
 'wide_resnet50_2',
 'wide_resnet101_2']


https://rwightman.github.io/pytorch-image-models/models/ 介绍了timm实现的一些网络模型及其论文和参考代码

https://paperswithcode.com/lib/timm 也有列出


模型及论文


CNN模型:

添加了经典的 NFNet,RegNet,TResNet,Lambda Networks,GhostNet,ByoaNet 等以及 TResNet, MobileNet-V3, ViT 的 ImageNet-21k 训练的权重,EfficientNet-V2 ImageNet-1k,ImageNet-21k 训练的权重。


Transformer模型:

添加了经典的 TNT,Swin Transformer,PiT,Bottleneck Transformers,Halo Nets,CoaT,CaiT,LeViT, Visformer, ConViT,Twins,BiT 等。


MLP模型:

添加了经典的 MLP-Mixer,ResMLP,gMLP等。


优化器层面:

更新了Adabelief optimizer等。


所以本文是对 timm 库代码的最新解读,不只限于视觉 transformer 模型。


所有的PyTorch模型及其对应arxiv链接如下:


Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723

Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370

Bottleneck Transformers - https://arxiv.org/abs/2101.11605

CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239

CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399

ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697

CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929

DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877

DenseNet - https://arxiv.org/abs/1608.06993

DLA - https://arxiv.org/abs/1707.06484

DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629

EfficientNet (MBConvNet Family)

EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252

EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665

EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946

EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html

EfficientNet V2 - https://arxiv.org/abs/2104.00298

FBNet-C - https://arxiv.org/abs/1812.03443

MixNet - https://arxiv.org/abs/1907.09595

MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626

MobileNet-V2 - https://arxiv.org/abs/1801.04381

Single-Path NAS - https://arxiv.org/abs/1904.02877

GhostNet - https://arxiv.org/abs/1911.11907

gMLP - https://arxiv.org/abs/2105.08050

GPU-Efficient Networks - https://arxiv.org/abs/2006.14090

Halo Nets - https://arxiv.org/abs/2103.12731

HardCoRe-NAS - https://arxiv.org/abs/2102.11646

HRNet - https://arxiv.org/abs/1908.07919

Inception-V3 - https://arxiv.org/abs/1512.00567

Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261

Lambda Networks - https://arxiv.org/abs/2102.08602

LeViT (Vision Transformer in ConvNet’s Clothing) - https://arxiv.org/abs/2104.01136

MLP-Mixer - https://arxiv.org/abs/2105.01601

MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244

NASNet-A - https://arxiv.org/abs/1707.07012

NFNet-F - https://arxiv.org/abs/2102.06171

NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692

PNasNet - https://arxiv.org/abs/1712.00559

Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302

RegNet - https://arxiv.org/abs/2003.13678

RepVGG - https://arxiv.org/abs/2101.03697

ResMLP - https://arxiv.org/abs/2105.03404

ResNet/ResNeXt

ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385

ResNeXt - https://arxiv.org/abs/1611.05431

‘Bag of Tricks’ / Gluon C, D, E, S variations - https://arxiv.org/abs/1812.01187

Weakly-supervised (WSL) Instagram pretrained / ImageNet tuned ResNeXt101 - https://arxiv.org/abs/1805.00932

Semi-supervised (SSL) / Semi-weakly Supervised (SWSL) ResNet/ResNeXts - https://arxiv.org/abs/1905.00546

ECA-Net (ECAResNet) - https://arxiv.org/abs/1910.03151v4

Squeeze-and-Excitation Networks (SEResNet) - https://arxiv.org/abs/1709.01507

ResNet-RS - https://arxiv.org/abs/2103.07579

Res2Net - https://arxiv.org/abs/1904.01169

ResNeSt - https://arxiv.org/abs/2004.08955

ReXNet - https://arxiv.org/abs/2007.00992

SelecSLS - https://arxiv.org/abs/1907.00837

Selective Kernel Networks - https://arxiv.org/abs/1903.06586

Swin Transformer - https://arxiv.org/abs/2103.14030

Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112

TResNet - https://arxiv.org/abs/2003.13630

Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf

Vision Transformer - https://arxiv.org/abs/2010.11929

VovNet V2 and V1 - https://arxiv.org/abs/1911.06667

Xception - https://arxiv.org/abs/1610.02357

Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611

Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611

XCiT (Cross-Covariance Image Transformers) - https://arxiv.org/abs/2106.09681


1. Models


timm 提供了大量的模型结构集合,而且很多模型都包含了预训练权重,或 PyTorch 训练、或从Jax和TensorFlow中移植,很方便下载使用.


模型列表:https://paperswithcode.com/lib/timm


查看模型列表:

#打印 timm 提供的模型列表
print(timm.list_models())
print(len(timm.list_models())) #739
#带有预训练权重的模型列表
print(timm.list_models(pretrained=True))
print(len(timm.list_models(pretrained=True))) #592


其中,timm.list_models() 函数:


list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False)


查看特定族模型,如:

print(timm.list_models('gluon_resnet*'))
print(timm.list_models('*resnext*', 'resnet') )
print(timm.list_models('resnet*', pretrained=True))

1.1. create_model 一般用法


timm 创建模型最简单的方式是采用 create_model.


以 Resnet-D 模型为例(Bag of Tricks for Image Classification For Convolutional Neural Networks paper),其是Resnet 的一种变形,其采用 average pool 进行下采样.

model = timm.create_model('resnet50d', pretrained=True)
print(model)
#查看模型配置参数
print(model.default_cfg)
'''
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
 'num_classes': 1000,
 'input_size': (3, 224, 224),
 'pool_size': (7, 7),
 'crop_pct': 0.875,
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'conv1.0',
 'classifier': 'fc',
 'architecture': 'resnet50d'}
'''


1.2. create_model 修改输入通道


timm models 有个非常有用的特点,其可以处理任意通道数量的输入图像. 这是很多其他库所不具备的. 其实现原理可参考:


https://fastai.github.io/timmdocs/models#So-how-is-timm-able-to-load-these-weights?


model = timm.create_model('resnet50d', pretrained=True, in_chans=1)
print(model)
#test, single channel image
x = troch.randn(1, 1, 224, 224)
out = model(x)
print(out.shape) #torch.Size([1, 1000])


1.3. create_model 定制模型


timm create_model 函数提供了很多参数,用于模型定制,函数定义如:

create_model(model_name, pretrained=False, checkpoint_path='', scriptable=None, exportable=None, no_jit=None, **kwargs)


**kwargs 示例参数如,

global_pool - 定义最终分类层所采用的 global pooling 类型. 取决于网络结构是否用到了全局池化层.

drop_rate - 设定训练时的 dropout 比例,默认是 0.

num_classes - 输出类别数


1.3.1. 修改类别数


查看当前模型输出层:

#如果输出层是 fc,则如
print(model.fc)
#Linear(in_features=2048, out_features=1000, bias=True)
#通用方式,查看输出层,
print(model.get_classifier())

修改输出层类别数:

model = timm.create_model('resnet50d', pretrained=True, num_classes=10)
print(model)
print(model.get_classifier())
#Linear(in_features=2048, out_features=10, bias=True)

如果完全不需要创建最后一层,可以将 num_classes 设为 0,模型将用恒等函数作为最后一层,其对于查看倒数第二层的输出有用.

model = timm.create_model('resnet50d', pretrained=True, num_classes=0)
print(model)
print(model.get_classifier())
#Identity()

1.3.2. Global pooling


model.default_cfg 中出现的 pool_size 设置,说明了在分类器前用到了一个全局池化层,如:

print(model.global_pool)
#SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))

其中,pool_type 支持:


avg - 平均池化

max - 最大池化

avgmax - 平均池化和最大池化的求和,加权 0.5

catevgmax - 沿着特征维度的平均池化和最大池化的输出的拼接,特征维度会翻倍

'' - 不采用 pooling,其被替换为恒等操作(Identity)

pool_types = ['avg', 'max', 'avgmax', 'catavgmax', '']
x = torch.randn(1, 3, 224, 224)
for pool_type in pool_types:
    model = timm.create_model('resnet50d', pretrained=True, num_classes=0, global_pool=pool_type)
    model.eval()
    out = model(x)
    print(out.shape)


1.3.3. 修改已有模型


如,

model = timm.create_model('resnet50d', pretrained=True)
print(f'[INFO]Original Pooling: {model.global_pool}')
print(f'[INFO]Original Classifier: {model.get_classifier}')
model = model.reset_classifier(10, 'max')
print(f'[INFO]Modified Pooling: {model.global_pool}')
print(f'[INFO]Modified Classifier: {model.get_classifier}')


1.3.4. 创建新的分类 head


虽然单个线性层已经足够得到比较好的结果,但有些时候需要更大的分类 head 来提升性能.


model = timm.create_model('resnet50d', pretrained=True, num_classes=10, global_pool='catavgmax')
print(model)
num_in_features = model.get_classifier().in_features
print(num_in_features)
model.fc = nn.Sequential(
    nn.BatchNorm1d(num_in_features),
    nn.Linear(in_features=num_in_features, out_features=512, bias=False),
    nn.ReLU(),
    nn.BatchNorm1d(512),
    nn.Dropout(0.4),
    nn.Linear(in_features=512, out_features=10, bias=False))
model.eval()
x = troch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)


相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
相关文章
|
7天前
|
机器学习/深度学习 人工智能
类人神经网络再进一步!DeepMind最新50页论文提出AligNet框架:用层次化视觉概念对齐人类
【10月更文挑战第18天】这篇论文提出了一种名为AligNet的框架,旨在通过将人类知识注入神经网络来解决其与人类认知的不匹配问题。AligNet通过训练教师模型模仿人类判断,并将人类化的结构和知识转移至预训练的视觉模型中,从而提高模型在多种任务上的泛化能力和稳健性。实验结果表明,人类对齐的模型在相似性任务和出分布情况下表现更佳。
19 3
|
19天前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
81 2
|
19天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于BP神经网络的苦瓜生长含水量预测模型matlab仿真
本项目展示了基于BP神经网络的苦瓜生长含水量预测模型,通过温度(T)、风速(v)、模型厚度(h)等输入特征,预测苦瓜的含水量。采用Matlab2022a开发,核心代码附带中文注释及操作视频。模型利用BP神经网络的非线性映射能力,对试验数据进行训练,实现对未知样本含水量变化规律的预测,为干燥过程的理论研究提供支持。
|
21天前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
43 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
18天前
|
网络协议 前端开发 Java
网络协议与IO模型
网络协议与IO模型
网络协议与IO模型
|
18天前
|
机器学习/深度学习 网络架构 计算机视觉
目标检测笔记(一):不同模型的网络架构介绍和代码
这篇文章介绍了ShuffleNetV2网络架构及其代码实现,包括模型结构、代码细节和不同版本的模型。ShuffleNetV2是一个高效的卷积神经网络,适用于深度学习中的目标检测任务。
55 1
目标检测笔记(一):不同模型的网络架构介绍和代码
|
24天前
|
网络协议 物联网 虚拟化
|
23天前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
38 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
4天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
12 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
13天前
|
机器学习/深度学习 算法 数据挖掘
【深度学习】经典的深度学习模型-02 ImageNet夺冠之作: 神经网络AlexNet
【深度学习】经典的深度学习模型-02 ImageNet夺冠之作: 神经网络AlexNet
23 2

热门文章

最新文章