由于数据隐私限制,多个中心之间的数据共享受到限制,这就影响了联邦学习架构下多中心合作开发高性能深度学习模型的效果。持续学习(Continual Learning)作为点对点联合学习的一种方法,可以通过共享中间模型而不是训练数据来绕过数据隐私的限制,从而促进多中心协作开发深度学习算法。近期不断有研究人员探索联邦持续学习方法(Federated Continual Learning,FCL),即,研究持续学习在联邦学习架构下多中心协作的可行性。
1、背景回顾
1.1 持续学习 (Continual Learning)
首先,我们来回顾一下什么是持续学习。当前,一般认为持续学习 (Continual Learning) 和增量学习(Incremental Learning)、终身学习 (Lifelong Learning) 是等价表述,它们都是在连续的数据流中训练模型,随着时间的推移,更多的数据逐渐可用,同时旧数据可能由于存储限制或隐私保护等原因而逐渐不可用,并且学习任务的类型和数量没有预定义 (例如分类任务中的类别数)。
当一个模型在新的数据集或任务上被重新训练时,深度学习会遇到灾难性遗忘的问题,即深度学习模型会灾难性地遗忘已经学到的旧知识。持续学习技术的目的是让机器学习模型通过新的数据进行更新,同时保留以前学到的知识。持续学习具有两大优点:1) 不需要保存之前任务上学习过的训练数据,从而在节约内存的同时解决了由于物理设备 (例如机器内存 ) 或学习策略 (例如隐私保护 ) 的限制所导致的数据不能被长期存储问题;2) 模型能够保存之前任务所学到的知识,并且能够极大程度地将之前任务学习到的知识运用到未来任务的学习中,从而提高学习效率。
目前,持续学习的方法仍在不断发展中,还没有严格的数学定义。韩等在文章 [1] 中给出了一幅持续学习的示意图,如图 1 所示,“在连续学习过程中,智能体逐个对每个连续的非独立均匀分布流数据示例进行学习,并且该智能体对每个示例只进行一次访问。这种学 习方式与动物学习过程更为接近。如果我们忽略各个任务的先后次序问题,单独训练每个任务,这将导致灾难性遗忘,这也是连续学习一直以来所面临的最大问题。因此,连续学习的本质是通过各种手段高效地转化和利用已经学过的知识来完成新任务的学习,并且能够极大程度地降低遗忘带来的问题。[1]”。
图 1. 连续学习示意图 [1]
到目前为止,已经有许多持续学习算法,主要分为三种类型:回放方法(Memory Reply)、动态结构模型(Dynamic Structural Models)和正则化方法(Regularization Model)。1)回放方法从以前的数据集中选择有代表性的样本,以保留所学的知识。该方法的研究重点是:「要保留旧任务的哪部分数据,以及如何利用旧数据与新数据一起训练模型」,这对于克服数据存储的限制是可行的,但对于多中心合作来说是不可行的,因为由于数据隐私问题,其他中心的样本是不可用的 [6-8]。2)动态结构模型为多任务场景设计动态网络架构或动态参数,网络的各个部分(如某些权重或某些神经元连接)负责对应的各个任务 [9][10]。3)正则化方法使用相同的传统神经网络,但在损失函数中加入了新的正则化项,以保留学习知识的重要参数。该方法的主要思想是「通过给新任务的损失函数施加约束的方法来保护旧知识不被新知识覆盖」[11][12]。
1.2 联邦持续学习 (Federated Continual Learning)
联邦学习的主要思想是去中心化,将模型下放到各个参与联合训练的客户端本地,基于本地客户端的数据进行模型训练,而不需要将用户数据上传到中央服务器,从而保护各个客户端中的隐私。然而,大多数现有方法假设整个联邦学习框架的数据类别随着时间的推移是固定不变的。实际情况中,已经参与联邦学习的客户端经常可能收集到新类别的数据,但考虑到每个客户端本地设备的存储空间非常有限,很难对其收集到的所有类别都保存下足够数量的数据,这种情况下,现实世界中的联邦学习模型很可能对于旧类数据的性能遇到严重的灾难性遗忘。此外,联邦学习框架往往还会有持续不断的新用户的参与,这些新用户往往有着大量的新数据类别,这样会进一步加剧全局模型的灾难性遗忘。
近年来,有研究人员陆续提出将联邦学习和持续学习的思想结合起来构建联邦持续学习框架。然而,直接简单的将联邦学习和持续学习相结合会带来新的问题。首先是,联邦持续学习仍然面临灾难性遗忘,此外还会带来来自其他客户端潜在的干扰。因此我们需要有选择地利用来自其他客户端的知识,以最小化客户端间的干扰,最大化进行客户端间的知识转移。第二个问题是联邦学习之间进行通信交换知识时,可能会造成通信成本过大,“通信代价” 成为了一个不可忽视的问题。
我们通过四篇近期发表的文章概览联邦持续学习的最新研究进展。
- 第一篇文章提出了一个新的联邦持续学习框架 —— 联邦加权客户端间的传输(FedWeIT)。FedWeIT 将各个客户端的本地模型参数分解为稠密基参数(a dense base parameter)和稀疏的任务自适应参数( sparse task-adaptive parameters),以便进行更高效地通信 [2]。
- 第二篇文章提出了一种全新的全局 - 本地遗忘补偿 (GLFC) 模型,即同时从全局和本地两个角度出发,尽可能地减弱灾难性遗忘,使得联邦学习最终可训练一个全局增量模型 [3]。
- 第三篇文章提出了一种联邦互相关和持续学习方法。对于异构性问题,该方法利用未标记的公共数据进行通信,并构造互相关矩阵来学习域偏移(domain shift)下的可概括性的表示。同时,对于灾难性遗忘,在本地更新中利用跨域和本域信息进行知识蒸馏,有效地提供域间和域内知识而不泄露参与者的隐私 [4]。
- 第四篇文章提出了一个联邦学习架构,称为联邦多语者 TTS 系统 Fed-Speech。该架构使用渐进式修剪掩码来分离参数,以保留说话人的语调。此外,应用选择性掩码来有效地重用任务中的知识。最后,引入 private speaker embedding 以保持用户的隐私 [5]。
2、Federated Continual Learning with Weighted Inter-client Transfer
持续学习和联邦学习在现实世界的深度神经网络中都很重要。然而,对于每个客户端从私有的循环数据流中学习一连串任务的情况,却很少有人进行研究。这种联邦学习的问题给持续学习带来了新的挑战,比如,如何有效利用其他客户端的知识同时防止不相关知识的干扰?为了解决这些问题,本文提出了一个新颖的联邦持续学习框架,即联邦加权客户端间传输(Federated Weighted Inter-client Transfer,FedWeIT),该框架将网络工作权重分解为全局的联邦参数( global federated parameters)和稀疏的特定任务参数( sparse task-specific parameters),每个客户端可以通过对其特定任务参数进行加权组合,从其他客户端那里获得选择性知识。具体是通过中央服务端获得其他客户端的 task-specific parameters,再对这些参数进行加权聚合得到 selective knowledge,从而最大化相似任务之间共识知识的传递。FedWeIT 最大限度地减少了不兼容任务之间的干扰,并且在学习过程中允许客户端之间的积极知识转移。
作者将 Fed-WeIT 与现有的联邦学习和连续学习方法在不同程度的客户端之间的任务相似性进行了验证,本文模型明显较优,通信成本大大降低。代码已公布在 https://github.com/wyjeong/FedWeIT。
2.1 方法介绍
受人类从间接经验中学习过程的启发,作者引入了一种新的联邦学习环境下的持续学习 --- 联邦持续学习(Federated Continual Learning,FCL)。FCL 假设多个客户端在私有数据流中的任务序列上进行训练,同时与中央服务器交互所学到的参数。在标准的持续学习中(在一台机器上),模型从一连串的任务 {T (1),T (2),...,T (T)} 中反复学习,其中 T (t) 是第 t 个任务的标记数据集。假设现实情况如下,任务序列是一个具有未知到达顺序的任务流,这样,模型只能在任务 t 的训练期访问 T (t),之后就无法访问了。给定 T (t) 和到目前为止学到的模型,任务 t 的学习目标如下:
然后,将传统的持续学习扩展到有多个客户端和一个中央服务器的联邦学习环境。假设有 C 个客户端,每个客户端 c_c∈{c_1, . . . , c_C } 在一个私人可访问的任务序列 {T^(1)_c , T^(2)_c , ..., T^(t)_c }⊆ T 上训练一个模型。需要注意的是,在步骤 t 收到的跨客户端任务之间没有关系。
现在的目标是通过与中央服务器沟通模型参数,有效地在他们自己的私有任务流上训练 C 类持续学习模型,中央服务器汇总每个客户端发送的参数,并将它们重新分配给客户端。在联邦持续学习框架中,将参数汇总为一个全局参数 θ_G,允许客户端间的知识转移,因为客户端 c_i 在第 q 轮学到的任务可能与客户端 c_j 在第 r 轮学到的任务相似或相关。然而,作者分析使用单一的综合参数 θ_G 可能是实现这一目标的次优选择,因为来自不相关任务的知识可能没有用处,甚至可能会通过将其参数改变到不正确的方向来阻碍每个客户端的训练,作者将此描述为客户端间的干扰。
另一个实际上也很重要的问题是通信效率。从客户端到中央服务器,以及从中央服务器到客户端的参数传输都会产生很大的通信成本,这对于持续学习环境来说是有问题的,因为客户端可能会在无限的任务流上进行训练。如前所述,造成这些问题的主要原因是,在多个客户端学到的所有任务的知识被存储在一组参数 θ_G 中。然而,为了使知识转移有效,每个客户端应该有选择地只利用在其他客户端训练的相关任务的知识。这种选择性转移也是最小化客户端间干扰的关键,因为它不考虑可能干扰学习的不相关任务的知识。
作者通过分解参数来解决这个问题,这些参数分为三种不同的类型,具有不同的作用:全局参数(θ_G),捕获所有客户端的全局和通用知识;本地基础参数(B),捕获每个客户端的通用知识;任务适应性参数(A),用于每个客户端的每个具体任务。将一组在持续学习客户端 c_c 的任务 t 的模型参数 θ^(t)_c 定义如下:
其中,B^(t)_c 是第 c 个客户端的基本参数集,在客户端的所有任务中共享。m^(t)_c 是稀疏向量掩码的集合,它允许对任务 t 的 B^(t)_c 进行适应性转换,A^(t)_c 是客户端 c_c 的稀疏任务适应性参数集合。L 是神经网络中的层数,I_l、O_l 分别是第 l 层权重的输入和输出维度。
上式中的第一项允许有选择地利用全局知识。作者希望每个客户端的基础参数 B^(t)_c 能够捕获所有客户端的所有任务中的通用知识。如图 2(a),在每一轮 t 中用前一次迭代的全局参数 θ^(t-1)_G 来初始化,汇总从客户端发送的参数。这使得 B^(t)_c 也能从关于所有任务的全局知识中受益。然而,由于 θ^(t-1)_G 也包含与当前任务无关的知识,我们不是原封不动地使用它,而是学习稀疏掩码 m^(t)_c,只为给定的任务选择相关参数。这种稀疏的参数选择有助于最大限度地减少客户端之间的干扰,从而实现高效的通信。上式中的第二项是任务适应性参数 A^(t)_c。对参数进行加法分解处理后,能够学会捕捉第一项没有捕捉到的关于任务的知识,因此将捕捉到关于任务 T^(t)_c 的具体知识。上式中的最后一项描述了加权的客户端间知识转移。我们拥有一组从中央服务器传输的参数,其中包含了所有客户端的所有任务适应性参数。为了有选择地利用这些来自其他客户端的间接经验,进一步在这些参数上分配注意力 α^(t)_c,并采取加权组合的方式。通过学习这种注意力,每个客户端可以只选择有助于学习给定任务的相关任务适应性参数。尽管作者将 A^(j)_i 设计成高度稀疏的,在实践中使用大约 2-3% 的全参数内存,但发送所有的任务知识仍然是不可取的。因此,作者选择从知识库中传输所有时间步骤的随机抽样的任务适应性参数,根据经验,作者发现这种处理方式在实践中取得了良好的效果。
图 2. FedWeIT 更新。(a) 客户端发送稀疏化的联邦参数 B_c ⊙m^(t)_c 。之后,中央服务器将聚合的参数重新分配给客户端。(b) 知识库存储了客户端先前的任务适应性参数,每个客户端有选择地利用这些参数,并有一个注意力掩码
训练。我们通过优化以下目标函数来学习可分解参数 θ^(t)_c:
其中,L 是损失函数,Ω(・) 是所有任务自适应参数和掩码变量的稀疏性诱导正则化项,以使它们变得稀疏。第二个正则化项用于追溯更新过去的任务适应性参数,通过反映基础参数的变化,帮助任务适应性参数保持目标任务的原始解决方案。∆B^(t)_c 是指当前时间段和前一个时间段的基础参数之差。∆A^(i)_c 是任务 i 在当前和前一时间段的任务适应性参数之间的差异。这种正则化处理对于防止灾难性的遗忘至关重要。λ1 和 λ2 是控制两个正则化作用的超参数。
客户端。在每个轮次 r,每个客户端 c_c 用中央服务器发送的全局参数的非零分量部分更新其基础参数;也就是说,B_c (n) = θ_G (n),其中,n 是全局参数的非零元素。它为新任务获得一个稀疏的基础参数 ^Bb^(t)_c 和任务适应性参数 A^(t)_c,将这两个参数发送到中央服务器,与 FCL 基线方法相比,成本更低。FCL 基线方法需要 | C|×R×|θ| 的资源用于客户端到中央服务器的通信,而 FedWeIT 需要 | C|×(R×|Bb|+|A|),其中 R 是每个任务的通信轮数,|・| 是参数数量。
中央服务器。中央服务器首先对所有客户端发送的基础参数进行汇总,取其加权平均值 θ_G。然后,将 θ_G 广播给所有客户端。t-1 的任务适应性参数在训练任务 t 期间在每个客户端广播一次。FCL 基线需要 | C|×R×|θ| 的中央服务器 - 客户端通信成本,而 FedWeIT 需要 | C|×(R×|θG|+(|C|-1)×|A|),其中 θ_G、A 是高度稀疏的。算法 1 中描述了 FedWeIT 的算法。
2.2 实验情况介绍
作者验证了 FedWeIT 在不同的任务序列配置下与基线方法(Overlapped-CIFAR-100 和 NonIID-50)的对比。1) Overlapped-CIFAR-100:将 100 个 CIFAR-100 数据集类分组为 20 个 NonIID 超类任务。然后,从 20 个任务中随机抽取 10 个任务并拆分实例,为每个任务重叠的客户端创建一个任务序列。2) NonIID-50:使用以下八个基准数据集:MNIST、CIFAR-10/-100、SVHN、Fashion MNIST,Not MNIST 和 TrafficSigns。将 8 个数据集中的类划分为 50 个 NonIID 任务,每个任务由 5 个类组成,这些类与用于其他任务的类不相交。
实验中用到的对比模型如下:1)STL:单任务学习每个到达的任务。2) EWC:每个客户端进行个人持续学习。3) Stable-SGD:每个客户端持续学习 Stable-SGD。4) APD:每个客户端使用 APD 进行个人持续学习。5) FedProx:使用 FedProx 算法的 FCL。6) Scaffold :使用 Scaffold 算法的 FCL。7) FedCurv:使用 FedCurv 算法的 FCL。8) FedProx-[model]:使用带有 [model] 的 FedProx 算法进行训练的 FCL。9) FedWeIT:FedWeIT 算法。
表 1 给出了在两个数据集上完成(联邦)连续学习后,每项任务的最终平均性能。我们观察到,基于 FedProx 的 FCL 方法与没有联邦学习的相同方法相比,会降低连续学习(CL)方法的性能。这是因为在不相关的任务中学习的所有客户端参数的汇总导致了对每个任务学习的严重干扰,这导致了灾难性的遗忘和次优的任务适应性。Scaffold 在 FCL 上的表现很差,因为它对本地梯度的正则化处理对 FCL 是有害的,因为所有的客户端都是从不同的任务序列中学习的。虽然 FedCurv 减少了任务间的参数差异,但它不能最大限度地减少任务间的干扰,这导致它的表现不如单机 CL 方法。另一方面,FedWeIT 在两个数据集上的表现都明显优于单机 CL 基线和 FCL 基线。即使有更多的客户端(C = 100),FedWeIT 也一直优于所有基线(图 3)。这种改进主要归功于 FedWeIT 有选择地利用其他客户端的知识来迅速适应目标任务的能力,并获得更好的最终性能。
表 1. 5 个客户端在 FCL 期间对两个数据集的平均每任务表现(分数 = 1.0)。在完成所有学习阶段的 3 次单独试验后,作者测量了任务准确性和模型大小。作者还测量了训练每个任务的 C2S/S2C 通信成本
图 3. 训练最后两个(第 9 和第 10 个)任务时的平均任务适应性,有 5 个和 100 个客户端
对新任务的快速适应是客户端间知识转移的另一个明显优势。为了进一步证明本文方法在更大的网络中的实用性,作者在 ResNet-18 的 NonIID 数据集上进行了实验(表 2),FedWeIT 在使用较少参数的情况下仍然明显优于最强基线(FedProx-APD)。
表 2. 使用 ResNet-18 在 NonIID-50 数据集上的 FCL 结果
此外,作者研究了在持续学习过程中过去任务的表现如何变化,以了解每种方法的灾难性遗忘的严重程度。图 4 给出了 FedWeIT 和 FCL 基线在第 3、第 6 和第 8 个任务上的表现。我们观察到,FCL 基线比带有 EWC 的本地持续学习遭受了更严重的灾难性遗忘,这是因为客户端间的干扰,来自其他客户端的不相关任务的知识覆盖了过去的任务知识。与此相反,本文模型没有显示出灾难性遗忘的迹象。这主要是由于选择性地利用了通过全局 / 任务自适应参数从其他客户端那里学到的先验知识,这使得它能够有效地缓解客户端间的干扰。FedProx-APD 也不存在灾难性遗忘的问题,但由于知识转移的无效性,它们的性能较差。
图 4. 灾难性遗忘。在 NonIID-50 的联邦持续学习过程中,在第 3、6 和 8 个任务中关于当前任务适应性的性能比较