2.2 实验分析
为了验证 De-Pois 的有效性,作者使用四个不同领域的数据集进行实验,包括手写数字识别(MNIST)、图像分类(CIFAR-10)、非线性二值分类(Fourclass dataset)和房屋销售价格预测(House pricing dataset)。作者引入一些主流的 attack-specific 的防御措施与 De-Pois 进行性能比对,所采用的评价指标包括:准确度、召回率和 F1 分数。为了评估 De-Pois 合成数据生成的性能,作者采用以下指标:Inception Score(IS)、Fréchet Inception Distance(FID)、Wasserstein Distance(WD)和平均欧氏距离(Average Euclidean Distance,AED)。
对于合成数据生成,作者使用固定结构的 cGAN 来生成数据。在生成器网络 Gc 中,将具有 100 维正态分布和类别标签的噪声先验 z_c 或具有 128 个小批量大小的可信数据集的回归值 y 的组合确定为 Gc 最底层的输入。Gc 使用 3 个 Leaky ReLu 激活的全连接层和一个带 sigmoid 激活的全连接层。鉴别器网络 Dc 由 3 个全连接层和一个 sigmoid 单元层构成。在 Dc 的中间层应用 Dropout,Dropout 值设置为 0.4。在 Gc 和 Dc 中,Leaky ReLu 中的泄漏斜率均设置为 0.2。对于验证器,作者在分类任务中采用特定的分类器模块(即图像的 CNN),并使用交叉熵函数计算其损失。对于回归,验证器调用特定的回归模块(即 LASSO),并使用 MSE 计算损失。对于模仿模型构造,作者通过将类别标签或回归值 y(如 cGAN 中的 y)作为附加输入到 Gw 和 Dw 中来改进 WGAN-GP。将具有 100 维正态分布和类别标签的噪声先验 z_c 或最小批量大小为 32 的扩充数据集的回归值 y 的组合确定为 Gw 最底层的输入,在 Gw 和 Dw 中都使用全连接层和激活层。特别的,作者为图像添加了卷积层以获得更好的模仿模型。作者使用学习率等于 0.00005 的 RMSprop 优化器,当 Gw 迭代一次时,Dw 迭代 5 次。
2.2.1 合成数据生成的有效性
首先,作者通过比较两个模型的准确性、召回率和 F1 分数来评估合成数据生成的有效性:在 S_aug 上训练的增强模型和在与 S_aug 大小相同的清洁数据集上训练的基线模型。在每个数据集上使用 cGAN 和验证器体系结构对这两个模型进行训练,并使用训练好的鉴别器来评估性能。为了统一分类和回归数据集的指标,作者使用清洁样本和污染样本对这两个模型进行了测试。对于回归数据集,我们可以获得通过模型处理的结果,然后计算分类数据集的准确度、召回率和 F1 分数。
表 6 给出了 CIFAR-10 数据集上的 IS 和 FID 结果、通过合成数据生成模型获得的 Fourclass 和 House pricing 数据集上的 WD 和 AED 结果,以及真实数据和 cGAN 模型上基线的比较结果。从表 6 给出的实验结果可以看出,本文模型可以达到真实数据的性能,并且优于 cGAN 模型,作者分析,这表明使用验证器有助于生成足够的有效数据。
表 6. 合成数据生成的评价结果
2.2.2 清洁数据量的影响
为了评估 St 大小的影响,作者将 Rp 设置为固定值(例如,Rp=20%)进行实验。结果如表 7 所示,在不同攻击下,De Pois 的准确度、召回率和 F1 分数平均高于 0.9,并且随着 St 大小的增加波动小于 10%,从而验证了 De Pois 是攻击不可知的,并且对于分类和回归任务都是有效的。
表 7. 固定污染率环境下准确度和 F1 评分的比较
2.2.3 De Pois 中不同组件的影响
作者评估了 De Pois 中不同组件的影响。作者将 De Pois 与其他三种不同的组合进行比较,即 cGAN+GAN、cGAN+cWGAN GP 和 cGAN_验证器 + GAN。在实验中,Rp=20%,| St |=20%| So |。结果如图 8 所示,本文提出的验证器在准确性、召回率和 F1 方面分别提高了至少 0.03、0.04 和 0.03。此外,本文提出的模仿模型构造部分在准确性、召回率和 F1 方面分别提高了至少 0.05、0.04 和 0.03。
图 8. De Pois 中不同组件的影响
3、Siren:基于主动报警的拜占庭鲁棒联邦学习 [13]
目前很多工作都是针对特定攻击进行防御,特别是在联邦学习(Federated Learning,FL)领域。但是在联邦学习的实际应用场景中,防御者并不能预先确定攻击者采用的是什么攻击方式,这就需要攻击不可知的防御措施了。与前几篇文章不同,这篇文章聚焦的是联邦学习的防御措施。
在联邦学习应用场景中,FL 通过松散连接的设备网络执行分布式机器学习,这些设备都是自主自发地参与训练过程,因此很难确定恶意客户端的具体数量。此外,客户端设备上的本地数据通常是非独立同分布的(Non-independent and identically distributed,Non-IID)。这些数据在参与的客户端设备之间的倾斜加剧了本地模型之间的差异,进一步混淆了恶意客户端和良性客户端之间的界限。因此,联邦学习很容易受到拜占庭式攻击:攻击者通过隐藏在联邦学习客户端之间的恶意客户端恶意更新模型,从而破坏联邦模型。
本文提出了一个拜占庭鲁棒的联邦学习系统 Siren,通过分析模型准确度和梯度协调客户端与 FL 服务器,以抵御 IID 和 Non-IID 数据上的各种攻击。作者设计了一个主动分布式报警系统,使客户端能够与 FL 服务器协作进行攻击检测。Siren 客户端保留一小部分本地数据集以测试全局模型的准确度并触发警报,FL 服务器联合分析客户端警报、模型权重更新和准确度以检测攻击。在分布式报警系统的基础上,作者提出了一个决策过程来检测恶意客户端并有效净化模型的聚合结果。
3.1 Siren 介绍
3.1.1 整体结构分析
在深度学习架构中,客户端和 FL 服务器之间的唯一通讯参数是权重更新,因此,恶意客户端只能通过修改这些权重更新来攻击全局模型。在这种情况下,大多数当前的只进行权重分析和只在服务器上部署防御措施的拜占庭鲁棒聚合规则使得 FL 系统极易受到攻击。
为了解决上述问题,本文提出了一种攻击不可知的防御方法 Siren。图 9 显示了 Siren 的整体结构。在客户端上有两个过程,即训练过程和报警过程。Siren 在每个客户端上保留本地数据集的一小部分作为本地测试数据。相比之下,在标准 FL 系统和使用其他聚合规则的系统中,客户端只执行一个训练过程。Siren 中的训练过程与标准 FL 中的训练过程相同,负责使用本地数据进行本地模型训练,而报警过程负责测试全局权重。在每一轮通信中,每个客户端上的报警过程都使用本地权重和本地测试数据集检查全局权重。如果客户端将全局权重视为污染权重,则会向 FL 服务器报警,FL 服务器会根据每个客户端的报警状态启动检测过程以排除恶意权重更新。
图 9. FL 架构。灰色块属于默认的 FL 组件,带点边框的红色块是 Siren 的组件
3.1.2 客户端侧分析
图 10 给出了 Siren 的客户端工作流程,即客户端执行报警流程以验证全局模型𝒈_𝑡是否污染,并在每一轮通信中向 FL 服务器上传报警状态 (𝐴_t)^(𝑖) 和模型权重更新Δ(𝒈_𝑡)^(𝑖)。Siren 要求每个客户端保留一个本地测试数据集和一份在上一轮生成的本地模型权重的副本。报警过程比较本地模型和全局模型在本地测试数据集上的准确度以证明全局模型是否可信,同时保证客户端可以在下一轮本地训练过程中使用该全局模型。
图 10、 Siren 客户端和 FL 服务器之间的交互说明
为了简单起见,我们用客户端𝑖来代表一个一般的参与客户端,它可能是恶意的,也可能是良性的。客户端算法的详细描述如下。
- 步骤 1:当第 (𝑡+1) 轮通信开始时,客户端𝑖收到 FL 服务器在上一轮(即第𝑡轮)通信中汇总的全局模型权重𝒈_𝑡。
- 步骤 2:与默认的 FedAvg 算法直接用全局模型𝒈_𝑡开始局部训练不同,每个 Siren 客户端首先启动报警程序,评估全局模型𝒈_𝑡和上一轮通信中训练的局部模型 (𝒈_t)^(𝑖)。全局模型𝒈_𝑡的准确度为𝜔_𝑡,局部模型(𝒈_t)^(𝑖) 的准确度为(𝜔_t)^(𝑖)。
- 步骤 3:为了论证全局模型𝒈_𝑡是否污染,报警过程进一步比较准确度𝜔_𝑡和 (𝜔_t)^(𝑖)。如果全局模型𝒈_𝑡比局部模型(𝒈_t)^(𝑖) 更准确,即𝜔_𝑡 - (𝜔_t)^(𝑖) ≥ 𝐶𝑐 ,其中𝐶𝑐 是预先定义的正阈值,则客户端𝑖在第 (𝑡 + 1) 轮通信训练中用𝒈_𝑡初始化本地模型𝒈^(𝑖)。客户端𝑖将报警状态 (A_t)^(𝑖) 设置为 0。相反,如果𝜔_𝑡 - (𝜔_t)^(𝑖) < 𝐶𝑐 ,由于全局模型的性能异常,则客户端𝑖用 (𝒈_t)^(𝑖) 而不是𝒈_𝑡初始化局部模型𝒈^(𝑖)。客户端𝑖将报警状态 (A_t)^(𝑖) 设置为 1。
- 步骤 4:客户端𝑖通过安全隧道(例如基于 Diffie-Hellman 算法的 IPsec 隧道)向 FL 服务器发送报警状态 (A_t)^(𝑖),这样可以防止报警状态在网络传输中被篡改。客户端𝑖通过在其本地数据上训练模型(𝒈_t)^(𝑖) 以获得新的模型 (𝒈_(t+1))^(𝑖),其中报警状态(A_t)^(𝑖) 决定了𝒈^(𝑖)在步骤 3 中是 (𝒈_t)^(𝑖) 或𝒈_𝑡。然后,客户端𝑖计算并发送本地权重更新Δ(𝒈_(t+1))^(𝑖) = (𝒈_(t+1))^(𝑖)- 𝒈_𝑡到 FL 服务器,并储存本地模型 (𝒈_(t+1))^(𝑖) 用于下一轮计算。
下面的 Algorithm 1 给出了上述客户端报警和训练过程的伪代码。这种客户端报警机制保证了污染的全局模型总是会触发良性的客户端报警。还应该注意的是,即使在模型没有污染的情况下,恶意的客户端也可以故意伪造警报来欺骗 FL 服务器。不过,Siren 能够识别来自 FL 服务器端的恶意客户端的这种报警。
3.1.3 服务器侧分析
由于联邦学习的固有脆弱性,FL 服务器既不相信本地模型的更新,也不相信任何参与客户端发出的警报。在汇总本地模型更新并像 FedAvg 那样更新全局模型之前,Siren 的 FL 服务器首先启动一个检测过程,分析警报状态并评估本地模型权重,以识别潜在的攻击。
在一个通信回合𝑡中,FL 服务器执行两个阶段的检测。1)检查前一轮聚合中生成的全局模型是否污染。2)测试在当前回合中收集的客户端模型更新是否污染。下面的步骤说明了 FL 服务器在每一轮通信中的两阶段检测过程。
- 步骤 1:在第 t 轮通信中,FL 服务器通过安全隧道从所有参与的客户端检索报警状态(A_t)^(𝑖),并收集客户端模型权重更新Δ(𝒈_(𝑡+1))^(𝑖),其中𝑖∈𝐾。
- 步骤 2:FL 服务器按照图 11 所示的决策过程分析所有客户端的报警。如果没有客户端报警,FL 服务器直接汇总来自客户端的模型权重更新并更新全局模型。然而,如果有任何报警,FL 服务器会进一步评估来自发出报警的客户端的模型更新{Δ(𝒈_(𝑡+1))^(𝑖) |𝑖 ∈𝑆𝑎}。其中𝑆𝑎⊆𝐾,𝑆𝑎是报警客户端的集合。
图 11. FL 服务器的决策过程。FL 服务器做出的决定以红色和斜体突出显示
- 步骤 3:FL 服务器使用 (𝒈_(𝑡+1))^(𝑖)= Δ(𝒈_(𝑡+1))^(𝑖) + 𝒈_𝑡恢复客户端模型权重,并使用根测试数据集(root test dataset)评估客户端模型(𝒈_(𝑡+1))^(𝑖) 是否污染以及客户端𝑖是否是恶意的,其中𝑖∈𝑆𝑎。如果在𝑆𝑎中没有发现恶意客户,FL 服务器将把模型权重评估扩展到所有参与的客户端。
- 步骤 4:FL 服务器在汇总模型权重更新时过滤掉被识别为污染的客户端模型更新,以更新第(𝑡-1)轮通信中的全局模型𝒈_(𝑡+1),而不是𝒈_𝑡,因为后者被识别为污染。因此,全局模型更新为:
其中,𝑆𝑏是被识别为良性的客户端集合。
- 步骤 5:FL 服务器丢弃𝒈_𝑡,并将模型权重𝒈_(𝑡+1)复制到𝒈_𝑡中。在全局模型𝒈_𝑡被推送给所有客户端后,开始启动第 (𝑡) 轮通信。作者还在 FL 服务器上制作了黑名单,以排除已经被认定为恶意的客户端参与训练。
下述 Algorithm 2 展示了 FL 服务器的检测和聚合过程的伪代码。
3.1.4 决策过程和安全性分析
作者进一步对 FL 服务器的决策过程进行推理分析。如果 FL 服务器在第 t 轮收到 0 个警报,存在如图 18 所示的两种情况:情况 1 ,𝒈_𝑡没有污染,{Δ(𝒈_(𝑡+1))^(𝑖) |𝑖 ∈𝐾}都是良性更新。情况 2,𝒈_(𝑡-1)没有污染,但在 {Δ(𝒈_(𝑡+1))^(𝑖) |𝑖∈𝐾} 中存在污染的模型更新。如果𝒈_𝑡污染,只要有一个良性的客户端存在客户端报警机制就会保证激活报警。此外,情况 2 只发生在恶意客户端在第一轮被攻击的通信中对全局模型𝒈_1 发送污染信息时。良性客户端将通过比较全局模型和本地模型在下一轮通信中的准确度来检测这种污染的更新。因此,FL 服务器选择在没有报警时直接汇总模型更新。
当出现激活报警时,FL 服务器首先测试发出激活报警的客户端(即𝑖∈𝑆𝑎)中模型的准确性,并在其中寻找准确度最大的客户端 max{(𝜔_t)^(𝑖) |𝑖∈𝑆𝑎}。作者使用用户定义的阈值𝐶𝑠来衡量最大准确度和每个报警客户端的准确度之间的差异。报警客户端要么与情况 3 有相似的准确度,即∀𝑖∈𝑆𝑎,max{(𝜔_t)^(𝑖) |𝑖∈𝑆𝑎} -(𝜔_t)^(𝑖) < 𝐶𝑠 ,或者具有发散的准确度,即∃𝑖∈ 𝑆𝑎,max{(𝜔_t)^(𝑖) |𝑖∈ 𝑆𝑎}-(𝜔_t)^(𝑖) ≥ 𝐶𝑠作为情况 4(如图 11)。
对于情况 3,如果没有攻击—无论是全局模型𝒈_(𝑡-1)还是客户端更新 {Δ(𝒈_t)^(𝑖) |𝑖 ∈ 𝐾} 都没有污染,那么激活的报警一定是恶意客户端故意产生的假信号。如果存在攻击,而报警客户端的模型更新具有类似的准确度,则应该测试沉默客户端的模型更新,以进一步验证报警客户端的模型更新是全部污染还是全部良性。因此,对于情况 3,应该始终测试所有沉默客户端的模型更新{Δ(𝒈_t)^(𝑖)|𝑖∈𝑆𝑠 },其中𝑆𝑠是沉默客户端集合。如果沉默客户端的最高准确度接近或甚至优于报警客户端的最高准确度:
则 FL 服务器可以保证良性客户端是沉默的,因此,所有发出报警的客户端的更新都是污染的。如果一个沉默客户端的更新准确度接近所有沉默客户端的最大准确度,则认为这个客户端是良性的。因此,当一个沉默客户端的准确度符合 max{(𝜔_t)^(𝑖) |𝑖 ∈ 𝑠 }-(𝜔_t)^(𝑖) < 𝐶𝑠 ,其中𝑖 ∈ 𝑠 ,则将该沉默客户端加入良性客户端集合𝑆𝑏。由于所有的良性客户端都是沉默的,因此恶意客户端发出的报警是假报警,而上一轮得到的全局模型𝒈_𝑡并没有污染。
相反,对于情况 3,如果沉默客户端的最大准确度低于报警客户端的最大准确度:max{(𝜔_t)^(𝑖) |𝑖 ∈ 𝑆𝑎}- max{(𝜔_t)^(𝑖) |𝑖 ∈ 𝑆𝑠 }> 𝐶𝑠,则所有沉默客户端的模型更新都污染,而发出报警的所有客户端由于其准确度相似,所以是良性的。
对于情况 4,报警客户端的准确度存在发散现象,表明良性和恶意客户端都在报警。同样,我们使用所有报警客户端的最大准确度来过滤掉报警的恶意客户端。如果一个报警客户端的准确度满足 max{𝜔_t)^(𝑖) |𝑖 ∈𝑆𝑎}-𝜔_t)^(𝑖) <𝐶𝑠,那么我们将其加入良性客户端集合𝑆𝑏。由于良性客户端在检测到污染𝒈_𝑡时总是会报警,所以在沉默的客户端中不存在良性客户端。因此,在这种情况下,我们将忽略所有的沉默客户端。
3.1.5 权重分析
由于大多数攻击的目的是产生恶意的权重更新,与良性的权重更新相比,它可以对全局模型产生反向影响。与良性更新相比,这些恶意的权重更新通常代表了模型的反向变化方向。本节作者对 Siren 进行权重分析。
然而,与 FL Trust 使用由服务器端数据训练的辅助模型不同(这意味着系统对全局模型有一个预定义的期望)[14],Siren 只使用来自客户端的信息。通过 Siren 的权重分析,服务器不仅要比较具有最大准确度的更新和其他更新之间的准确度,还要比较这些更新之间的角度。如果更新𝜔_𝑖与最大准确度 max{(𝜔_t)^(𝑖) |𝑖∈𝑆𝑎}之间的角度大于𝜋/2,那么𝜔_𝑖将被服务器视为一个恶意更新。否则,𝜔_𝑖被认为是良性更新,可以纳入到全局模型的计算过程中。通过权重分析,服务器可以从另一个角度检查客户端的更新,同时保持其客观性。
3.1.6 辅助机制
为了进一步改进 Siren 的性能,作者在本节中讨论向 Siren 结构中引入一些辅助机制。所有这些辅助机制都只在服务器端使用,因此并不会向客户端引入任何额外的计算负担。而服务器本身可以根据服务器上的计算资源以及从更好的安全性和性能的需求角度出发,灵活地决定是否使用这些辅助机制。
- 惩罚机制。作者设计了一个惩罚机制来提高 Siren 的稳定性。恶意客户端可以持续攻击服务器,而与之对应的持续检查会浪费大量的计算资源。通过惩罚机制,服务器记录每个客户端被视为恶意客户端的次数。如果一个客户端的计数大于阈值𝐶𝑝,那么服务器将不再接受来自该客户端的更新且不再检查,因为此时已经默认该客户端为恶意客户端。通过这种方法,服务器可以有效地节省计算资源,提高系统的稳定性。
- 奖励机制。由于每个客户端的数据存在差异,惩罚机制可能会将良性客户端误判为恶意客户端。因此,服务器可以利用奖励机制,让被禁止的客户端有机会重新加入到训练中来。在某一个通信回合中,如果一个被禁止的客户端被服务器视为良性客户端,则可以通过奖励机制将这个客户端的惩罚计数减少𝐶𝑎。如果这个被禁止的客户端的惩罚次数小于𝐶𝑝,那么可以令这个客户端再次参与训练过程。有了这个机制,服务器就可以减轻惩罚机制的副作用。