视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库
PyTorchImageModels,简称timm,是一个巨大的PyTorch代码集合,包括了一系列:
image models
layers
utilities
optimizers
schedulers
data-loaders / augmentations
training / validation scripts
旨在将各种SOTA模型整合在一起,并具有复现ImageNet训练结果的能力。
PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.
除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由来自加拿大温哥华Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下几个链接。
Github链接:https://github.com/rwightman/pytorch-image-models
官网链接:https://fastai.github.io/timmdocs/https://rwightman.github.io/pytorch-image-models/
简略文档:https://rwightman.github.io/pytorch-image-models/
详细文档:https://fastai.github.io/timmdocs/
安装
PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.
timm 提供了参考的 training 和 validation 脚本,用于复现在 ImageNet 上的训练结果;以及更多的 官方文档 和 timmdocs project.
https://rwightman.github.io/pytorch-image-models/
https://fastai.github.io/timmdocs/
但,由于 timm 的功能之多,所以在定制使用时很难知道如何入手. 这里主要进行概述.
pip install timm==0.5.4
所有的开发和测试都是在 Linux x86-64系统上的 Conda Python 3环境中完成的,尤其是 Python 3.6和3.7 3.8 3.9
PyTorch 版本1.4、1.5. x、1.6、1.7. x 和1.8已经使用此代码进行了测试。
import timm
加载预先训练好的模型
我们只需要简单的create_model就可以得到我们的模型,并且如果我们需要使用我们的预训练模型,只需要加上参数pretrained=True即可
import timm m = timm.create_model('mobilenetv3_large_100', pretrained=True) m.eval() MobileNetV3( (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (blocks): Sequential( (0): Sequential( (0): DepthwiseSeparableConv( (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (se): Identity() (conv_pw): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Identity() ) ) (1): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv_dw): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (se): Identity() (conv_pwl): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv_dw): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False) (bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (se): Identity() (conv_pwl): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv_dw): Conv2d(72, 72, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=72, bias=False) (bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(72, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False) (bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False) (bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=240, bias=False) (bn2): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): Identity() (conv_pwl): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(80, 200, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=200, bias=False) (bn2): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): Identity() (conv_pwl): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False) (bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): Identity() (conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): InvertedResidual( (conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False) (bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): Identity() (conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (4): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=480, bias=False) (bn2): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): SqueezeExcite( (conv_reduce): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(672, 672, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=672, bias=False) (bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): SqueezeExcite( (conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=672, bias=False) (bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): SqueezeExcite( (conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(672, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False) (bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): SqueezeExcite( (conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() (conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False) (bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Hardswish() (se): SqueezeExcite( (conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1)) (act1): ReLU(inplace=True) (conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1)) (gate): Hardsigmoid() ) (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): Sequential( (0): ConvBnAct( (conv): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): Hardswish() ) ) ) (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity()) (conv_head): Conv2d(960, 1280, kernel_size=(1, 1), stride=(1, 1)) (act2): Hardswish() (flatten): Flatten(start_dim=1, end_dim=-1) (classifier): Linear(in_features=1280, out_features=1000, bias=True) )
列出具有预训练权重的模型
我们可以简单看看,一共大概有400多个模型,我们都是可以随意使用的
import timm from pprint import pprint model_names = timm.list_models(pretrained=True) pprint(model_names) ['adv_inception_v3', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_tiny', 'convit_base', 'convit_small', 'convit_tiny', 'cspdarknet53', 'cspresnet50', 'cspresnext50', 'deit_base_distilled_patch16_224', 'deit_base_distilled_patch16_384', 'deit_base_patch16_224', 'deit_base_patch16_384', 'deit_small_distilled_patch16_224', 'deit_small_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_tiny_patch16_224', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'densenetblur121d', 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60_res2net', 'dla60_res2next', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet269d', 'ecaresnetlight', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2_pruned', 'efficientnet_b3', 'efficientnet_b3_pruned', 'efficientnet_b4', 'efficientnet_el', 'efficientnet_el_pruned', 'efficientnet_em', 'efficientnet_es', 'efficientnet_es_pruned', 'efficientnet_lite0', 'efficientnetv2_rw_m', 'efficientnetv2_rw_s', 'ens_adv_inception_resnet_v2', 'ese_vovnet19b_dw', 'ese_vovnet39b', 'fbnetc_100', 'gernet_l', 'gernet_m', 'gernet_s', 'ghostnet_100', 'gluon_inception_v3', 'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet50_v1c', 'gluon_resnet50_v1d', 'gluon_resnet50_v1s', 'gluon_resnet101_v1b', 'gluon_resnet101_v1c', 'gluon_resnet101_v1d', 'gluon_resnet101_v1s', 'gluon_resnet152_v1b', 'gluon_resnet152_v1c', 'gluon_resnet152_v1d', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d', 'gluon_resnext101_32x4d', 'gluon_resnext101_64x4d', 'gluon_senet154', 'gluon_seresnext50_32x4d', 'gluon_seresnext101_32x4d', 'gluon_seresnext101_64x4d', 'gluon_xception65', 'gmixer_24_224', 'gmlp_s16_224', 'hardcorenas_a', 'hardcorenas_b', 'hardcorenas_c', 'hardcorenas_d', 'hardcorenas_e', 'hardcorenas_f', 'hrnet_w18', 'hrnet_w18_small', 'hrnet_w18_small_v2', 'hrnet_w30', 'hrnet_w32', 'hrnet_w40', 'hrnet_w44', 'hrnet_w48', 'hrnet_w64', 'ig_resnext101_32x8d', 'ig_resnext101_32x16d', 'ig_resnext101_32x32d', 'ig_resnext101_32x48d', 'inception_resnet_v2', 'inception_v3', 'inception_v4', 'legacy_senet154', 'legacy_seresnet18', 'legacy_seresnet34', 'legacy_seresnet50', 'legacy_seresnet101', 'legacy_seresnet152', 'legacy_seresnext26_32x4d', 'legacy_seresnext50_32x4d', 'legacy_seresnext101_32x4d', 'levit_128', 'levit_128s', 'levit_192', 'levit_256', 'levit_384', 'mixer_b16_224', 'mixer_b16_224_in21k', 'mixer_b16_224_miil', 'mixer_b16_224_miil_in21k', 'mixer_l16_224', 'mixer_l16_224_in21k', 'mixnet_l', 'mixnet_m', 'mixnet_s', 'mixnet_xl', 'mnasnet_100', 'mobilenetv2_100', 'mobilenetv2_110d', 'mobilenetv2_120d', 'mobilenetv2_140', 'mobilenetv3_large_100', 'mobilenetv3_large_100_miil', 'mobilenetv3_large_100_miil_in21k', 'mobilenetv3_rw', 'nasnetalarge', 'nf_regnet_b1', 'nf_resnet50', 'nfnet_l0', 'pit_b_224', 'pit_b_distilled_224', 'pit_s_224', 'pit_s_distilled_224', 'pit_ti_224', 'pit_ti_distilled_224', 'pit_xs_224', 'pit_xs_distilled_224', 'pnasnet5large', 'regnetx_002', 'regnetx_004', 'regnetx_006', 'regnetx_008', 'regnetx_016', 'regnetx_032', 'regnetx_040', 'regnetx_064', 'regnetx_080', 'regnetx_120', 'regnetx_160', 'regnetx_320', 'regnety_002', 'regnety_004', 'regnety_006', 'regnety_008', 'regnety_016', 'regnety_032', 'regnety_040', 'regnety_064', 'regnety_080', 'regnety_120', 'regnety_160', 'regnety_320', 'repvgg_a2', 'repvgg_b0', 'repvgg_b1', 'repvgg_b1g4', 'repvgg_b2', 'repvgg_b2g4', 'repvgg_b3', 'repvgg_b3g4', 'res2net50_14w_8s', 'res2net50_26w_4s', 'res2net50_26w_6s', 'res2net50_26w_8s', 'res2net50_48w_2s', 'res2net101_26w_4s', 'res2next50', 'resmlp_12_224', 'resmlp_12_distilled_224', 'resmlp_24_224', 'resmlp_24_distilled_224', 'resmlp_36_224', 'resmlp_36_distilled_224', 'resmlp_big_24_224', 'resmlp_big_24_224_in22ft1k', 'resmlp_big_24_distilled_224', 'resnest14d', 'resnest26d', 'resnest50d', 'resnest50d_1s4x24d', 'resnest50d_4s2x40d', 'resnest101e', 'resnest200e', 'resnest269e', 'resnet18', 'resnet18d', 'resnet26', 'resnet26d', 'resnet34', 'resnet34d', 'resnet50', 'resnet50d', 'resnet51q', 'resnet101d', 'resnet152d', 'resnet200d', 'resnetblur50', 'resnetrs50', 'resnetrs101', 'resnetrs152', 'resnetrs200', 'resnetrs270', 'resnetrs350', 'resnetrs420', 'resnetv2_50x1_bit_distilled', 'resnetv2_50x1_bitm', 'resnetv2_50x1_bitm_in21k', 'resnetv2_50x3_bitm', 'resnetv2_50x3_bitm_in21k', 'resnetv2_101x1_bitm', 'resnetv2_101x1_bitm_in21k', 'resnetv2_101x3_bitm', 'resnetv2_101x3_bitm_in21k', 'resnetv2_152x2_bit_teacher', 'resnetv2_152x2_bit_teacher_384', 'resnetv2_152x2_bitm', 'resnetv2_152x2_bitm_in21k', 'resnetv2_152x4_bitm', 'resnetv2_152x4_bitm_in21k', 'resnext50_32x4d', 'resnext50d_32x4d', 'resnext101_32x8d', 'rexnet_100', 'rexnet_130', 'rexnet_150', 'rexnet_200', 'selecsls42b', 'selecsls60', 'selecsls60b', 'semnasnet_100', 'seresnet50', 'seresnet152d', 'seresnext26d_32x4d', 'seresnext26t_32x4d', 'seresnext50_32x4d', 'skresnet18', 'skresnet34', 'skresnext50_32x4d', 'spnasnet_100', 'ssl_resnet18', 'ssl_resnet50', 'ssl_resnext50_32x4d', 'ssl_resnext101_32x4d', 'ssl_resnext101_32x8d', 'ssl_resnext101_32x16d', 'swin_base_patch4_window7_224', 'swin_base_patch4_window7_224_in22k', 'swin_base_patch4_window12_384', 'swin_base_patch4_window12_384_in22k', 'swin_large_patch4_window7_224', 'swin_large_patch4_window7_224_in22k', 'swin_large_patch4_window12_384', 'swin_large_patch4_window12_384_in22k', 'swin_small_patch4_window7_224', 'swin_tiny_patch4_window7_224', 'swsl_resnet18', 'swsl_resnet50', 'swsl_resnext50_32x4d', 'swsl_resnext101_32x4d', 'swsl_resnext101_32x8d', 'swsl_resnext101_32x16d', 'tf_efficientnet_b0', 'tf_efficientnet_b0_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b2_ns', 'tf_efficientnet_b3', 'tf_efficientnet_b3_ap', 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4', 'tf_efficientnet_b4_ap', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b6_ns', 'tf_efficientnet_b7', 'tf_efficientnet_b7_ap', 'tf_efficientnet_b7_ns', 'tf_efficientnet_b8', 'tf_efficientnet_b8_ap', 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', 'tf_efficientnet_el', 'tf_efficientnet_em', 'tf_efficientnet_es', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', 'tf_efficientnet_lite4', 'tf_efficientnetv2_b0', 'tf_efficientnetv2_b1', 'tf_efficientnetv2_b2', 'tf_efficientnetv2_b3', 'tf_efficientnetv2_l', 'tf_efficientnetv2_l_in21ft1k', 'tf_efficientnetv2_l_in21k', 'tf_efficientnetv2_m', 'tf_efficientnetv2_m_in21ft1k', 'tf_efficientnetv2_m_in21k', 'tf_efficientnetv2_s', 'tf_efficientnetv2_s_in21ft1k', 'tf_efficientnetv2_s_in21k', 'tf_inception_v3', 'tf_mixnet_l', 'tf_mixnet_m', 'tf_mixnet_s', 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100', 'tnt_s_patch16_224', 'tresnet_l', 'tresnet_l_448', 'tresnet_m', 'tresnet_m_448', 'tresnet_m_miil_in21k', 'tresnet_xl', 'tresnet_xl_448', 'tv_densenet121', 'tv_resnet34', 'tv_resnet50', 'tv_resnet101', 'tv_resnet152', 'tv_resnext50_32x4d', 'twins_pcpvt_base', 'twins_pcpvt_large', 'twins_pcpvt_small', 'twins_svt_base', 'twins_svt_large', 'twins_svt_small', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'visformer_small', 'vit_base_patch16_224', 'vit_base_patch16_224_in21k', 'vit_base_patch16_224_miil', 'vit_base_patch16_224_miil_in21k', 'vit_base_patch16_384', 'vit_base_patch32_224', 'vit_base_patch32_224_in21k', 'vit_base_patch32_384', 'vit_base_r50_s16_224_in21k', 'vit_base_r50_s16_384', 'vit_huge_patch14_224_in21k', 'vit_large_patch16_224', 'vit_large_patch16_224_in21k', 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_large_r50_s32_224', 'vit_large_r50_s32_224_in21k', 'vit_large_r50_s32_384', 'vit_small_patch16_224', 'vit_small_patch16_224_in21k', 'vit_small_patch16_384', 'vit_small_patch32_224', 'vit_small_patch32_224_in21k', 'vit_small_patch32_384', 'vit_small_r26_s32_224', 'vit_small_r26_s32_224_in21k', 'vit_small_r26_s32_384', 'vit_tiny_patch16_224', 'vit_tiny_patch16_224_in21k', 'vit_tiny_patch16_384', 'vit_tiny_r_s16_p8_224', 'vit_tiny_r_s16_p8_224_in21k', 'vit_tiny_r_s16_p8_384', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', 'xception41', 'xception65', 'xception71']
通过通配符选择模型架构
这个方法,可以让我们快速找到我们所需要的模型,这样可以方便我们进行create_model
model_names = timm.list_models('*resne*t*') pprint(model_names) ['bat_resnext26ts', 'cspresnet50', 'cspresnet50d', 'cspresnet50w', 'cspresnext50', 'cspresnext50_iabn', 'eca_lambda_resnext26ts', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet200d', 'ecaresnet269d', 'ecaresnetlight', 'ecaresnext26t_32x4d', 'ecaresnext50t_32x4d', 'ens_adv_inception_resnet_v2', 'gcresnet50t', 'gcresnext26ts', 'geresnet50t', 'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet50_v1c', 'gluon_resnet50_v1d', 'gluon_resnet50_v1s', 'gluon_resnet101_v1b', 'gluon_resnet101_v1c', 'gluon_resnet101_v1d', 'gluon_resnet101_v1s', 'gluon_resnet152_v1b', 'gluon_resnet152_v1c', 'gluon_resnet152_v1d', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d', 'gluon_resnext101_32x4d', 'gluon_resnext101_64x4d', 'gluon_seresnext50_32x4d', 'gluon_seresnext101_32x4d', 'gluon_seresnext101_64x4d', 'ig_resnext101_32x8d', 'ig_resnext101_32x16d', 'ig_resnext101_32x32d', 'ig_resnext101_32x48d', 'inception_resnet_v2', 'lambda_resnet26t', 'lambda_resnet50t', 'legacy_seresnet18', 'legacy_seresnet34', 'legacy_seresnet50', 'legacy_seresnet101', 'legacy_seresnet152', 'legacy_seresnext26_32x4d', 'legacy_seresnext50_32x4d', 'legacy_seresnext101_32x4d', 'nf_ecaresnet26', 'nf_ecaresnet50', 'nf_ecaresnet101', 'nf_resnet26', 'nf_resnet50', 'nf_resnet101', 'nf_seresnet26', 'nf_seresnet50', 'nf_seresnet101', 'resnest14d', 'resnest26d', 'resnest50d', 'resnest50d_1s4x24d', 'resnest50d_4s2x40d', 'resnest101e', 'resnest200e', 'resnest269e', 'resnet18', 'resnet18d', 'resnet26', 'resnet26d', 'resnet26t', 'resnet34', 'resnet34d', 'resnet50', 'resnet50d', 'resnet50t', 'resnet51q', 'resnet61q', 'resnet101', 'resnet101d', 'resnet152', 'resnet152d', 'resnet200', 'resnet200d', 'resnetblur18', 'resnetblur50', 'resnetrs50', 'resnetrs101', 'resnetrs152', 'resnetrs200', 'resnetrs270', 'resnetrs350', 'resnetrs420', 'resnetv2_50', 'resnetv2_50d', 'resnetv2_50t', 'resnetv2_50x1_bit_distilled', 'resnetv2_50x1_bitm', 'resnetv2_50x1_bitm_in21k', 'resnetv2_50x3_bitm', 'resnetv2_50x3_bitm_in21k', 'resnetv2_101', 'resnetv2_101d', 'resnetv2_101x1_bitm', 'resnetv2_101x1_bitm_in21k', 'resnetv2_101x3_bitm', 'resnetv2_101x3_bitm_in21k', 'resnetv2_152', 'resnetv2_152d', 'resnetv2_152x2_bit_teacher', 'resnetv2_152x2_bit_teacher_384', 'resnetv2_152x2_bitm', 'resnetv2_152x2_bitm_in21k', 'resnetv2_152x4_bitm', 'resnetv2_152x4_bitm_in21k', 'resnext50_32x4d', 'resnext50d_32x4d', 'resnext101_32x4d', 'resnext101_32x8d', 'resnext101_64x4d', 'seresnet18', 'seresnet34', 'seresnet50', 'seresnet50t', 'seresnet101', 'seresnet152', 'seresnet152d', 'seresnet200d', 'seresnet269d', 'seresnext26d_32x4d', 'seresnext26t_32x4d', 'seresnext26tn_32x4d', 'seresnext50_32x4d', 'seresnext101_32x4d', 'seresnext101_32x8d', 'skresnet18', 'skresnet34', 'skresnet50', 'skresnet50d', 'skresnext50_32x4d', 'ssl_resnet18', 'ssl_resnet50', 'ssl_resnext50_32x4d', 'ssl_resnext101_32x4d', 'ssl_resnext101_32x8d', 'ssl_resnext101_32x16d', 'swsl_resnet18', 'swsl_resnet50', 'swsl_resnext50_32x4d', 'swsl_resnext101_32x4d', 'swsl_resnext101_32x8d', 'swsl_resnext101_32x16d', 'tresnet_l', 'tresnet_l_448', 'tresnet_m', 'tresnet_m_448', 'tresnet_m_miil_in21k', 'tresnet_xl', 'tresnet_xl_448', 'tv_resnet34', 'tv_resnet50', 'tv_resnet101', 'tv_resnet152', 'tv_resnext50_32x4d', 'vit_base_resnet26d_224', 'vit_base_resnet50_224_in21k', 'vit_base_resnet50_384', 'vit_base_resnet50d_224', 'vit_small_resnet26d_224', 'vit_small_resnet50d_s16_224', 'wide_resnet50_2', 'wide_resnet101_2']
https://rwightman.github.io/pytorch-image-models/models/ 介绍了timm实现的一些网络模型及其论文和参考代码
https://paperswithcode.com/lib/timm 也有列出
模型及论文
CNN模型:
添加了经典的 NFNet,RegNet,TResNet,Lambda Networks,GhostNet,ByoaNet 等以及 TResNet, MobileNet-V3, ViT 的 ImageNet-21k 训练的权重,EfficientNet-V2 ImageNet-1k,ImageNet-21k 训练的权重。
Transformer模型:
添加了经典的 TNT,Swin Transformer,PiT,Bottleneck Transformers,Halo Nets,CoaT,CaiT,LeViT, Visformer, ConViT,Twins,BiT 等。
MLP模型:
添加了经典的 MLP-Mixer,ResMLP,gMLP等。
优化器层面:
更新了Adabelief optimizer等。
所以本文是对 timm 库代码的最新解读,不只限于视觉 transformer 模型。
所有的PyTorch模型及其对应arxiv链接如下:
Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
Bottleneck Transformers - https://arxiv.org/abs/2101.11605
CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
DenseNet - https://arxiv.org/abs/1608.06993
DLA - https://arxiv.org/abs/1707.06484
DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
EfficientNet (MBConvNet Family)
EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946
EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
EfficientNet V2 - https://arxiv.org/abs/2104.00298
FBNet-C - https://arxiv.org/abs/1812.03443
MixNet - https://arxiv.org/abs/1907.09595
MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626
MobileNet-V2 - https://arxiv.org/abs/1801.04381
Single-Path NAS - https://arxiv.org/abs/1904.02877
GhostNet - https://arxiv.org/abs/1911.11907
gMLP - https://arxiv.org/abs/2105.08050
GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
Halo Nets - https://arxiv.org/abs/2103.12731
HardCoRe-NAS - https://arxiv.org/abs/2102.11646
HRNet - https://arxiv.org/abs/1908.07919
Inception-V3 - https://arxiv.org/abs/1512.00567
Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
Lambda Networks - https://arxiv.org/abs/2102.08602
LeViT (Vision Transformer in ConvNet’s Clothing) - https://arxiv.org/abs/2104.01136
MLP-Mixer - https://arxiv.org/abs/2105.01601
MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
NASNet-A - https://arxiv.org/abs/1707.07012
NFNet-F - https://arxiv.org/abs/2102.06171
NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
PNasNet - https://arxiv.org/abs/1712.00559
Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
RegNet - https://arxiv.org/abs/2003.13678
RepVGG - https://arxiv.org/abs/2101.03697
ResMLP - https://arxiv.org/abs/2105.03404
ResNet/ResNeXt
ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385
ResNeXt - https://arxiv.org/abs/1611.05431
‘Bag of Tricks’ / Gluon C, D, E, S variations - https://arxiv.org/abs/1812.01187
Weakly-supervised (WSL) Instagram pretrained / ImageNet tuned ResNeXt101 - https://arxiv.org/abs/1805.00932
Semi-supervised (SSL) / Semi-weakly Supervised (SWSL) ResNet/ResNeXts - https://arxiv.org/abs/1905.00546
ECA-Net (ECAResNet) - https://arxiv.org/abs/1910.03151v4
Squeeze-and-Excitation Networks (SEResNet) - https://arxiv.org/abs/1709.01507
ResNet-RS - https://arxiv.org/abs/2103.07579
Res2Net - https://arxiv.org/abs/1904.01169
ResNeSt - https://arxiv.org/abs/2004.08955
ReXNet - https://arxiv.org/abs/2007.00992
SelecSLS - https://arxiv.org/abs/1907.00837
Selective Kernel Networks - https://arxiv.org/abs/1903.06586
Swin Transformer - https://arxiv.org/abs/2103.14030
Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
TResNet - https://arxiv.org/abs/2003.13630
Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
Vision Transformer - https://arxiv.org/abs/2010.11929
VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
Xception - https://arxiv.org/abs/1610.02357
Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611
Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611
XCiT (Cross-Covariance Image Transformers) - https://arxiv.org/abs/2106.09681
1. Models
timm 提供了大量的模型结构集合,而且很多模型都包含了预训练权重,或 PyTorch 训练、或从Jax和TensorFlow中移植,很方便下载使用.
查看模型列表:
#打印 timm 提供的模型列表 print(timm.list_models()) print(len(timm.list_models())) #739 #带有预训练权重的模型列表 print(timm.list_models(pretrained=True)) print(len(timm.list_models(pretrained=True))) #592
其中,timm.list_models()
函数:
list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False)
查看特定族模型,如:
print(timm.list_models('gluon_resnet*')) print(timm.list_models('*resnext*', 'resnet') ) print(timm.list_models('resnet*', pretrained=True))
1.1. create_model 一般用法
timm 创建模型最简单的方式是采用 create_model
.
以 Resnet-D 模型为例(Bag of Tricks for Image Classification For Convolutional Neural Networks paper),其是Resnet 的一种变形,其采用 average pool 进行下采样.
model = timm.create_model('resnet50d', pretrained=True) print(model) #查看模型配置参数 print(model.default_cfg) ''' {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1.0', 'classifier': 'fc', 'architecture': 'resnet50d'} '''
1.2. create_model 修改输入通道
timm models 有个非常有用的特点,其可以处理任意通道数量的输入图像. 这是很多其他库所不具备的. 其实现原理可参考:
https://fastai.github.io/timmdocs/models#So-how-is-timm-able-to-load-these-weights?
model = timm.create_model('resnet50d', pretrained=True, in_chans=1) print(model) #test, single channel image x = troch.randn(1, 1, 224, 224) out = model(x) print(out.shape) #torch.Size([1, 1000])
1.3. create_model 定制模型
timm create_model
函数提供了很多参数,用于模型定制,函数定义如:
create_model(model_name, pretrained=False, checkpoint_path='', scriptable=None, exportable=None, no_jit=None, **kwargs)
**kwargs 示例参数如,
global_pool - 定义最终分类层所采用的 global pooling 类型. 取决于网络结构是否用到了全局池化层.
drop_rate - 设定训练时的 dropout 比例,默认是 0.
num_classes - 输出类别数
1.3.1. 修改类别数
查看当前模型输出层:
#如果输出层是 fc,则如 print(model.fc) #Linear(in_features=2048, out_features=1000, bias=True) #通用方式,查看输出层, print(model.get_classifier())
修改输出层类别数:
model = timm.create_model('resnet50d', pretrained=True, num_classes=10) print(model) print(model.get_classifier()) #Linear(in_features=2048, out_features=10, bias=True)
如果完全不需要创建最后一层,可以将 num_classes 设为 0,模型将用恒等函数作为最后一层,其对于查看倒数第二层的输出有用.
model = timm.create_model('resnet50d', pretrained=True, num_classes=0) print(model) print(model.get_classifier()) #Identity()
1.3.2. Global pooling
在 model.default_cfg
中出现的 pool_size
设置,说明了在分类器前用到了一个全局池化层,如:
print(model.global_pool) #SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
其中,pool_type 支持:
avg - 平均池化
max - 最大池化
avgmax - 平均池化和最大池化的求和,加权 0.5
catevgmax - 沿着特征维度的平均池化和最大池化的输出的拼接,特征维度会翻倍
'' - 不采用 pooling,其被替换为恒等操作(Identity)
pool_types = ['avg', 'max', 'avgmax', 'catavgmax', ''] x = torch.randn(1, 3, 224, 224) for pool_type in pool_types: model = timm.create_model('resnet50d', pretrained=True, num_classes=0, global_pool=pool_type) model.eval() out = model(x) print(out.shape)
1.3.3. 修改已有模型
如,
model = timm.create_model('resnet50d', pretrained=True) print(f'[INFO]Original Pooling: {model.global_pool}') print(f'[INFO]Original Classifier: {model.get_classifier}') model = model.reset_classifier(10, 'max') print(f'[INFO]Modified Pooling: {model.global_pool}') print(f'[INFO]Modified Classifier: {model.get_classifier}')
1.3.4. 创建新的分类 head
虽然单个线性层已经足够得到比较好的结果,但有些时候需要更大的分类 head 来提升性能.
model = timm.create_model('resnet50d', pretrained=True, num_classes=10, global_pool='catavgmax') print(model) num_in_features = model.get_classifier().in_features print(num_in_features) model.fc = nn.Sequential( nn.BatchNorm1d(num_in_features), nn.Linear(in_features=num_in_features, out_features=512, bias=False), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(0.4), nn.Linear(in_features=512, out_features=10, bias=False)) model.eval() x = troch.randn(1, 3, 224, 224) out = model(x) print(out.shape)