持续学习中避免灾难性遗忘的Elastic Weight Consolidation Loss数学原理及代码实现

本文涉及的产品
实时计算 Flink 版,1000CU*H 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
简介: 在本文中,我们将探讨一种方法来解决这个问题,称为Elastic Weight Consolidation。EWC提供了一种很有前途的方法来减轻灾难性遗忘,使神经网络在获得新技能的同时保留先前学习任务的知识。

训练人工神经网络最重要的挑战之一是灾难性遗忘。神经网络的灾难性遗忘(catastrophic forgetting)是指在神经网络学习新任务时,可能会忘记之前学习的任务。这种现象特别常见于传统的反向传播算法和深度学习模型中。主要原因是网络在学习新数据时,会调整权重以适应新任务,这可能会导致之前学到的知识被覆盖或忘记,尤其是当新任务与旧任务有重叠时。

在本文中,我们将探讨一种方法来解决这个问题,称为Elastic Weight Consolidation。EWC提供了一种很有前途的方法来减轻灾难性遗忘,使神经网络在获得新技能的同时保留先前学习任务的知识。

在任务a和任务B的灰色和黄色区域中,存在许多具有期望的低误差的最优参数配置。假设我们为任务A找到了一个这样的配置θꭺ*,当继续从这样的配置训练模型到新的任务B时,会出现三种不同的场景:

蓝色箭头:简单地继续在任务B上进行训练而不受惩罚,将在任务B的低水平区域结束,但在任务A上的表现低于预期的准确性。

绿色箭头:使用任务A的权重的L2约束可能太强,使得模型在任务A上表现良好,但在任务B上表现不佳。

红色箭头:这是EWC是提出的解决方案,它将在模型在两个任务上都表现良好的区域(两个区域之间的交叉点)中找到参数。

下面我们将解释这是如何完成的。

费雪信息矩阵(FIM)

EWC方法所基于的FIM(Fisher Information Matrix)。FIM是一种统计度量,用于量化给定数据提供的关于我们要估计的未知参数θ的信息量。在持续学习的背景下,FIM将有助于识别神经网络参数,这些参数从以前的任务中获取的数据信息较少。通过更新这些参数,网络可以学习新的任务,而不会删除存储在参数中的重要信息,这些信息是关于先前学习任务的非常有用的信息。

假设X是一个随机变量,其概率密度函数f(X |θ)参数化为θ。样本x的似然函数(仅在数据固定的情况下为参数函数)为:

和对数拟然:

将FIM定义为:

这表明对数似然函数对参数的微小变化有多敏感。我们可以将FIM视为似然函数二阶导数的负期望:

当求二阶导数时,基本上是在看似然函数的曲率。

可以考虑下面的两个绘制的似然函数的图表。蓝色曲线表示在峰值附近非常窄的分布,表明数据更有可能在θ附近,并且随着远离θ而迅速减少。相反,黑色曲线代表一个更广泛的分布,即使远离θ,数据也保持相似的可能性。

FIM量化了这个概念——数据是多么严格地限制在某个θ值上。较大的FIM(如蓝色曲线所示)意味着参数值的微小变化将导致数据在这些参数下的可能性显著下降。相反,较小的FIM(如黑色曲线所示)意味着参数值的较小变化将导致可能性的较小降低。

事实证明,费雪信息矩阵与数据的方差(或多变量情况下的协方差)成反比。在上面的图表中,如果假设曲线分别代表均值θ 0和方差σ²ᵦₗᵤₑ和σ²ᵦₗ꜀ₖ的两个高斯分布,其中σ²ᵦₗᵤₑ< σ²ᵦₗ꜀ₖFIM等于1/σ²,因此蓝色曲线包含更多信息。

弹性重量固结

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:

弹性权重保持

弹性权重保持(Elastic Weight Consolidation,EWC)是一种用于减轻神经网络灾难性遗忘问题的方法。它的基本思想是在学习新任务时保护先前任务的关键权重。

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:

对两边应用对数并不改变最大化的目标,因为对数是一个单调变换。因此目标变成:

假设两个独立任务D = {A, B},我们有:

最后一个是独立于A和B的。这里log(p(B|θ))是任务B的损失,log(p(B))是B的可能性,它可以作为优化的常数,因为它不依赖于θ, log(p(θ| a))是任务a的后验分布,它包含了任务a重要参数的所有信息。

估计log(p(θ|A))比较复杂的,因为计算它将涉及在整个参数空间上对高维函数进行积分。但是它近似为正态分布,其均值为任务a - θꭺ的最优参数,方差为费雪信息矩阵。这种近似是有意义的,因为我们可以假设A和B任务的新参数θ与任务A的最优参数相差不远。在所有θꭺ的参数中,会有一些参数对任务A的良好表现更重要,并且不希望它们改变太多,这就是FIM的作用,FIM的值表明在这种情况下,改变某个参数将如何影响任务A的损失。因此,FIM中值越高的参数变化受到的惩罚越大。

现在,我们对任务A的最优权值进行泰勒展开直到第二项

其中log(p(θꭺ|A))是一个常数,我们可以在优化中忽略它。我们也可以忽略第二项,因为在最优θꭺ处,梯度为零。这样就找到了log(p(θ|A))的表达式,把它代回到图8的原始公式中:

第二项的二阶导数为Hessian,可以根据图5的定义用费雪信息矩阵近似。log(p(B|θ))是新任务B的损失,例如交叉熵,我们记为Lᵦ(θ)

我们不需要进行二阶导数,只需根据图4中等价于图5的定义,即对数似然梯度的外积,用一阶导数近似FIM即可:

这样优化L(θ)的总损失为:

λ是一个超参数,表示在前一个任务a上保持精度的重要性。

上面涉及梯度向量的外积的定义捕获了梯度的协方差结构。而FIM的对角线近似通常由梯度的平方给出,它只计算参数的方差,但计算成本较低,足以完成任务:

Pytorch实现

上面我们介绍了弹性权重保持的数学原理,下面我们来看看Pytorch的代码实现

让我们首先导入一些库以及分别代表任务A和任务B的MNIST和Fashion MNIST数据集。我们还定义了一个简单的神经网络:

 importtorch
 importtorch.nnasnn
 importtorch.nn.functionalasF
 importtorch.optimasoptim
 fromtorchimportautograd
 importnumpyasnp
 fromtorch.utils.dataimportDataLoader

 fromtorch.utils.dataimportDataset, DataLoader
 fromtorchvisionimportdatasets, transforms
 fromtqdmimporttqdm

 defget_accuracy(model, dataloader):
     model=model.eval()
     acc=0
     forinput, targetindataloader:
         o=model(input.to(device))
         acc+= (o.argmax(dim=1).long() ==target.to(device)).float().mean()
     returnacc/len(dataloader)

 classLinearLayer(nn.Module):
     # from https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
     def__init__(self, input_dim, output_dim, act='relu', use_bn=False):
         super(LinearLayer, self).__init__()
         self.use_bn=use_bn
         self.lin=nn.Linear(input_dim, output_dim)
         self.act=nn.ReLU() ifact=='relu'elseact
         ifuse_bn:
             self.bn=nn.BatchNorm1d(output_dim)
     defforward(self, x):
         ifself.use_bn:
             returnself.bn(self.act(self.lin(x)))
         returnself.act(self.lin(x))

 classFlatten(nn.Module):

     defforward(self, x):
         returnx.view(x.shape[0], -1)

 classModel(nn.Module):

     def__init__(self, num_inputs, num_hidden, num_outputs):
         super(Model, self).__init__()
         self.f1=Flatten()
         self.lin1=LinearLayer(num_inputs, num_hidden, use_bn=True)
         self.lin2=LinearLayer(num_hidden, num_hidden, use_bn=True)
         self.lin3=nn.Linear(num_hidden, num_outputs)

     defforward(self, x):
         returnself.lin3(self.lin2(self.lin1(self.f1(x))))

 # Load MNIST dataset, representint task A
 mnist_train=datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
 mnist_test=datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
 train_loader=DataLoader(mnist_train, batch_size=100, shuffle=True)
 test_loader=DataLoader(mnist_test, batch_size=100, shuffle=False)

 # FashiomMNIST is task B
 f_mnist_train=datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
 f_mnist_test=datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
 f_train_loader=DataLoader(f_mnist_train, batch_size=100, shuffle=True)
 f_test_loader=DataLoader(f_mnist_test, batch_size=100, shuffle=False)

现在让我们在MNIST任务上训练模型:

 # parameters
 EPOCHS=4
 lr=0.001
 weight=100000
 accuracies= {}

 device='cuda:1'

 criterion=nn.CrossEntropyLoss()

 # train model on task A 
 model=Model(28*28, 100, 10).to(device)
 optimizer=optim.Adam(model.parameters(), lr)

 for_inrange(EPOCHS):
     forinput, targetintqdm(train_loader):
         output=model(input.to(device))
         loss=criterion(output, target.to(device))
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()

 accuracies['mnist_initial'] =get_accuracy(model, test_loader)

现在可以定义函数来估计FIM和EWC损失中使用的先前参数:

 defewc_loss(model, weight, estimated_fishers, estimated_means):
     losses= []
     forparam_name, paraminmodel.named_parameters():
         estimated_mean=estimated_means[param_name]
         estimated_fisher=estimated_fishers[param_name]
         losses.append((estimated_fisher* (param-estimated_mean) **2).sum())

     return (weight/2) *sum(losses)

 defestimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
     estimated_mean= {}

     forparam_name, paraminmodel.named_parameters():
         estimated_mean[param_name] =param.data.clone()

     estimated_fisher= {}
     dl=DataLoader(train_ds, batch_size, shuffle=True)

     forn, pinmodel.named_parameters():
         estimated_fisher[n] =torch.zeros_like(p)

     model.eval()
     fori, (input, target) inenumerate(dl):
         ifi>num_batch:
             break
         model.zero_grad()

         output=model(input.to(device))
         # https://www.inference.vc/on-empirical-fisher-information/ - more on this here
         ifESTIMATE_TYPE=='empirical':
             # empirical
             label=target.to(device)
         else:
             # true estimate
             label=output.max(1)[1]

         loss=F.nll_loss(F.log_softmax(output, dim=1), label)
         loss.backward()

         # accumulate all the gradients
         forn, pinmodel.named_parameters():
             estimated_fisher[n].data+=p.grad.data**2/len(dl)

     estimated_fisher= {n: pforn, pinestimated_fisher.items()}
     returnestimated_mean, estimated_fisher

然后继续在任务B上训练EWC损失的网络:

 # compute fisher and mean parameters for EWC loss
 estimated_mean, estimated_fisher=estimate_ewc_params(model, mnist_train)

 # Train task B fashion mnist
 for_inrange(EPOCHS):
     forinput, targetintqdm(f_train_loader):
         output=model(input.to(device))
         loss=ewc_loss(model, weight, estimated_fisher, estimated_mean) +criterion(output, target.to(device))
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()

 accuracies['mnist_EWC'] =get_accuracy(model, test_loader)
 accuracies['f_mnist_EWC'] =get_accuracy(model, f_test_loader)

可以得到以下精度:

 {'mnist_initial': tensor(0.9772, device='cuda:1'),
  'mnist_AB': tensor(0.9717, device='cuda:1'),
  'f_mnist': tensor(0.8312, device='cuda:1')}

最后将这些与没有EWC损失的模型进行比较:

 {'mnist_initial': tensor(0.9762, device='cuda:1'),
  'mnist_AB': tensor(0.1769, device='cuda:1'),
  'f_mnist': tensor(0.8672, device='cuda:1')}

可以看到EWC损失有助于保持任务A的准确率几乎不变,而学习任务B的准确率几乎与没有EWC损失的情况相同。

总结

我们看到了一种允许神经网络在继续学习新任务的同时保留其先前学习的知识的技术,虽然EWC在解决灾难性遗忘方面效果显著,但仍有一些挑战,例如对费雪信息矩阵的计算和存储需求较高,以及在复杂的深度神经网络结构中的实施复杂性。

还有还有其他方法可以使模型进行持续学习,比如:

重播记忆(Replay Memory):保存旧数据以便周期性地重训练。

联合训练(Joint Training):同时训练网络以处理旧任务和新任务。

元学习方法(Meta-learning Approaches):通过元学习算法来优化模型,以便快速适应新任务而不会忘记旧任务。

这些方法有助于减轻灾难性遗忘的影响,使神经网络能够持续学习和适应多个任务。

https://avoid.overfit.cn/post/56aee34117764e89a1a707c316fa305f

目录
相关文章
|
机器学习/深度学习 存储 数据采集
使用GANs生成时间序列数据:DoppelGANger论文详解(一)
使用GANs生成时间序列数据:DoppelGANger论文详解
1623 0
使用GANs生成时间序列数据:DoppelGANger论文详解(一)
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
揭秘深度学习中的微调难题:如何运用弹性权重巩固(EWC)策略巧妙应对灾难性遗忘,附带实战代码详解助你轻松掌握技巧
【10月更文挑战第1天】深度学习中,模型微调虽能提升性能,但常导致“灾难性遗忘”,即模型在新任务上训练后遗忘旧知识。本文介绍弹性权重巩固(EWC)方法,通过在损失函数中加入正则项来惩罚对重要参数的更改,从而缓解此问题。提供了一个基于PyTorch的实现示例,展示如何在训练过程中引入EWC损失,适用于终身学习和在线学习等场景。
939 4
揭秘深度学习中的微调难题:如何运用弹性权重巩固(EWC)策略巧妙应对灾难性遗忘,附带实战代码详解助你轻松掌握技巧
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
彻底告别微调噩梦:手把手教你击退灾难性遗忘,让模型记忆永不褪色的秘密武器!
【10月更文挑战第5天】深度学习中,模型微调虽能提升性能,但也常导致灾难性遗忘,即学习新任务时遗忘旧知识。本文介绍几种有效解决方案,重点讲解弹性权重巩固(EWC)方法,通过在损失函数中添加正则项来防止重要权重被更新,保护模型记忆。文中提供了基于PyTorch的代码示例,包括构建神经网络、计算Fisher信息矩阵和带EWC正则化的训练过程。此外,还介绍了其他缓解灾难性遗忘的方法,如LwF、在线记忆回放及多任务学习,以适应不同应用场景。
1229 8
|
机器学习/深度学习 算法 前端开发
阿里面试官分享+真实面经+笔试模拟题 | 面试充电,就看这篇
阿里面试官分享+真实面经+笔试模拟题+招聘信息汇总,太全了!这篇合辑一定要看,不然就亏大啦!
阿里面试官分享+真实面经+笔试模拟题 | 面试充电,就看这篇
|
3月前
|
机器学习/深度学习 人工智能 机器人
面向人机协作任务的具身智能系统感知-决策-执行链条建模
本文探讨了面向人机协作任务的具身智能系统建模,涵盖感知、决策与执行链条。具身智能强调智能体通过“身体”与环境互动,实现学习与适应,推动机器人技术升级。文章分析了其关键组成(感知、控制与决策系统)、挑战(高维状态空间、模拟鸿沟等)及机遇(仿真训练加速、多模态感知融合等)。通过代码示例展示了基于PyBullet的强化学习训练框架,并展望了通用具身智能的未来,包括多任务泛化、跨模态理解及Sim2Real迁移技术,为智能制造、家庭服务等领域提供新可能。
面向人机协作任务的具身智能系统感知-决策-执行链条建模
|
机器学习/深度学习 存储 算法
【博士每天一篇论文-算法】Continual Learning Through Synaptic Intelligence,SI算法
本文介绍了一种名为"Synaptic Intelligence"(SI)的持续学习方法,通过模拟生物神经网络的智能突触机制,解决了人工神经网络在学习新任务时的灾难性遗忘问题,并保持了计算效率。
617 1
【博士每天一篇论文-算法】Continual Learning Through Synaptic Intelligence,SI算法
|
计算机视觉
增量学习中Task incremental、Domain incremental、Class incremental 三种学习模式的概念及代表性数据集?
本文介绍了增量学习中的三种主要模式:任务增量学习(Task-incremental)、域增量学习(Domain-incremental)和类别增量学习(Class-incremental),它们分别关注任务序列、数据分布变化和类别更新对学习器性能的影响,并列举了每种模式下的代表性数据集。
1780 4
增量学习中Task incremental、Domain incremental、Class incremental 三种学习模式的概念及代表性数据集?
|
机器学习/深度学习 算法 Python
【博士每天一篇文献-算法】Overcoming catastrophic forgetting in neural networks
本文介绍了一种名为弹性权重合并(EWC)的方法,用于解决神经网络在学习新任务时遭受的灾难性遗忘问题,通过选择性地降低对旧任务重要权重的更新速度,成功地在多个任务上保持了高性能,且实验结果表明EWC在连续学习环境中的有效性。
659 2
【博士每天一篇文献-算法】Overcoming catastrophic forgetting in neural networks
|
11月前
|
机器学习/深度学习 存储 监控
揭秘微调‘失忆’之谜:如何运用低秩适应与多任务学习等策略,快速破解灾难性遗忘难题?
【10月更文挑战第13天】本文介绍了几种有效解决微调灾难性遗忘问题的方法,包括低秩适应(LoRA)、持续学习和增量学习策略、记忆增强方法、多任务学习框架、正则化技术和适时停止训练。通过示例代码和具体策略,帮助读者优化微调过程,提高模型的稳定性和效能。
452 5
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
907 9