模型量化显著降低了模型推理的复杂性,并已被广泛用于现实应用的部署。然而,大多数现有的量化方法主要是在卷积神经网络(CNNs)上开发的,当应用于全量化的Vision Transformer时,会出现严重的退化。
在这项工作中证明了这些困难中的许多是由于LayerNorm输入中的严重通道间变化而出现的,并且提出了Power-of-Two Factor(PTF),这是一种减少全量化Vision Transformer性能退化和推理复杂性的系统方法。此外,观察到注意力图中的极端非均匀分布,提出了Log Int Softmax(LIS)来维持这一点,并通过使用4位量化和BitShift算子来简化推理。
在各种基于Transformer的架构和基准测试上进行的综合实验表明,全量化Vision Transformer(FQ-ViT)在注意力图上使用更低的位宽的同时,也优于以前的工作。例如,在ImageNet上使用ViT-L达到84.89%的Top-1准确率,在COCO上使用Cascade Mask R-CNN(SwinS)达到50.8 mAP。据所知是第1个在全量化的Vision Transformer上实现无损精度下降(~1%)的算法。
Github地址:
1、简介
基于Transformer的架构在各种计算机视觉(CV)任务中取得了具有竞争力的性能,包括图像分类、目标检测、语义分割等。与CNN的同类架构相比,Transformer通常具有更多的参数和更高的计算成本。例如,ViT-L具有307M参数和190.7G FLOP,在经过大规模预训练的ImageNet中达到87.76%的准确率。然而,当部署到资源受限的硬件设备时,基于Transformer的架构的大量参数和计算开销带来了挑战。
为了便于部署,已经提出了几种技术,包括架构设计的量化、剪枝、蒸馏和自适应。在本文中重点关注量化技术,并注意到剪枝、蒸馏和架构自适应与本文的工作正交,并且可以组合。
大多数现有的量化方法都是在神经网络上设计和测试的,并且缺乏对转化子特异性构建的适当处理。先前的工作发现,在量化Vision Transformer的LayerNorm和Softmax时,精度显著下降。在这种情况下,模型没有完全量化,导致需要在硬件中保留浮点单元,这将带来巨大的消耗,并显著降低推理速度。
因此,重新审视了Vision Transformer的这2个专属模块,并发现了退化的原因:
- 首先,作者发现LayerNorm输入的通道间变化严重,有些通道范围甚至超过中值的40倍。传统方法无法处理如此大的激活波动,这将导致很大的量化误差。
- 其次,作者发现注意力图的值具有极端的不均匀分布,大多数值聚集在0~0.01之间,少数高注意力值接近1。
基于以上分析,作者提出了Power-of-Two Factor(PTF)来量化LayerNorm的输入。通过这种方式,量化误差大大降低,并且由于Bit-Shift算子,整体计算效率与分层量化的计算效率相同。
此外,还提出了Log Int Softmax(LIS),它为小值提供了更高的量化分辨率,并为Softmax提供了更有效的整数推理。结合这些方法,本文首次实现了全量化Vision Transformer的训练后量化。
如图1所示,本文的方法显著提高了全量化Vision Transformer的性能,并获得了与全精度对应算法相当的精度。
本文的贡献有4方面:
- 重新审视了完全量化的Vision Transformer,并将精度下降归因于LayerNorm输入的严重通道间变化。同时,观察到注意力图的极端不均匀分布,导致量化误差。
- 提出了Power-of-Two Factor(PTF),这是一种简单而有效的后训练方法,可以在只有一个分层量化尺度的情况下对LayerNorm输入实现精确量化。
- 提出了Log Int Softmax(LIS),这是一种可以对注意力图执行4-bit量化的新方法。使用LIS,可以将注意力映射存储在一个激进的低位上,并用Bit-Shift运算符代替乘法。在Softmax模块上实现了仅整数推理,显著降低了推理消耗。
- 使用各种基于Transformer的架构对图像分类和目标检测进行了广泛的实验。结果表明,全量化Vision Transformer具有8位权重/激活和4位注意力映射,可以实现与浮点版本相当的性能。
2、相关工作
2.1、Vision Transformer
最近,基于Transformer的体系结构在CV任务中显示出巨大的威力。基于ViT的新兴工作证明了分类、检测和分割等所有视觉任务的有效性。新提出的Swin Transformer在几乎传统的CV任务上甚至超过了最先进的神经网络,呈现出强大的Transformer表达和泛化能力。
然而,这些高性能的Vision Transformer归因于大量的参数和高计算开销,限制了它们的采用。因此,设计更小、更快的Vision Transformer成为一种新趋势。LeViT通过下采样、Patch描述符和注意力MLP块的重新设计,在更快的推理方面取得了进展。DynamicViT提出了一个动态Token稀疏化框架,以逐步动态地修剪冗余Token,实现竞争复杂性和准确性的权衡。Evo-ViT提出了一种快速更新机制,该机制可以保证信息流和空间结构,从而降低训练和推理的复杂性。虽然上述工作侧重于高效的模型设计,但本文在量化的思路上提高了压缩和加速。
2.2、模型量化
目前的量化方法可以分为两类:量化感知训练(QAT)和训练后量化(PTQ)。
QAT依赖于训练来实现低比特(例如2比特)量化和有希望的性能,而它通常需要高水平的专家知识和巨大的GPU资源来进行训练或微调。为了降低上述量化成本,无训练的PTQ受到了越来越广泛的关注,并出现了许多优秀的作品。OMSE建议通过最小化量化误差来确定激活的值范围。AdaRound提出了一种新的舍入机制来适应数据和任务损失。
除了上述针对神经网络的工作外,Liu等人还提出了一种具有相似性感知和秩感知策略的Vision Transformer训练后量化方法。然而,这项工作没有量化Softmax和LayerNorm模块,导致量化不完整。在本文的FQ-ViT中,目标是在PTQ范式下实现精确、完全量化的Vision Transformer。
3、本文方法
在本节中将详细介绍所提出的方法。首先,在第3.1节中,提出了网络量化的初步结果。然后在第3.2和3.3节中,分析了全量化Vision Transformer退化的原因,并提出了两种新的量化方法,Power-of-Two Factor(PTF)和Log-Int-Softmax(LIS)。
3.1、准备工作
在本节中,将解释网络量化的符号。假设量化位宽为b,量化器可以公式化为将浮点数 映射到最近的量化bin的函数:
有各种各样的量化器,其中通常使用和。
1、Uniform Quantization
大多数硬件平台都很支持量化。其量化器可以定义为:
其中和是由的下界和上界确定的量化参数,它们通常是最小值和最大值:
2、Log2 Quantization
量化化将量化过程从线性变化转换为指数变化。其量化器可定义为:
在本文中,为了实现一个全量化的Vision Transformer,本文量化了所有模块的量化,包括Conv、Linear、MatMul、LayerNorm、Softmax等。特别是对Conv、Linear和MatMul模块采用均匀的MinMax量化,对LayerNorm和Softmax采用以下方法。
3.2、LayerNorm量化的Power-of-Two Factor
在推理过程中,LayerNorm计算每个前向传播步骤中的统计数据,σ,并对输入进行归一化。然后,仿射参数γ,β将归一化输入重新缩放到另一个学习的分布。上述过程可以写成:
与神经网络中常用的BatchNorm不同,LayerNorm由于其动态计算特性,无法折叠到前一层,因此必须单独量化它。然而,在对其应用训练后量化时观察到显著的性能下降。查看LayerNorm层的输入,发现存在严重的通道间变化。
图2显示了最后一个LayerNorm层中的通道激活范围。此外,还展示了ResNets的案例进行比较。考虑到ResNets中没有LayerNorm,选择相同位置的激活(第4阶段的输出)来展示。
可以观察到,与ResNets相比,Vision Transformer中的通道范围波动更大。例如,ResNet152的最大范围/中值范围仅为21.6/4.2,而在Swin-B中则上升到622.5/15.5。基于这种极端的通道间变化,将相同的量化参数应用于所有通道的逐层量化将导致不可容忍的量化误差。一种可能的解决方案是使用分组量化或通道量化,其将不同的量化参数分配给不同的组或通道。然而,这些仍然会导致浮点域中的均值和方差的计算,从而导致较高的硬件开销。
在本文中提出了一种简单而有效的LayerNorm量化方法,即Power-of-Two Factor(PTF)。PTF的核心思想是为不同的通道配备不同的因子,而不是不同的量化参数。给定量化位宽b,输入激活,逐层量化参数,,以及PTF α,则量化激活可以公式化为:
注意,表示和α的通道索引。超参数可以满足不同的缩放要求。为了涵盖所有模型中不同的通道间变化,将K=3设置为默认值。
在这一点上,每个通道都有自己的Power-of-Two Factor α和逐层参数,。在推理过程中,可以提取逐层参数和,因此、σ的计算可以在整数域而不是浮点域中进行,这降低了能量和面积成本。同时,由于二次幂的性质,可以通过Bit-Shift算子将PTF α与分层量化有效地结合起来,避免了分组或通道量化的浮点计算。整个过程可以分为2个阶段:
- 阶段1:用Power-of-Two Factor Shift量化激活α:
- 阶段2:根据Shift激活计算平均值和方差:
1、LayerNorm的量化推理
LayerNorm已广泛应用于神经网络中,以加速训练过程中的收敛速度,可以表述为:
其中,γ,β为学习参数,,σ为需要根据LayerNorm的输入进行计算的统计量。
本文提出的Power-of-Two Factor是用于Layer Norm输入的量化,其量化后的值可以写为:
式中,为量化参数,α为Power-of-Two Factor。
根据LayerNorm的定义,应该计算输入的统计量。如本文所述,整个过程可以分为2个阶段。在第一阶段,用PTF α移动量化的激活:
然后,在第二阶段,需要计算基于移位激活的统计量。首先,测量和的平均值如下:
其中为中的通道数。其次,利用和计算σ:
因此,基于仅整数计算得到输入的统计量。
在统计数据的计算之后,需要对LayerNorm进行综合推理。将输入量化如下:
其中,为输入的比例,而为输出的scale和zero-point。
为了简化方程,将每一项融合如下:
为了得到仅限整数的推理,将A近似为:
其中b是目标位宽。最后,对LayerNorm的量化推理可以表述为:
2、LayerNorm的量化实现
class QIntLayerNorm(nn.LayerNorm): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): super(QIntLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) assert isinstance(normalized_shape, int) self.mode = 'ln' def get_MN(self, x): bit = 8 N = torch.clamp(bit - 1 - torch.floor(torch.log2(x)), 0, 31) M = torch.clamp(torch.floor(x * torch.pow(2, N)), 0, 2 ** bit - 1) return M, N def forward(self, x, in_quantizer=None, out_quantizer=None, in_scale_expand=1): if self.mode == 'ln': x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.mode == 'int': in_scale = in_quantizer.scale if in_scale_expand != 1: in_scale = in_scale.unsqueeze(-1).expand( -1, in_scale_expand).T.reshape(-1) out_scale = out_quantizer.scale assert in_scale is not None and out_scale is not None channel_nums = x.shape[-1] in_scale = in_scale.reshape(1, 1, -1) out_scale = out_scale.reshape(1, 1, -1) x_q = (x / in_scale).round() in_scale1 = in_scale.min() in_scale_mask = (in_scale / in_scale1).round() x_q = x_q * in_scale_mask mean_x_q = x_q.mean(dim=-1) * in_scale1 std_x_q = (in_scale1 / channel_nums) * torch.sqrt( channel_nums * (x_q**2).sum(dim=-1) - x_q.sum(dim=-1)**2) A = (in_scale1 / std_x_q).unsqueeze(-1) * \ self.weight.reshape(1, 1, -1) / out_scale A_sign = A.sign() M, N = self.get_MN(A.abs()) B = ((self.bias.reshape(1, 1, -1) - (mean_x_q / std_x_q).unsqueeze(-1) * self.weight.reshape(1, 1, -1)) / out_scale * torch.pow(2, N)).round() x_q = ((A_sign * M * x_q + B) / torch.pow(2, N)).round() x = x_q * out_scale else: raise NotImplementedError return x