简单有效 | Transformer通过剪枝降低FLOPs以走向部署(文末获取论文)

简介: 简单有效 | Transformer通过剪枝降低FLOPs以走向部署(文末获取论文)

1 简介


Visual Transformer在各种计算机视觉应用中取得了具有竞争力的性能。然而,它们的存储、运行时的内存以及计算需求阻碍了在移动设备上的部署。在这里,本文提出了一种Visual Transformer剪枝方法,该方法可以识别每个层中通道的影响,然后执行相应的修剪。通过促使Transformer通道的稀疏性,来使得重要的通道自动得到体现。同时为了获得较高的剪枝率,可以丢弃大量系数较小的通道,而不会造成显著的损害。

Visual transformer修剪的流程如下:

  1. Training with sparsity regularization
  2. Pruning channels
  3. Finetuning

在ImageNet数据集上验证了该算法的有效性。


2 Approach


图1 Visual transformer Pruning

2.1 复杂度分析

其实大家都知道典型的ViT结构包括Multi-Head Self-Attention(MHSA)、Multi-Layer Perceptron(MLP)、 layer normalization、激活函数以及Shortcut。

MHSA是Transformer组件,在token之间进行信息的交互。具体来说,将输入X通过全连接层转换为query 、key 和value ,其中n为patches的数量,d为embedding维数。这里利用self-attention对patch之间的关系进行建模:

image.png

最后,利用线性变换生成MHSA的输出:

image.png

为了简化,忽略了layer normalization和激活函数。MHSA的参数量为,FLOPs为。对于双层MLP,可以写成:

image.png

Hidden Layer dimension通常设置为,其参数量为, FLOPs为。与MHSA和MLP相比,layer normalization、激活函数和Shortcut的参数或FLOPs可以忽略。所以一个Transformer block约有的参数量和的FLOPs,其中MHSA和MLP占绝大多数计算量。

2.2 ViT剪枝

其实通过前面对于复杂度的分析可以看出来,绝大多数的计算量都被消耗再MHSA和MLP上了,所以为了实现Transformer架构的精简,作者着重于减少MHSA和MLP的FLOPs。

本文提出通过学习每个维度的重要性得分来减少特征的维度。对于特征,其中n表示待剪枝的通道数量,d表示每个通道的维度,而目标是保留重要的特征,去除无用的特征。假设最优的重要度评分为,即重要特征的评分为1,无用特征的评分为0。利用重要度分数可以得到剪枝后的特征:

image.png

然而,由于其是离散值导致很难通过反向传播算法优化神经网络中的。因此,作者提出使用松弛为real value 。得到的soft pruned特征为:

image.png

然后,relaxed importance scores 可以和transformer网络的端到端一起学习。

为了加强importance scores的稀疏性,对系数应用L1正则化:,并通过添加训练目标来优化它,其中是稀疏超参数。经过稀疏惩罚训练后,得到一些重要值接近于零的transformer。对transformer中的所有正则化系数值进行排序,并根据预先定义的剪枝率获得阈值。在阈值下,通过将阈值以下的值设为0,较高的值设为1得到离散的:

image.png

在根据importance scores 进行修剪后,被修剪的总transformer将被微调以减少精度下降。以上修剪过程记为:

image.png

如图1所示,我们对所有MHSA和MLP块应用剪枝操作。它们的修剪过程可以表述为:

image.png

所提出的visual transformer pruning(VTP)方法为slim visual transformer提供了一种简单而有效的方法。


3 Experiments


3.1 ImageNet-100

image.png

如表1所示从结果来看,剪枝率的大小与参数量和FLOPs的比例相匹配。例如,当修剪40%的通道的模型训练0.0001稀疏率,参数saving是45.3%,FLOPs saving是43.0%。可以看到在精度保持不变的情况下,参数和FLOPs下降了。此外,稀疏比对剪枝方法的有效性影响不大。

image.png

在表2中比较了Baseline模型和2种VTP模型,即20% pruned和40% pruned模型。精度会随着较大的下降而略有下降。当删除20%的通道时,22.0%的FLOPs被保存,准确率下降了0.96%。当删除40%的通道时,节省了45.3%的FLOPs,准确率也下降了1.92%。

3.2 ImageNet-1K

结果如表3所示。可以看出,与原始DeiT-B相比,在对40%的通道进行修剪后,VTP的准确率仅降低了1.1%。可以看出VTP的有效性可以推广到大规模数据集。


4 参考


[1].Visual Transformer Pruning


5 推荐阅读


又改ResNet | 重新思考ResNet:采用高阶方案的改进堆叠策略(附论文下载)

VariFocalNet | IoU-aware同V-Focal Loss全面提升密集目标检测(附YOLOV5测试代码)

最强Vision Trabsformer | 87.7%准确率!CvT:将卷积引入视觉Transformer(文末附论文下载)

全新FPN | 通道增强特征金字塔网络(CE-FPN)提升大中小目标检测的鲁棒性(文末附论文)

经典Transformer | CoaT为Transformer提供Light多尺度的上下文建模能力(附论文下载)

相关文章
|
7月前
|
机器学习/深度学习 编解码 人工智能
ICLR 2024:泛化递归Transformer,降低超分辨率复杂度
【2月更文挑战第16天】ICLR 2024:泛化递归Transformer,降低超分辨率复杂度
254 1
ICLR 2024:泛化递归Transformer,降低超分辨率复杂度
|
3月前
英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强
【9月更文挑战第10天】《通过剪枝和知识蒸馏实现紧凑型语言模型》由英伟达研究人员撰写,介绍了一种创新方法,通过剪枝和知识蒸馏技术将大型语言模型参数数量减半,同时保持甚至提升性能。该方法首先利用剪枝技术去除冗余参数,再通过知识蒸馏从更大模型转移知识以优化性能。实验结果显示,该方法能显著减少模型参数并提升性能,但可能需大量计算资源且效果因模型和任务而异。
84 8
|
5月前
|
机器学习/深度学习 计算机视觉 异构计算
【YOLOv8改进 - Backbone主干】FasterNet:基于PConv(部分卷积)的神经网络,提升精度与速度,降低参数量。
【YOLOv8改进 - Backbone主干】FasterNet:基于PConv(部分卷积)的神经网络,提升精度与速度,降低参数量。
|
5月前
|
计算机视觉 异构计算
【YOLOv8改进-SPPF】 AIFI : 基于注意力的尺度内特征交互,保持高准确度的同时减少计算成本
YOLOv8专栏介绍了该系列目标检测框架的最新改进与实战应用。文章提出RT-DETR,首个实时端到端检测器,解决了速度与精度问题。通过高效混合编码器和不确定性最小化查询选择,RT-DETR在COCO数据集上实现高AP并保持高帧率,优于其他YOLO版本。论文和代码已开源。核心代码展示了AIFI Transformer层,用于位置嵌入。更多详情见[YOLOv8专栏](https://blog.csdn.net/shangyanaf/category_12303415.html)。
|
7月前
|
自然语言处理 算法 网络架构
DeepMind升级Transformer,前向通过FLOPs最多可降一半
【4月更文挑战第25天】DeepMind提出的新Transformer变体MoD,通过动态分配计算资源降低前向计算复杂度,旨在优化效率并保持性能。MoD模型采用动态路由机制,集中计算资源处理关键token,减少不必要的计算,从而提高效率和速度。实验显示,MoD模型能减半FLOPs,降低成本。然而,它面临动态计算分配的复杂性、路由算法的准确性及自回归采样中的非因果性挑战。[论文链接](https://arxiv.org/pdf/2404.02258.pdf)
64 5
|
7月前
|
缓存 并行计算 算法
【译】Based:简单线性注意力语言模型平衡召回-吞吐量权衡
【译】Based:简单线性注意力语言模型平衡召回-吞吐量权衡
61 3
|
7月前
|
机器学习/深度学习 网络架构
YOLOv8改进 | 2023主干篇 | 利用RT-DETR特征提取网络PPHGNetV2改进YOLOv8(超级轻量化精度更高)
YOLOv8改进 | 2023主干篇 | 利用RT-DETR特征提取网络PPHGNetV2改进YOLOv8(超级轻量化精度更高)
498 1
|
机器学习/深度学习 数据可视化
DHVT:在小数据集上降低VIT与卷积神经网络之间差距,解决从零开始训练的问题
VIT在归纳偏置方面存在空间相关性和信道表示的多样性两大缺陷。所以论文提出了动态混合视觉变压器(DHVT)来增强这两种感应偏差。
260 0
|
存储 编解码 算法
全新剪枝框架 | YOLOv5模型缩减4倍,推理速度提升2倍(二)
全新剪枝框架 | YOLOv5模型缩减4倍,推理速度提升2倍(二)
493 0
|
机器学习/深度学习 传感器 编解码
全新剪枝框架 | YOLOv5模型缩减4倍,推理速度提升2倍(一)
全新剪枝框架 | YOLOv5模型缩减4倍,推理速度提升2倍(一)
474 0
下一篇
DataWorks