GitHub 7.5k star量,各种视觉Transformer的PyTorch实现合集整理好了
目录
博主介绍
简介
项目介绍
蒸馏
深ViT
门槛值
代币对代币 ViT
CCT
交叉 ViT
PiT
LeViT
CvT
Twins SVT
RegionViT
CrossFormer
NesT
MobileViT
简单的蒙版图像建模
屏蔽自编码器
💫点击直接资料领取💫
目录
博主介绍
💂 个人主页:苏州程序大白
💂 个人社区:CSDN全国各地程序猿
🤟作者介绍:中国DBA联盟(ACDU)成员,CSDN全国各地程序猿(媛)聚集地管理员。目前从事工业自动化软件开发工作。擅长C#、Java、机器视觉、底层算法等语言。2019年成立柒月软件工作室,2021年注册苏州凯捷智能科技有限公司
💬如果文章对你有帮助,欢迎关注、点赞、收藏(一键三连)和C#、Halcon、python+opencv、VUE、各大公司面试等一些订阅专栏哦
🎗️ 承接软件APP、小程序、网站等开发重点行业应用开发(SaaS、PaaS、CRM、HCM、银行核心系统、监管报送平台、系统搭建、人工智能助理)、大数据平台开发、商业智能、App开发、ERP、云平台、智能终端、产品化解决方案。测试软件产品测试、应用软件测试、测试平台及产品、测试解决方案。运维数据库维护(SQL Server 、Oracle、MySQL)、 操作系统维护(Windows、Linux、Unix等常用系统)、 服务器硬件设备维护、网络设备维护、 运维管理平台等。运营服务IT咨询 、IT服务、业务流程外包(BPO)、云/基础设施的管理、线上营销、数据采集与标注、内容管理和营销、设计服务、本地化、智能客服、大数据分析等。
💅 有任何问题欢迎私信,看到会及时回复
👤 微信号:stbsl6,微信公众号:苏州程序大白
简介
近一两年,Transformer 跨界 CV 任务不再是什么新鲜事了。
自 2020 年 10 月谷歌提出 Vision Transformer (ViT) 以来,各式各样视觉 Transformer 开始在图像合成、点云处理、视觉 - 语言建模等领域大显身手。
之后,在 PyTorch 中实现 Vision Transformer 成为了研究热点。GitHub 中也出现了很多优秀的项目,今天要介绍的就是其中之一。
该项目名为「vit-pytorch」,它是一个 Vision Transformer 实现,展示了一种在 PyTorch 中仅使用单个 transformer 编码器来实现视觉分类 SOTA 结果的简单方法。
项目当前的 star 量已经达到了 7.5k,创建者为 Phil Wang,ta 在 GitHub 上有 147 个资源库。
项目作者还提供了一段动图展示:
项目介绍
首先来看 Vision Transformer-PyTorch 的安装、使用、参数、蒸馏等步骤。
第一步是安装:
$ pip install vit-pytorch
第二步是使用:
import torch from vit_pytorch import ViT v = ViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) img = torch.randn(1, 3, 256, 256) preds = v(img) # (1, 1000)
第三步是所需参数,包括如下:
image_size: 内部。
图片大小。如果您有矩形图像,请确保您的图像尺寸是宽度和高度的最大值
patch_size: 内部。
补丁数。image_size必须被 整除patch_size。
补丁的数量是: n = (image_size // patch_size) ** 2并且n 必须大于 16。
num_classes: 内部。
要分类的类数。
dim: 内部。
线性变换后输出张量的最后一维nn.Linear(..., dim)。
depth: 内部。
变压器块的数量。
heads: 内部。
多头注意力层中的头数。
mlp_dim: 内部。
MLP(前馈)层的维度。
channels:整数,默认3。
图像的通道数。
dropout: 之间浮动[0, 1],默认0.。
辍学率。
emb_dropout: 之间浮动[0, 1],默认0。
嵌入辍学率。
pool: 字符串,cls令牌池化或mean池化。
蒸馏
最近的一篇论文表明,使用蒸馏令牌将知识从卷积网络蒸馏到视觉变换器可以产生小而高效的视觉变换器。该存储库提供了轻松进行蒸馏的方法。
前任。从 Resnet50(或任何老师)提炼到视觉转换器。
import torch from torchvision.models import resnet50 from vit_pytorch.distill import DistillableViT, DistillWrapper teacher = resnet50(pretrained = True) v = DistillableViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 8, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) distiller = DistillWrapper( student = v, teacher = teacher, temperature = 3, # temperature of distillation alpha = 0.5, # trade between main loss and distillation loss hard = False # whether to use soft or hard distillation ) img = torch.randn(2, 3, 256, 256) labels = torch.randint(0, 1000, (2,)) loss = distiller(img, labels) loss.backward() # after lots of training above ... pred = v(img) # (2, 1000)
该DistillableViT班是相同的ViT,除了向前传球是如何处理的,所以你应该能够加载参数回ViT你已经完成蒸馏训练。
您还可以.to_vit在DistillableViT实例上使用方便的方法来取回ViT实例。
v = v.to_vit() type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
深ViT
这篇论文指出 ViT 努力在更深的地方(过去 12 层)参与,并建议将每个头的注意力混合在 softmax 之后作为解决方案,称为重新注意力。结果与NLP的Talking Heads论文一致。
您可以按如下方式使用它:
import torch from vit_pytorch.deepvit import DeepViT v = DeepViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) img = torch.randn(1, 3, 256, 256) preds = v(img) # (1, 1000)
门槛值
本文还指出了在更深层次上训练视觉转换器的困难,并提出了两种解决方案。首先,它建议对残差块的输出进行每通道乘法。其次,它建议让补丁互相关注,并且只允许 CLS 令牌关注最后几层的补丁。
他们还添加了Talking Heads,注意到改进。
您可以按如下方式使用此方案:
import torch from vit_pytorch.cait import CaiT v = CaiT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 12, # depth of transformer for patch to patch attention only cls_depth = 2, # depth of cross attention of CLS tokens to patch heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1, layer_dropout = 0.05 # randomly dropout 5% of the layers ) img = torch.randn(1, 3, 256, 256) preds = v(img) # (1, 1000)
代币对代币 ViT
本文提出前几层应该通过展开对图像序列进行下采样,导致每个标记中的图像数据重叠,如上图所示。您可以ViT按如下方式使用此变体。
import torch from vit_pytorch.t2t import T2TViT v = T2TViT( dim = 512, image_size = 224, depth = 5, heads = 8, mlp_dim = 512, num_classes = 1000, t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module ) img = torch.randn(1, 3, 224, 224) preds = v(img) # (1, 1000)
CCT
CCT通过使用卷积而不是修补和执行序列池来提出紧凑型转换器。这使得 CCT 具有高精度和少量参数。
您可以通过两种方法使用它:
import torch from vit_pytorch.cct import CCT model = CCT( img_size=224, embedding_dim=384, n_conv_layers=2, kernel_size=7, stride=2, padding=3, pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, num_layers=14, num_heads=6, mlp_radio=3., num_classes=1000, positional_embedding='learnable', # ['sine', 'learnable', 'none'] )
或者,您可以使用多个预定义模型之一,这些模型[2,4,6,7,8,14,16] 预定义了层数、注意力头数、mlp 比率和嵌入维度。
import torch from vit_pytorch.cct import cct_14 model = cct_14( img_size=224, n_conv_layers=1, kernel_size=7, stride=2, padding=3, pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, num_classes=1000, positional_embedding='learnable', # ['sine', 'learnable', 'none'] )
官方存储库包括到经过训练的模型检查点的链接。
交叉 ViT
本文建议让两个视觉转换器处理不同尺度的图像,每隔一段时间交叉处理一个。他们展示了在基础视觉转换器之上的改进。
import torch from vit_pytorch.cross_vit import CrossViT v = CrossViT( image_size = 256, num_classes = 1000, depth = 4, # number of multi-scale encoding blocks sm_dim = 192, # high res dimension sm_patch_size = 16, # high res patch size (should be smaller than lg_patch_size) sm_enc_depth = 2, # high res depth sm_enc_heads = 8, # high res heads sm_enc_mlp_dim = 2048, # high res feedforward dimension lg_dim = 384, # low res dimension lg_patch_size = 64, # low res patch size lg_enc_depth = 3, # low res depth lg_enc_heads = 8, # low res heads lg_enc_mlp_dim = 2048, # low res feedforward dimensions cross_attn_depth = 2, # cross attention rounds cross_attn_heads = 8, # cross attention heads dropout = 0.1, emb_dropout = 0.1 ) img = torch.randn(1, 3, 256, 256) pred = v(img) # (1, 1000)
PiT
本文建议通过使用深度卷积的池化过程对令牌进行下采样。
import torch from vit_pytorch.pit import PiT v = PiT( image_size = 224, patch_size = 14, dim = 256, num_classes = 1000, depth = (3, 3, 3), # list of depths, indicating the number of rounds of each stage before a downsample heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) # forward pass now returns predictions and the attention maps img = torch.randn(1, 3, 224, 224) preds = v(img) # (1, 1000)
LeViT
这篇论文提出了一些变化,包括(1)卷积嵌入而不是逐块投影(2)阶段中的下采样(3)注意力中的额外非线性(4)二维相对位置偏差而不是初始绝对位置偏差(5 ) 批范数代替层范数。
官方仓库
import torch from vit_pytorch.levit import LeViT levit = LeViT( image_size = 224, num_classes = 1000, stages = 3, # number of stages dim = (256, 384, 512), # dimensions at each stage depth = 4, # transformer of depth 4 at each stage heads = (4, 6, 8), # heads at each stage mlp_mult = 2, dropout = 0.1 ) img = torch.randn(1, 3, 224, 224) levit(img) # (1, 1000)
CvT
本文提出混合卷积和注意力。具体来说,卷积用于分三个阶段嵌入和下采样图像/特征图。深度卷积还用于投影查询、键和值以引起注意。
import torch from vit_pytorch.cvt import CvT v = CvT( num_classes = 1000, s1_emb_dim = 64, # stage 1 - dimension s1_emb_kernel = 7, # stage 1 - conv kernel s1_emb_stride = 4, # stage 1 - conv stride s1_proj_kernel = 3, # stage 1 - attention ds-conv kernel size s1_kv_proj_stride = 2, # stage 1 - attention key / value projection stride s1_heads = 1, # stage 1 - heads s1_depth = 1, # stage 1 - depth s1_mlp_mult = 4, # stage 1 - feedforward expansion factor s2_emb_dim = 192, # stage 2 - (same as above) s2_emb_kernel = 3, s2_emb_stride = 2, s2_proj_kernel = 3, s2_kv_proj_stride = 2, s2_heads = 3, s2_depth = 2, s2_mlp_mult = 4, s3_emb_dim = 384, # stage 3 - (same as above) s3_emb_kernel = 3, s3_emb_stride = 2, s3_proj_kernel = 3, s3_kv_proj_stride = 2, s3_heads = 4, s3_depth = 10, s3_mlp_mult = 4, dropout = 0. ) img = torch.randn(1, 3, 224, 224) pred = v(img) # (1, 1000)
Twins SVT
该文提出了混合本地和全球的关注,与位置编码发生器(中提出沿CPVT)和全球平均水平池,以达到相同的结果斯文,没有转移的窗户,CLS令牌,也不是位置的嵌入的额外的复杂性。
import torch from vit_pytorch.twins_svt import TwinsSVT model = TwinsSVT( num_classes = 1000, # number of output classes s1_emb_dim = 64, # stage 1 - patch embedding projected dimension s1_patch_size = 4, # stage 1 - patch size for patch embedding s1_local_patch_size = 7, # stage 1 - patch size for local attention s1_global_k = 7, # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper s1_depth = 1, # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff) s2_emb_dim = 128, # stage 2 (same as above) s2_patch_size = 2, s2_local_patch_size = 7, s2_global_k = 7, s2_depth = 1, s3_emb_dim = 256, # stage 3 (same as above) s3_patch_size = 2, s3_local_patch_size = 7, s3_global_k = 7, s3_depth = 5, s4_emb_dim = 512, # stage 4 (same as above) s4_patch_size = 2, s4_local_patch_size = 7, s4_global_k = 7, s4_depth = 4, peg_kernel_size = 3, # positional encoding generator kernel size dropout = 0. # dropout ) img = torch.randn(1, 3, 224, 224) pred = model(img) # (1, 1000)
RegionViT
本文提出将特征图划分为局部区域,从而使局部标记相互协调。每个本地区域都有自己的区域令牌,然后处理其所有本地令牌以及其他区域令牌。
您可以按如下方式使用它:
import torch from vit_pytorch.regionvit import RegionViT model = RegionViT( dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage window_size = 7, # window size, which should be either 7 or 14 num_classes = 1000, # number of output classes tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance ) img = torch.randn(1, 3, 224, 224) pred = model(img) # (1, 1000)
CrossFormer
这纸PVT和欧亚交替使用本地和全球的关注甘拜下风。全局注意力是跨窗口维度完成的,以降低复杂性,就像用于轴向注意力的方案一样。
他们还有跨尺度嵌入层,他们证明这是一个可以改进所有视觉转换器的通用层。还制定了动态相对位置偏差,以允许网络推广到更高分辨率的图像。
import torch from vit_pytorch.crossformer import CrossFormer model = CrossFormer( num_classes = 1000, # number of output classes dim = (64, 128, 256, 512), # dimension at each stage depth = (2, 2, 8, 2), # depth of transformer at each stage global_window_size = (8, 4, 2, 1), # global window sizes at each stage local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages) ) img = torch.randn(1, 3, 224, 224) pred = model(img) # (1, 1000)
NesT
这纸决定来处理分层级的图像,注意力只在局部块,其中聚集因为它移动了层次结构的令牌。聚合是在图像平面中完成的,并包含一个卷积和后续的 maxpool,以允许它跨边界传递信息。
您可以使用以下代码(例如 NesT-T)
import torch from vit_pytorch.nest import NesT nest = NesT( image_size = 224, patch_size = 4, dim = 96, heads = 3, num_hierarchies = 3, # number of hierarchies block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom num_classes = 1000 ) img = torch.randn(1, 3, 224, 224) pred = nest(img) # (1, 1000)
MobileViT
这个纸介绍MobileViT,重量轻的和通用的视觉变压器用于移动设备。MobileViT 为使用转换器对信息进行全局处理提供了不同的视角。
您可以使用以下代码(例如 mobilevit_xs)
import torch from vit_pytorch.mobile_vit import MobileViT mbvit_xs = MobileViT( image_size = (256, 256), dims = [96, 120, 144], channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384], num_classes = 1000 ) img = torch.randn(1, 3, 256, 256) pred = mbvit_xs(img) # (1, 1000)
简单的蒙版图像建模
这个纸提出了一种简单的掩蔽图像的建模(SimMIM)方案,仅使用一个线性投影断掩蔽令牌为像素空间后跟一个L1损失与掩蔽贴片的像素值。结果与其他更复杂的方法相比具有竞争力。
您可以按如下方式使用它:
import torch from vit_pytorch import ViT from vit_pytorch.simmim import SimMIM v = ViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 8, mlp_dim = 2048 ) mim = SimMIM( encoder = v, masking_ratio = 0.5 # they found 50% to yield the best results ) images = torch.randn(8, 3, 256, 256) loss = mim(images) loss.backward() # that's all! # do the above in a for loop many times with a lot of images and your vision transformer will learn torch.save(v.state_dict(), './trained-vit.pt')
屏蔽自编码器
Kaiming He 的一篇新论文提出了一种简单的自动编码器方案,其中视觉转换器处理一组未屏蔽的补丁,而较小的解码器尝试重建屏蔽的像素值。
DeepReader 快速论文审查
与 Letitia 的 AI Coffeebreak
您可以通过以下代码使用它
import torch from vit_pytorch import ViT, MAE v = ViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 8, mlp_dim = 2048 ) mae = MAE( encoder = v, masking_ratio = 0.75, # the paper recommended 75% masked patches decoder_dim = 512, # paper showed good results with just 512 decoder_depth = 6 # anywhere from 1 to 8 ) images = torch.randn(8, 3, 256, 256) loss = mae(images) loss.backward() # that's all! # do the above in a for loop many times with a lot of images and your vision transformer will learn # save your improved vision transformer torch.save(v.state_dict(), './trained-vit.pt')