【博士每天一篇文献-算法】Zero-Shot Machine Unlearning

简介: 这篇论文提出了零样本机器遗忘的概念,介绍了两种新方法——错误最小化-最大化噪声(Error Maximization-Minimization, M-M)和门控知识传输(Gated Knowledge Transfer, GKT),以实现在不访问原始训练数据的情况下从机器学习模型中删除特定数据,同时引入了Anamnesis指数来评估遗忘质量,旨在帮助企业有效遵守数据隐私法规。

阅读时间:2023-12-3

1 介绍

年份:2023
作者:Vikram S. Chundawat,Ayush K. Tarun,新加坡国立大学
期刊: IEEE Transactions on Information Forensics and Security
引用量:5
由于现代隐私法规中,有些个人要求产品和服务(包括机器学习模型)不能保留他们的信息。机器遗忘涉及从训练好的机器学习模型中移除特定数据,而无需重新进行昂贵且耗时的训练。文档介绍了零样本机器遗忘的概念,即无需原始训练数据就可以实现遗忘。提出了两种新颖的方法,使用错误最小化-最大化噪声和门控知识传输来实现零样本机器遗忘。这些方法旨在从模型中删除已删除的数据,同时保持对保留数据的准确性,并防止隐私攻击。文中引入了用于评估遗忘质量的Anamnesis指数度量标准。提出的零样本遗忘方法为企业有效遵守数据隐私法规提供了解决方案,而无需访问原始训练数据。对深度学习模型在基准视觉数据集上进行的实验结果显示了遗忘的有希望的结果。

2 创新点

  1. 提出了零样本机器遗忘的问题,着眼于现代隐私法规赋予个人被产品和服务遗忘的权利,包括机器学习(ML)模型。零样本机器遗忘的概念是指在没有原始训练数据的情况下实现机器遗忘。
  2. 提出了两种方法来实现零样本机器遗忘,包括使用误差最小化-最大化噪声和门控知识传递。
  3. 引入了评估遗忘质量的指标——回忆指数(Anamnesis)。
  4. 这些零射遗忘方法为企业提供了一种有效遵守数据隐私法规的解决方案,而不需要访问原始训练数据。

3 相关研究

(1)基本概念
Machine unlearning是指要求 ML 模型所有者从用于构建 ML 模型的训练集中删除数据所有者的数据。
Kullback-Leibler (KL) 散度 用于作衡量两个概率分布之间相似性。
(2)机器学习中的遗忘
机器遗忘可以广泛分为两类:精确遗忘和近似遗忘方法,在精确遗忘中,通过从头开始重新训练模型并从训练集中排除该数据,移除数据点(待遗忘)对模型的影响。
【L. Bourtoule et al., “Machine unlearning,” in Proc. IEEE Symp. Secur. Privacy (SP), May 2021, pp. 141–159.】提出了一种SISA方法,通过将数据集分成一组不重叠的碎片,来训练模型。这减少了对完全重训练的需求,因为模型可以在其中一个碎片上进行重新训练。 近似遗忘方法的目标是近似参数,这些参数可以通过不使用待遗忘数据进行训练而获得。通常的方法是通过将参数调整到一个从头开始训练的模型(从未见过待遗忘数据)附近,在相对较低的更新次数比较精确遗忘方法的情况下获得。
【L. Graves, V. Nagisetty, and V. Ganesh, “Amnesiac machine learning,” in Proc. AAAI Conf. Artif. Intell., 2021, vol. 35, no. 13, pp. 11516–11524】在训练过程中在参数空间中存储由每个数据点做出的更新。在遗忘过程中,将这些相应的更新从模型的最终参数中减去。
【J. Brophy and D. Lowd, “Machine unlearning for random forests,” in Proc. 38th Int. Conf. Mach. Learn., 2021, pp. 1092–1104.】提出了随机森林来支持数据遗忘。
(3)深度学习中的遗忘
【“Eternal sunshine of the spotless net: Selective forgetting in deep networks, 2020】提出了一种信息论方法,用于从使用SGD训练的深度网络的中间层中擦除信息。
【Forgetting outside the box: Scrubbing deep networks of information accessible from input-output observations,2020】提出了一种基于神经切线核(NTK)的训练过程近似,并使用它来估计遗忘后的网络权重。然而,即使对于小型数据集,近似的准确性会迅速下降,计算成本也会迅速增加。
【Mixed-privacy forgetting in deep networks,2021】直接训练线性化网络并将其用于遗忘。他们训练两个单独的网络:核心模型和用于遗忘目的的混合线性模型。然而,为每个深度架构设计混合线性网络是一种低效的方法,并且根据原始模型的网络结构需要人工干预。
【Fast yet effective machine unlearning, 2023】提出了一种基于错误最大化的方法,用于学习类别遗忘的噪声矩阵。他们使用这样的反样本来诱导深度网络中的类级遗忘。该方法需要一小部分保留数据来进行遗忘。
(4)遗忘设置和数据隐私
近似遗忘方法旨在通过对训练信息的可用性(如训练数据、使用的优化技术)进行某些假设,并在一定程度上减轻模型的有效性,从而提供更高的效率。大多数现有的遗忘方法可以临时分为三类。

  • 第一类:使用要遗忘的数据来更新ML模型以进行遗忘。在这些方法中,忘记数据对模型的影响被近似为Newton算法步骤,并向训练目标函数注入随机噪声。【Certified data removal from machine learning models, 2020】【Approximate data deletion from machine learning models,2021】
  • 第二类,需要访问剩余的训练数据(不包括要遗忘的数据)。这些方法使用Fisher信息并向模型权重注入最佳噪声以实现遗忘。然而,它们对训练过程实施了一定的限制,比如只允许SGD优化而不是训练。【Eternal sunshine of the spotless net: Selective forgetting in deep networks, 2020】【Forgetting outside the box: Scrubbing deep networks of information accessible from input-output observations, 2020】【Mixed-privacy forgetting in deep networks,2021】。论文【Fast yet effective machine unlearning, 2023】在提出时消除了某些限制。所有上述方法都需要访问要遗忘的数据、剩余数据或在训练期间存储的某些元数据。这些假设不反映实际情况中的遗忘场景。
  • 不需要数据的遗忘方法。【When machine unlearning jeopardizes privacy, 2021】

(4)相关概念
Newton步骤是一种近似算法,用于在机器学习模型中执行数据删除操作。在一种模型反演攻击中,将输入向量初始化为零,并添加一些小的噪音,然后通过梯度下降使用与目标类别相关的损失进行优化。在每个梯度下降步骤之后执行图像处理步骤,以帮助识别生成的图像。这个过程经过多次迭代,最终得到反转的图像。Newton步骤是用于模型反转攻击的关键部分,旨在通过生成特定类别的图像来获取关于该类别的表示信息。

4 算法

4.1 定义零样本遗忘问题

D 是一个数据集,
Cf 表示我们希望机器学习(ML)模型忘记的类别集合
Df 是对应于要忘记的类别的数据
Cr 表示我们希望模型记住的保留类别
Dr 是对应于要保留的类别的数据

4.2 算法一:Error最大化-最小化(M-M)

截屏2024-04-10 上午9.57.00.png
第一步:通过隐私攻击的方法生成分别保留类别和遗忘类别的伪样本,其中对于遗忘样本,随机初始化噪声矩阵N_f,计算与遗忘类别最大预测误差,通过梯度更新N_f,最终获得攻击得到的矩阵N_f。
对于保留样本,随机初始化噪声矩阵N_r输入到模型,计算模型的输出与保留类别的最小预测误差,通过梯度更新N_r,最终获得攻击得到矩阵N_r。
第二步:用以上攻击得到的样本构成数据集(N_f+N_r),用来重新训练原始模型,更新模型权重。

4.3 算法二:门控知识传输(GKT)

截屏2024-04-19 下午1.48.51.png
KL散度是一种衡量两个概率分布相似度的方法。越大越相似。
第一步:通过教师网络生成训练样本,生成的样本要是保留类别的话,用于学生网络训练。要是是遗忘类别的话,就不使用该样本。教师网络和生成器之间的损失是用KL散度。
第二步:训练学生网络,模仿教师网络,损失函数为由KL散度和注意力损失组合的损失函数。
其中的KL散度是用于评价教师网络和学生网络输出概率分布之间的相似度,
$ DKL(T(x)||S(x)) = \sum_{i} t_p^i \log\left(\frac{t_p^i}{s_p^o}\right) $

其中, $ t_p^i$是教师网络输出的概率,$ s_p^i $​是学生网络输出的概率
注意力损失用于确保学生网络学习到教师网络的重要特征。
$ L_{at}=\sum_{l\in NL}\left|\left|\frac{f(A_l^{(t)})}{||f(A_l^{(t)})||_2}-\frac{f(A_l^{(s)})}{||f(A_l^{(s)})||_2}\right|\right| $

其中,
$ A_t^l 和 A_s^l $ 分别是教师和学生网络在第l层的输出,
f(⋅)是一个函数(最大池化来获取最显著的特征或者通过平均池化来获取整体特征)
$ ∣∣*∣∣_2 $​表示L2范数,用于计算向量或矩阵的欧几里得范数,即长度或大小
NL是选定的网络层集合

4.4 评价指标

(1)遗忘集(Forget Set, D_f)上的准确度
期望在被忘却的数据集上的准确度接近于重新训练(retrained)模型的准确度,因为忘却模型的行为应该与在没有观察到忘却数据集的情况下重新训练的模型相似。
(2)保留集(Retain Set, Dr)上的准确度
期望在保留的数据集上的准确度接近于原始模型的准确度,这意味着模型在执行忘却操作后,对于保留的数据集仍然保持较高的性能。
(3)记忆恢复指数(Anamnesis Index, AIN)
这是一个新提出的度量,用于更有效地评估忘却操作的效果。AIN是通过计算忘却模型重新学习到原始模型准确度一定百分比(α%)范围内所需的时间(步数或epoch数)与重新训练模型达到相同准确度所需的时间的比值来计算的。AIN的值越接近1,表示忘却效果更好。
$ AIN = \frac{rt(M_u, M_{orig}, \alpha)}{rt(M_s, M_{orig}, \alpha)} $
$ M_u $​:经过忘却操作后的模型
$ M_{orig} $:原始的、未经忘却操作的模型
$ M_s $​:在忘却数据集
$D_r$​:上从零开始重新训练的模型。
$ rt(M, M_{orig}, \alpha) $:模型M重新学习到原始模型 $M_{orig}$ 准确度的 $\alpha $范围内所需的最小批次(或epoch)数。
$ \alpha \% $:一个边缘值,通常设为5%到10%,表示在计算重新学习时间时接受的性能差异范围。
(4)模型反演攻击(Model Inversion Attack)
通过模拟攻击来检查忘却模型的隐私保护能力。在模型反演攻击中,攻击者尝试从模型中恢复训练数据。如果忘却模型能够有效地抵抗这种攻击,即攻击无法从忘却模型中恢复出有关忘却类别的信息,那么这表明忘却方法是有效的。
(5)成员推断攻击(Membership Inference Attack)
这种攻击旨在推断某些数据点是否被用于训练模型。在忘却操作后,忘却模型对于忘却类别的数据点的成员推断概率应该低于原始模型,表明忘却模型在保护隐私方面更为健壮。

5 实验分析

(1)M-M和GKT算法的有效性
不同数据集上的Acc、AIN比较
截屏2024-04-19 下午3.24.06.png
M-M方法在零样本忘却中较差。由于这种方法需要对忘却类和保留类都进行噪声优化,噪声样本的质量决定了忘却性能。M-M方法无法保持在Dr上的准确度。
GKT方法在Dr上实现了接近原始模型的准确度。

(2)不同算法比较
截屏2024-04-19 下午3.25.37.png
GKT方法的Df为0,Dr也比较高,即实现了类别忘却也实现了对于其他类别预测的准确性。
(3)模型反演攻击(Model Inversion Attack)和成员推断攻击(Membership Inference Attack)
截屏2024-04-19 下午3.27.36.png

  • 图左一,原始模型:未忘却
  • 图中间,中间重新训练的模型:表示在删除了特定类别(例如类别0)后,仅使用剩余数据重新训练的模型。
  • 图右一,GKT忘却模型

AllCNN模型在MNIST数据集上进行的模型反演攻击的效果。图中是类别0的可视化。
模型反演攻击无法从我们的忘却模型中提取任何信息。

6 思考

(1)多种算法原理理解
Bad Teacher是对遗忘数据集构造伪标签
M-M算法是对遗忘数据集,构造伪图片
GKT算法是模型蒸馏+成员推断攻击
Fisher算法是用计算数据点的重要性,更新权重
Amnesiac算法是根据训练时的权重更新记录,逆向更新权重

  1. GKT算法
    • 缺陷:GKT方法在大型模型上的效果可能不够理想。
    • 改进角度
    • 模型大小适应性:研究如何调整GKT算法以更好地适应大型模型。
    • 训练策略优化:可能需要更精细的训练策略,例如动态调整学习率或使用更高级的优化器。
  2. M-M算法
    • 缺陷:M-M方法在零样本设置中表现不佳,无法有效保持对保留类别(Dr)的准确性,并且在忘却类别(Df)上的准确性也不理想。
    • 改进角度
      • 噪声生成策略:改进噪声矩阵的生成方法,以便更有效地模拟训练数据的分布。
      • 数据代理:寻找更好的数据代理方法,以在没有实际数据的情况下模拟训练样本。
  3. 基于Fisher信息的Unlearning:
    • 原理:利用Fisher信息来估计模型参数对数据点的敏感度,并据此进行参数更新以实现忘却。
    • 优点:可以针对特定数据点进行精确的忘却操作。
    • 缺点:计算Fisher信息可能需要大量计算资源。
    • 创新思路:开发一种近似计算Fisher信息的方法,减少计算成本,或者利用迁移学习技术将Fisher信息从一个模型迁移到另一个模型。
  4. Amnesiac Unlearning
    • 原理:通过存储训练过程中的中间信息(如梯度和权重更新),在需要忘却时回退到先前的状态。
    • 优点:可以较为精确地实现忘却操作。
    • 缺点:需要存储大量中间信息,可能导致内存开销大。
    • 创新思路:设计一种压缩机制来减少存储需求,或者使用元学习(Meta-Learning)技术来学习忘却过程,减少所需存储的信息。
  5. Bad Teacher Unlearning
    • 原理:通过一个“不称职”的教师模型来教导学生模型忘却特定信息。
    • 优点:可以在没有访问训练数据的情况下进行忘却。
    • 缺点:可能需要精心设计教师模型的教学策略。
    • 创新思路:利用对抗性网络(如GANs)生成“误导性”数据,这些数据能够促使学生模型忘却特定类别;或者开发一种自适应机制,自动调整教师模型的教学策略以提高忘却效率。

(2)可创新点

  1. 混合方法:结合GKT和Fisher信息,创建一个混合算法,利用Fisher信息来指导GKT中的噪声生成过程,从而提高忘却的精确性和效率。
  2. 自适应阈值确定:开发一种自适应机制来动态确定GKT中band-pass filter的阈值ϵ,以便更有效地过滤掉忘却类别的信息。
  3. 元学习忘却框架:利用元学习框架来学习忘却过程,通过少量的样本训练一个忘却模型,该模型能够快速适应新的忘却任务而无需从头开始训练。

7 代码

https://github.com/ayushkumartarun/zero-shot-unlearning
相关论文代码
【When Machine Unlearning Jeopardizes Privacy】https://github.com/MinChen00/UnlearningLeaks

目录
相关文章
|
3月前
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
65 12
|
3月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
194 9
|
3月前
|
机器学习/深度学习 算法 机器人
【博士每天一篇文献-算法】改进的PNN架构Lifelong learning with dynamically expandable networks
本文介绍了一种名为Dynamically Expandable Network(DEN)的深度神经网络架构,它能够在学习新任务的同时保持对旧任务的记忆,并通过动态扩展网络容量和选择性重训练机制,有效防止语义漂移,实现终身学习。
56 9
|
3月前
|
机器学习/深度学习 算法 文件存储
【博士每天一篇文献-算法】 PNN网络启发的神经网络结构搜索算法Progressive neural architecture search
本文提出了一种名为渐进式神经架构搜索(Progressive Neural Architecture Search, PNAS)的方法,它使用顺序模型优化策略和替代模型来逐步搜索并优化卷积神经网络结构,从而提高了搜索效率并减少了训练成本。
52 9
|
3月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之HNet:Continual learning with hypernetworks
本文提出了一种基于任务条件超网络(Hypernetworks)的持续学习模型,通过超网络生成目标网络权重并结合正则化技术减少灾难性遗忘,实现有效的任务顺序学习与长期记忆保持。
42 4
|
3月前
|
机器学习/深度学习 存储 人工智能
【博士每天一篇文献-算法】改进的PNN架构Progressive learning A deep learning framework for continual learning
本文提出了一种名为“Progressive learning”的深度学习框架,通过结合课程选择、渐进式模型容量增长和剪枝机制来解决持续学习问题,有效避免了灾难性遗忘并提高了学习效率。
53 4
|
3月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding
RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。
84 3
|
12天前
|
算法 安全 数据安全/隐私保护
基于game-based算法的动态频谱访问matlab仿真
本算法展示了在认知无线电网络中,通过游戏理论优化动态频谱访问,提高频谱利用率和物理层安全性。程序运行效果包括负载因子、传输功率、信噪比对用户效用和保密率的影响分析。软件版本:Matlab 2022a。完整代码包含详细中文注释和操作视频。
|
9天前
|
人工智能 算法 数据安全/隐私保护
基于遗传优化的SVD水印嵌入提取算法matlab仿真
该算法基于遗传优化的SVD水印嵌入与提取技术,通过遗传算法优化水印嵌入参数,提高水印的鲁棒性和隐蔽性。在MATLAB2022a环境下测试,展示了优化前后的性能对比及不同干扰下的水印提取效果。核心程序实现了SVD分解、遗传算法流程及其参数优化,有效提升了水印技术的应用价值。
|
10天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。