ImageNet-1K压缩20倍,Top-1精度首超60%:大规模数据集蒸馏转折点

简介: ImageNet-1K压缩20倍,Top-1精度首超60%:大规模数据集蒸馏转折点

该工作是目前唯一实现了大规模高分辨率数据集蒸馏的框架


过去几年,数据压缩或蒸馏任务引起了人们的广泛关注。通过将大规模数据集压缩成具有代表性的紧凑子集,数据压缩方法有助于实现模型的快速训练和数据的高效存储,同时保留原始数据集中的重要信息。数据压缩在研究和应用中的重要性不可低估,因为它在处理大量数据的过程中起着关键作用。通过采用先进的算法,数据压缩取得了显著的进展。然而,现有解决方案主要擅长压缩低分辨率的小数据集,这种局限性是因为在双层优化过程中执行大量未展开的迭代会导致计算开销巨大。


MBZUAI 和 CMU 团队的最新工作 SRe2L 致力于解决这一问题。该工作是目前唯一实现了大规模高分辨率数据集蒸馏的框架,可以将 Imagenet-1K 原始的 1.2M 数据样本压缩到 0.05M (压缩比 1:20),使用常用的 224x224 分辨率进行蒸馏,在 ImageNet-1K 标准验证集(val set)上取得了目前最高的 60.8% Top-1 精度远超之前所有 SOTA 方法,如 TESLA (ICML’23) 的 27.9% 的精度。


该工作目前已完全开源,包括蒸馏后的数据,蒸馏过程和训练代码。


 

论文:https://arxiv.org/abs/2306.13092

代码:https://github.com/VILA-Lab/SRe2L


数据集蒸馏 / 压缩任务的定义和难点


传统的模型蒸馏是为了得到一个更加紧凑的模型,同时保证模型性能尽可能得高。与之不同,数据集蒸馏任务关注于如何得到一个更紧凑同时更具表达能力的压缩后的数据集,数据样本相比原始数据集会少很多(节省从头训练模型的计算开销),同时模型在该压缩后的数据集上训练,在原始数据验证集上测试依然可以得到较好的精度。


数据集蒸馏任务的主要难点在于如何设计一个生成算法来高效可行地生成需要的样本,生成的样本需要包含 / 保留原始数据集中核心的信息。目前比较常用的方法包括梯度匹配、特征匹配、轨迹匹配等等,但是这些方法的一个共同缺点就是没法 scale-up 到大规模数据集上。比如,由于计算量和 GPU 显存的限制,无法蒸馏标准的 ImageNet-1K 或者更大的数据集。计算量和 GPU 显存需要过大的主要原因在于这些方法生成过程需要匹配和保存的信息过多,目前很多 GPU 显存没法容纳所有需要匹配的数据信息,导致这些方法大多数只适用于较小的数据集。


针对这些问题,新论文通过解耦数据生成和模型训练两个步骤,提出了一个三阶段数据集蒸馏算法,蒸馏生成新数据过程只依赖于在原始数据集上预训练好的模型,极大地降低了计算量和显存需求。


解决方案核心思路


之前很多数据集蒸馏方法都是围绕样本生成和模型训练的双层优化 (bi-level optimization) 来展开,或者根据模型参数轨迹匹配 (trajectory matching) 来生成压缩后的数据。这些方法最大的局限在于可扩展性不是很强,需要的显存消耗和计算量都很大,没法很好地扩展到完整的 ImageNet-1K 或者更大的数据集上。


针对这些问题,本文作者提出了解耦数据生成和模型训练的方法,让原始数据信息提取过程和生成数据过程相互独立,这样既避开了更多的内存需求,同时也避免了如果同时处理原始数据和生成数据导致原始数据中的噪声对生成数据造成偏差 (bias)。


具体来说,本文提出了一种新的数据集压缩框架,称为挤压、恢复和重新标记 (SRe2L),如下图所示,该框架在训练过程中解耦模型和合成数据双层优化为两个独立的操作,从而可以处理不同规模的数据集、不同模型架构和高图像分辨率,以实现有效的数据集压缩目的。

 

本文提出的方法展示了在不同数据集规模的灵活性,并在多个方面表现出多种优势:1)合成图像的任意分辨率,2)高分辨率下的低训练成本和内存消耗,以及 3)扩展到任意评估网络结构的能力。本文在 Tiny-ImageNet 和 ImageNet-1K 数据集上进行了大量实验,并展示出非常优异的性能。


三阶段数据集蒸馏框架


本文提出一个三阶段数据集蒸馏的框架:


第一步是将整个数据集的核心信息压缩进一个模型之中,通过模型参数来存储原始数据集中的信息,类似于我们通常进行的模型训练;第二步是将这些高度抽象化的信息从训练好的模型参数中恢复出来,本文讨论了多种不同损失和正则函数对于恢复后图像的质量以及对数据集蒸馏任务的影响;第三步也是提升最大的一步:对生成的数据进行类别标签重新校准。此处作者采用了 FKD 的方式,生成每个 crop 对应的 soft label,并作为数据集新的标签存储起来。


三阶段过程如下图所示:



性能及计算能效比


在 50 IPC 下 (每个类 50 张图),本文提出的方法在 Tiny-ImageNet 和 ImageNet-1K 上实现了目前最高的 42.5% 和 60.8% 的 Top-1 准确率,分别比之前最好方法高出 14.5% 和 32.9%。


此外,本文提出的方法在速度上也比 MTT 快大约 52 倍 (ConvNet-4) 和 16 倍 (ResNet-18),并且在数据合成过程中内存需求更少,相比 MTT 方法分别减少了 11.6 倍 (ConvNet-4) 和 6.4 倍 (ResNet-18),具体比较如下表所示:

 


实验结果


实验设置


该工作主要聚焦于大规模数据集蒸馏,因此选用了 ImageNet-Tiny 和 ImageNet-1K 两个相对较大的数据集进行实验。对于骨干网络,本文采用 ResNet-{18, 50, 101} 、ViT-Tiny 和自己构建的 BN-ViT-Tiny 作为目标模型结构。对于测试阶段,跟之前工作相同,文本通过从头开始训练模型来评估压缩后数据集的质量,并报告 ImageNet-Tiny 和 ImageNet-1K 原始验证集上的测试准确性。


在 full ImageNet-1K 数据集上的结果



可以看到,在相同 IPC 情况下,本文实验结果远超之前方法 TESLA。同时,对于该方法蒸馏得到的数据集,当模型结构越大,训练得到的精度越高,体现了很好的一致性和扩展能力。


下图是性能对比的可视化结果,可以看到:对于之前方法 TESLA 蒸馏得到的数据集,当模型越大,性能反而越低,这对于大规模数据集蒸馏是一个不好的情况。与之相反,本文提出的方法,模型越大,精度越高,更符合常理和实际应用需求。



压缩后的数据可视化

 

从上图可以看到,相比于 MTT 生成的数据(第一和第三行),本文生成的数据(第二和第四行)不管是质量、清晰度还是语义信息,都明显更高。


蒸馏过程图像生成动画


此外,包含 50、200 个 IPC(具有 4K 恢复预算)的压缩数据集文件可从以下链接获取:https://zeyuanyin.github.io/projects/SRe2L/


将该方法扩展到持续学习任务上的结果


 

上图展示了 5 步和 10 步的增量学习策略,将 200 个类别(Tiny-ImageNet)分为 5 个或 10 个学习步骤,每步分别容纳 40 个和 20 个类别。可以看到本文的结果明显优于基线(baseline)性能。


更多细节欢迎阅读其论文原文和代码。

相关文章
|
缓存 前端开发 JavaScript
【面试题】大厂面试官:你做过什么有亮点的项目吗?
【面试题】大厂面试官:你做过什么有亮点的项目吗?
277 0
|
人工智能 自然语言处理 运维
AIGC系列文章汇总
AIGC系列文章汇总(2024年3月8日更新)
3452 4
AIGC系列文章汇总
|
数据可视化 测试技术 PyTorch
智谱ChatGLM3魔搭最佳实践教程来了!
ChatGLM3-6B 是 ChatGLM 系列最新一代的开源模型,在保留了前两代模型对话流畅、部署门槛低等众多优秀特性的基础上
|
数据可视化 uml
UML图讲解(关联关系,单向关联,双向关联,自关联,组合关系,依赖关系,继承关系,实现关系)
UML图讲解,关联关系,单向关联,双向关联,自关联,组合关系,依赖关系,继承关系,实现关系。
6427 0
UML图讲解(关联关系,单向关联,双向关联,自关联,组合关系,依赖关系,继承关系,实现关系)
|
PHP 开发者
深入探索PHP的命名空间与自动加载机制
在现代PHP开发中,随着项目的不断扩大和复杂度的增加,如何有效地组织代码、避免命名冲突以及提高性能成为了开发者面临的重要问题。本文将详细探讨PHP中的命名空间和自动加载机制,解释它们是如何帮助解决这些问题的,并提供一些最佳实践建议。
275 20
|
人工智能 自然语言处理 前端开发
OpenAI 12天发布会全解析 | AI大咖说
OpenAI近日宣布将在12个工作日内每天进行一场直播,展示一系列新产品和样品。首日推出GPT-o1正式版,性能大幅提升;次日展示Reinforcement Fine-Tuning技术,提高模型决策质量;第三天推出Sora,实现高质量视频生成;第四天加强Canvas,提升多模态创作效率;第五天发布ChatGPT扩展功能,增强灵活性;第六天推出ChatGPT Vision,实现多模态互动;第七天推出ChatGPT Projects,优化项目管理。这些新技术正改变我们的生活和工作方式。
1698 9
|
关系型数据库 MySQL 数据库
【已解决】[图文步骤] message from server: “Host ‘172.17.0.1‘ is not allowed to connect to this MySQL server“
【已解决】[图文步骤] message from server: “Host ‘172.17.0.1‘ is not allowed to connect to this MySQL server“
579 0
MybatisPlus3---常用注解,驼峰转下滑线作为表明 cteateTime 数据表中的 cteate_time,@TableField,与数据库字段冲突要使用转义字符“`order`“,is
MybatisPlus3---常用注解,驼峰转下滑线作为表明 cteateTime 数据表中的 cteate_time,@TableField,与数据库字段冲突要使用转义字符“`order`“,is
|
C语言
C语言字符串、宏定义及主函数介绍
C语言字符串、宏定义及主函数介绍
545 0