ResNet压缩20倍,Facebook提出新型无监督模型压缩量化方法

简介: 怎样用量化方法解决模型压缩问题?Facebook 近日提出了一个基于向量的量化方法,无需标注数据即可对 ResNet 模型进行20倍压缩,还能够获得很高的准确率。
怎样用量化方法解决模型压缩问题? Facebook 近日提出了一个基于向量的量化方法,无需标注数据即可对 ResNet 模型进行20倍压缩,还能够获得很高的准确率。


概述


本文旨在解决类似 ResNet [1] 这类卷积网络的模型过大,推理速度慢的问题。相比较之前的量化方法,本文提出来一个向量量化方法,其主要思想是保存网络重建之后的输出而不是原始无压缩网络的权重。本文提出的方法无需标注数据,并且使用对 CPU 推理友好的字节对齐码本。实验证明,使用本文方法可对 ResNet 模型进行20倍压缩,并在ImageNet 上获得 76.1% 的 top-1准确率。与此同时,可将 Mask R-cnn 压缩至6MB的大小并保持不错的准确率。


论文地址:https://arxiv.org/abs/1907.05686


压缩方法的介绍和比较


神经网络压缩一直是个热门的研究方向,目前的主流方法主要有以下几种:

  • 低精度训练[2]。这类方法用低bit权重,优点是可以加速推理过程,利用位操作代替复杂的逻辑操作,但是同时也会带来一个比较大的精度下降。
  • 量化。向量量化(VQ[3])和乘积量化(PQ[4])是其中的典型。这种量化方法是将原始的高维向量分解至低维笛卡尔积,然后对分解得到的低维向量做量化,这种方法的缺点是对于深度网络会产生一个笛卡尔漂移。
  • 剪枝。根据一些特定的规则去除部分冗余的连接或者结构,剪枝方法的训练时间较长,且需要剪枝和再微调的反复迭代进行。
  • 调整结构。类似 SqueezeNet[5],NASnet,shuffleNet,MobileNet等,主要依赖的 DepthWise 和 PointWise 卷积,或者通道分类和通道打乱等等。
  • 蒸馏[6-7]。这种方法通常利用大模型或模型组的知识(概率分布)来训练小模型。


相比较以上的压缩方法,本文提出的量化方法注重于恢复网络的输出。跟PQ方法相比,本文是注重恢复压缩后的权重。同时,在训练的过程中,本文也运用了蒸馏的思想,用于提升压缩网络的精度。


算法介绍


全连接层量化



  • 背景介绍--PQ 方法介绍


我们知道,PQ 算法(Product Quantization)一开始是由法国 INRIA 实验室团队提出的用于图像压缩的一种算法,通常被用于信息压缩和检索方向。PQ 算法中,以全连接层量化为例。处理全连接层时,我们关注全连接权重,忽略偏差。首先将 W 的每一列分割成 m 个连续的子向量,然后对m*的子向量学习一个码本。
对于包含 k 个聚类中心的模板 C={c1,c2,,,ck},任意 W 的列向量 Wj 可以映射为量化的版本 q(wj)= (ci1,,,cim),其中下标 i1 表示的是 wj 的第一个子向量归属的码本的索引,以此类推其余的下标。通过量化如下的目标函数(1)来学习码本。


微信图片_20211202135411.jpg


其中,w^ 表示量化的权重。这个目标函数可以通过 k-means 来最小化。当 m=1 时,PQ等价于 VQ(Vector Quantization),当 m=Cin 时,PQ  等价于标量 k-means。最后,PQ 生成了一个 k^m 大小的隐式码本。
通过对该算法的一个简要的回顾,我们不难看出,PQ 算法的重点点在于对权重的恢复,旨在减小压缩前后的权重信息损失。那么,这个方法在使用时有何缺点呢?我们可以看看如下的图(1)。



  • 本文算法以及算法求解


在这之前,我们一直强调本文的主要目的是恢复网络之后的输出值,而不是权重值。这其中的原因可以追溯到模型对样本的拟合能力,具体如下图所示:  


微信图片_20211202135522.jpg

图1:量化保存权重和本文方法的比较


在图1中,灰色线代表真实的分类器,红色线表示量化保存权重的标准方法训练得到的分类器,而绿色线是本文提出的方法。在in-domain中,可以看出,本文的方法可以比较好的拟合真实的分类器,而标准方法为了拟合out-of-domin领域的数据,在in-domain中反而带来了错误的分类。
笔者认为,该图从本质上显示了神经网络模型在拟合样布分布的时候,有一定的容错能力,也就是说对于噪声点或者异常信息有忽所忽略,而关注于正确的样本的分布,可以避免模型的过拟合。因此,在量化的时候,算法应该同时学习到模型的拟合能力以及泛化信息,而不是仅仅学习模型的参数信息。因此,引申出本文算法的压缩目的,恢复压缩输出信息,而不是PQ的恢复压缩权重信息。
因此,本文提出直接最小化重建输出值误差,在给定的输入值x的前提下,本文旨在缩小输出和重建输出之间的误差。改写目标函数(1)为如下函数(2)。


微信图片_20211202135542.jpg


其中,y=xW表示原始网络的输出,


微信图片_20211202135551.jpg


表示重建输出。


  • EM求解
  • E步:分配向量至最近的聚类中心
  • M步:根据E步组成的集合,更新码本C。
  • 复杂度分析


此类方法的本质上和 k-means 一致,每一列 m 个维度为 d 的子向量选择 k 个聚类中心,其算法的时间复杂度为 mkd。因此一个全连接矩阵的 PQ 的时间复杂为C_inC_out*k 。
对于一个常见的网络结果,卷积层是必不可少的部分,接下来,笔者将介绍如何对卷积层进行压缩处理的。


卷积层量化


  • 4D卷积分离求解

全连接的权重是个4D的矩阵,首先将这个矩阵reshape成一个二维的矩阵(C_inKK)Cout。然后将转换之后的矩阵每一维分离成C_in个大小为(kk)大小的子向量。具体如图(2)所示:


微信图片_20211202135626.jpg

图2:4D卷积的reshape示意图


为了保证结果的一致性,同样的将输入 X 也 reshape 一下。然后运用目标函数中量化权重。


  • 算法复杂度分析


上面说到,对卷积运用算法压缩时,首先将 4D 的矩阵 reshape 成(C_inKK)Cout的矩阵,该矩阵可以看为全连接层的权重。接着,对矩阵运用 PQ压缩。结合3.1中PQ算法的时间复杂度,本文算法的时间复杂度可以表示为(C_inKK)C_out*k(可类比k-means算法)。
其中,K 为卷积核大小,k 为聚类中心的个数。


整个网络的量化


  • 自底向上的逐层量化


本文输入一个 batch 的图像,并从底层向上逐层量化网络。需要注意的是,在量化的过程中,使用的是当前层的输出值而不是非压缩网络的输出值。因为,在实验过程中,使用非压缩网络的输出值会带来一定的误差。


  • 微调 codebook


进行了逐层的量化学习之后,需要对整个网络进行微调。在这个过程中,本文使用非压缩网络作为 teacher 网络来指导压缩网络学习。蒸馏学习中,用 Kl 散度来作为蒸馏的损失,并用 SGD 来更新码本。


  • 整体微调 BN 参数


跟前一个过程相比较,这个过程将BN设置为训练模式,重复上述微调码本的过程。


实验介绍


微信图片_20211202135725.jpg

图(3):resnet-18和resnet-50的压缩结果


图(3)表示的是在resnet-18和resnet-50的压缩大小和TOP-1之间的关系图。对比其他的算法,本文的量化方法在更高的压缩率上保证模型的精度。


微信图片_20211202135757.jpg

表(1) 在给定大小的前提下的模型准确率对比


表(1)表示的是,在限定模型大小的前提下,本文的方法对比目前最优结构的 top-1 准确率,可以发现,本文的方法在半监督的情况下,准确率较高。


Detection实验


微信图片_20211202135816.jpg

表(2)Mask R-CNN 实验


表(2)展示的是在 k=256(8bits) 的情况下,模型压缩因子大概 26 的情况下,量化模型下降约 4 个 AP 。


论文总结和分析


本文作为一篇网络压缩方向的论文,从本质上提出来了独特的量化方法,其提出的保存输出结果而不是保存权重的思想,从而可以拟合in-domain数据并且忽略out-of-domain的数据,这是很值得借鉴和思考的。从这个思想引申出来的量化方法,延续PQ方法的优点。

  1. 本文首要值得借鉴的是对压缩本质的思考,脱离传统的压缩权重的思想,另辟蹊径恢复输出。
  2. 从操作上来说,PQ等方法都是经典的算法,本文方法延续其内容,因此实现上难度不大。


但是,作为一个算法类的研究论文,笔者认为以下的方面还值得继续研究和探讨。

  1. 在训练过程中,k-means的思想是其重点,那么对于算法的复杂度分析部分,是否需要更具体的讨论。
  2. 在本文中,对于样本的采样和计算过程,也是值得继续研究的一个课题。


总体来说,本文的实验结果很充分地说明该方法的有效性。但是,这种带有训练机制的量化方法,从样本的采样,算法复杂度,训练过程上来说,都是一个耗时且不太可控的过程。


作者介绍:


立早,工学硕士,研究方向为模式识别。目前从事人脸识别、检测和神经网络压缩方向的工作。希望能够一直学习,多多交流,不断进步。


参考文献[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image

recognition. CoRR, 2015.

[2] Shuchang Zhou, Zekun Ni, Xinyu Zhou, He Wen, Yuxin Wu, and Yuheng Zou. Dorefa-net:

Training low bitwidth convolutional neural networks with low bitwidth gradients. CoRR, 2016.

[3] Yunchao Gong, Liu Liu, Ming Yang, and Lubomir Bourdev. Compressing deep convolutional

networks using vector quantization. arXiv preprint arXiv:1412.6115, 2014.

[4] Herv´ e J´ egou, Matthijs Douze, and Cordelia Schmid. Product Quantization for Nearest Neigh-

bor Search. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2011.

[5]Jie Hu, Li Shen, and Gang Sun. Squeeze-and-excitation networks. In Conference on Computer

Vision and Pattern Recognition, 2018.

[6] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network.

NIPS Deep Learning Workshop, 2014.

[7] Yu Cheng, Duo Wang, Pan Zhou, and Tao Zhang. A survey of model compression and accel-

eration for deep neural networks. CoRR, 2017.

相关文章
|
计算机视觉
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
276 0
|
3月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
124 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
|
机器学习/深度学习 数据库
【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型
【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型
|
机器学习/深度学习 存储 人工智能
模型推理加速系列 | 03:Pytorch模型量化实践并以ResNet18模型量化为例(附代码)
本文主要简要介绍Pytorch模型量化相关,并以ResNet18模型为例进行量化实践。
|
计算机视觉
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
152 0
|
机器学习/深度学习 计算机视觉 索引
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
573 0
|
机器学习/深度学习 编解码 数据可视化
超越 Swin、ConvNeXt | Facebook提出Neighborhood Attention Transformer
超越 Swin、ConvNeXt | Facebook提出Neighborhood Attention Transformer
171 0
|
编解码 数据可视化 计算机视觉
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(二)
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(二)
289 0
|
机器学习/深度学习 编解码 算法
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(一)
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(一)
267 0

热门文章

最新文章