CVPR 2021首次!王言治教授和色拉布团队把GAN压缩22倍,性能比原始模型还高

简介: 王言治教授团队与美国色拉布公司(Snap Inc.)首次提出了一种GAN剪枝的方法,除了使压缩时间减少了四个数量级以外,还在远低于原始计算量的条件下,获得来比原有模型更好的性能,并且实现了更高的生成图片质量。论文已被CVPR 2021收录。

神经网络结构搜索有助于得到在计算机视觉任务中效果更好的深度神经网络,同时可以减小模型尺寸, 提高运行效率,实现移动端高速处理。


近年来,深度神经网络在图像、语音、文本等领域的进展使得其广泛应用在不同功能的系统中,包括图像分类、目标识别、语义分割、语音处理等。


不同于判别式模型只需要得到比较简单的判断结果(如分类结果),生成式模型需要生成更加复杂的图像结构。


相比于前者,后者通常需要更大的计算量和更大规模的模型,这使得将生成式模型压缩以提高运行效率面临巨大挑战。


为此,美国东北大学王言治教授研究团队与美国色拉布公司(Snap Inc.)的创意视觉研究组共同提出了压缩与教学技术。论文已经被CVPR 2021会议收录。

20.jpg论文地址:https://arxiv.org/abs/2103.03467

项目地址:https://dejqk.github.io/GAN_CAT/


通过将inception模块引入生成模型并进行神经网络结构搜索,从而使搜索空间扩展至包含多种不同核尺寸的卷积模块。


并且利用知识蒸馏用搜索过程中训练的大模型指导搜索出的小模型的训练过程,在远低于原始超大规模生成模型计算量的条件下实现优于原始超大规模生成模型的生成图片质量。21.jpg与原有的巨型生成模型相比,论文的方法得到的模型在压缩的同时可以生成更高质量的图片(FID越低图片质量越好),并且实现了SOTA的性能-效率取舍。

网络模型

实现高效率网络模型主要包括网络压缩和模型结构搜索两种方式。

相比于前者,后者通常获得的网络结构更多样,效果也更优,并且现代压缩算法通常也包含搜索步骤。

然而,直接将传统的用于压缩或搜索的方法用在生成模型中,通常会导致模型性能具有较大损失,特别是生成的图像画质通常较差,容易产生额外的噪点和花斑。

此外,生成模型因其计算量庞大,通常训练时间较长,直接使用网络搜索一般不容易得到最优解,使得网络结构优化面临更多的挑战。

而且,对于高复杂度的大型网络(如GauGAN),传统方法通常导致性能损失更为明显。

因此,研究出高速有效的网络结构搜索方法和训练方法,对于提高生成模型的性能-效率取舍具有重要意义。

为了保证压缩后的生成模型产生出高质量的图像,需要解决的几个主要问题是:

网络搜索空间需要足够广,使得搜索过程的自由度足够高;

网络搜索的过程需要足够快,使得搜索过程中遍历的备选模型尽可能多,迭代过程也尽可能快(如超参调优等);

搜索出的模型的在训练时需要充分利用已有信息,尽量保证模型得到充分训练。

为了扩展网络搜索空间,传统方法通过在不同类型操作之间进行选择来实现网络结构搜索。

与之相比,近年来提出的AtomNAS算法通过引入Inception模块,将多种不同类型的神经层同时使用,在提升模型性能的同时,将搜索过程和训练过程合并,显著降低了模型搜索所需的额外计算开销。

受此启发,作者将多种不同核尺寸的卷积模块同时使用,并同时包含普通卷积模块与depthwise卷积模块,实现网络搜索空间的扩充。

所用的模块包含1x1、3x3、5x5三种不同核尺寸的卷积模块,并且同时使用了普通卷积模块与depthwise卷积模块。22.jpg用在生成模型中的Inception残差模块


该模块使用不同核尺寸的卷积模块,并且同时使用普通卷积模块与depthwise卷积模块,在搜索过程中有助于扩充搜索空间。


作者将这一模块用在大型网络GauGAN中,用来替代其中主干中的卷积层和第一个归一化层中使用的SPADE模块中的卷积网络。23.jpg将Inception模块用在GauGAN的SPADE模块中


此外,主干中的第二个卷积层和分支中的卷积层可以使用普通的归一化层,而不需要使用计算量很大的SPADE模块。


网络空间的扩展不仅使得搜索过程简单高效,而且可以提高网络的灵活度,使得模型在相同计算量下能实现更高性能。


网络搜索



在网络空间扩展的同时,提高网络搜索效率成为网络搜索的主要问题。参照传统的网络搜索办法,作者选择使用归一化层的权重模大小作为搜索依据。


论文提出的搜索过程直接参考目标计算量,使用半分法来确定网络压缩所需的权重阈值。


24.jpg使用半分法根据目标计算量确定压缩阈值对网络进行压缩


首先根据训练好的网络中归一化层的权重大小预设搜索上界和下界,由此算出一个权重阈值对网络进行预压缩,根据预压缩所得网络的大小与目标大小的相对关系,调整上下界,直至所得网络大小满足要求。


相比于文献中提出的生成模型压缩方法,论文提出的方法可以使得压缩过程所需时间减少至少四个数量级。

25.png不同压缩方法在不同数据集和不同生成模型上所需压缩时间比较


搜索出模型结构后,通常原有模型的权重无法直接使用,需要重新训练。由于模型较小,训练过程中可能会出现较难优化甚至不收敛的问题。


为使得训练结果较好,文献中提出先额外训练一个较大的模型作为导师,再使用此模型训练搜索出的小模型。然而这种方法增加了额外的训练开销,使得搜索和训练过程的更加冗长。


为此,作者提出使用用于搜索的原有大模型作为导师模型,相当于将用于搜索的模型再次利用,进行知识蒸馏。


这样,大模型不仅用来作为小模型的导师指导训练,也因为其本身的结构特征用作网络结构搜索。这种方法可以最大限度地利用大模型,减少训练开销和时间。


知识蒸馏


知识蒸馏技术通常包含直接蒸馏间接蒸馏。


前者一般只利用网络的最终输出进行比较实现蒸馏的目的,后者则利用网络内部卷积层的中间结果进行比较,作为指导原则。者则选取后者对搜索的结构进行训练。


然而,由于作为导师的大模型的中间层特征与经过搜索压缩后得到的作为学生的小模型的中间层特征的通道数存在差异,无法通过直接比较完成蒸馏的目的。


文献中引入一个额外的可训练的线性层,将学生模型的特征映射到导师模型特征的空间中。


这样做不仅会导致引入额外的训练层,增加训练复杂度,而且蒸馏办法较为间接,可能效果并非最优。


为此,作者采用一种更直接的办法,通过比较导师模型特征和学生模型特征,通过最大化二者相似度,实现蒸馏目的。


Hinton等人于2019年通过详细分析,比较了不同的相似度判断标准,并且提出一种称为中心化核对齐(CKA)的指标。


作者采用类似的核对齐(KA)指标,并且发现中心化对最终结果不具有决定性的影响。


如下图所示,作者通过计算导师模型特征与学生模型特征的核对齐指标并将其最大化作为损失函数进行训练,实现知识蒸馏的目的

26.jpg与传统的通过引入额外的训练线性层进行知识蒸馏相比(左图),论文提出一种直接比较特征相似度的方法进行蒸馏(右图)。


结果分析


作者在多个数据集和多种类型的网络上验证了论文提出的方法,并且与原有的大模型和文献中已有的生成模型进行了比较。


论文提出的方法在将生成模型计算量压缩数十倍的基础上,仍然可以获得比原有模型更好的性能(高mIoU或低FID)并且与文献中的方法相比,实现了SOTA的性能-效率取舍。


27.jpg不同压缩方法在不同数据集和不同生成模型上性能比较

 

为了更直观地展示结果,作者在不同数据集和模型上将压缩模型生成的图片和原有模型生成的图片进行对比。


可以看到,论文提出的模型在远低于原有模型计算量的条件下,可以生成更高质量的图片。

28.jpg                                Horse2Zebra数据集上压缩CycleGAN模型29.jpg                                      Map2Aerial数据集上压缩Pix2pix模型30.jpg31.jpg                                 Cityscapes数据集上压缩GauGAN模型

作者介绍


第一作者金庆,美国东北大学ECE系PhD一年级学生。


主要研究领域为Deep Learning algorithm,研究内容已经在发表在CVPR,AAAI等机器学习和计算机视觉会议中。





相关文章
|
JavaScript
js 使用fetch来上传文件 formdata()
js 使用fetch来上传文件 formdata()
【全是精华】Token的获取和使用-FastApi版
【全是精华】Token的获取和使用-FastApi版
1490 0
|
编解码 自然语言处理 算法
开源版图生视频I2VGen-XL:单张图片生成高质量视频
VGen是由阿里巴巴通义实验室开发的开源视频生成模型和代码系列,具备非常先进和完善的视频生成系列能力
|
8月前
|
监控 网络安全
网页显示HTTP错误503怎么办?HTTP错误503解决方法
HTTP 503错误表示服务器暂时无法处理请求,通常是由于服务器过载或维护导致。常见解决方法包括:1. 等待一段时间再刷新页面;2. 检查服务器负载;3. 确认服务器是否在维护;4. 检查配置错误;5. 联系服务提供商。通过这些步骤,用户和管理员可以有效排查并解决该问题。
9796 3
|
11月前
|
监控 关系型数据库 MySQL
Ubuntu24.04安装Librenms
此指南介绍了在Linux系统上安装和配置LibreNMS网络监控系统的步骤。主要内容包括:安装所需软件包、创建用户、克隆LibreNMS仓库、设置文件权限、安装PHP依赖、配置时区、设置MariaDB数据库、调整PHP-FPM与Nginx配置、配置SNMP及防火墙、启用命令补全、设置Cron任务和日志配置,最后通过网页完成安装。整个过程确保LibreNMS能稳定运行并提供有效的网络监控功能。
|
11月前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
1434 30
|
负载均衡 监控 算法
每个程序员都应该知道的 6 种负载均衡算法
每个程序员都应该知道的 6 种负载均衡算法
1455 2
|
网络安全 虚拟化 Windows
windows 11安装openSSH server 遇到的"kex_exchange_identification: read: Connection reset"问题
windows 11安装openSSH server 遇到的"kex_exchange_identification: read: Connection reset"问题
1981 60
|
SQL 运维 数据库
MSSQL性能调优实战:索引策略、查询优化与并发控制的精细操作
在Microsoft SQL Server(MSSQL)的日常运维与优化中,实现高效、稳定的数据库性能是每位数据库管理员和开发者的核心任务
1014 1
|
机器学习/深度学习 算法
机器学习中最常见的四种分类模型
机器学习中最常见的四种分类模型
1278 10