彻底告别微调噩梦:手把手教你击退灾难性遗忘,让模型记忆永不褪色的秘密武器!

简介: 【10月更文挑战第5天】深度学习中,模型微调虽能提升性能,但也常导致灾难性遗忘,即学习新任务时遗忘旧知识。本文介绍几种有效解决方案,重点讲解弹性权重巩固(EWC)方法,通过在损失函数中添加正则项来防止重要权重被更新,保护模型记忆。文中提供了基于PyTorch的代码示例,包括构建神经网络、计算Fisher信息矩阵和带EWC正则化的训练过程。此外,还介绍了其他缓解灾难性遗忘的方法,如LwF、在线记忆回放及多任务学习,以适应不同应用场景。

快速解决微调灾难性遗忘问题

随着深度学习的发展,模型微调已成为提高模型性能的重要手段之一。然而,在对预训练模型进行微调时,经常会出现灾难性遗忘的问题,即模型在学习新任务的同时,忘记了之前学到的知识。这不仅影响了模型在旧任务上的表现,也限制了其在多任务学习中的应用潜力。为了解决这一难题,研究者们提出了多种策略和技术,本文将介绍几种有效的解决方案,并提供相应的代码示例。

一种常用的缓解灾难性遗忘的方法是使用弹性权重巩固(Elastic Weight Consolidation,EWC)。EWC通过在损失函数中添加一个正则项来惩罚对重要权重的更新,从而保护模型不忘记先前学习到的信息。具体实现时,我们需要估计每个权重的重要性,并在微调过程中使用这些信息来引导优化方向。

首先,导入必要的库:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

然后,定义一个简单的神经网络模型:

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc(x)
        return x

接下来,定义EWC的计算方法:

def fisher_matrix_diag(model, loader):
    log_softmax = nn.LogSoftmax(dim=1)
    model.eval()
    fisher = {
   }
    for param_name, _ in model.named_parameters():
        fisher[param_name] = torch.zeros_like(model.state_dict()[param_name])

    for data, target in loader:
        output = model(data)
        log_probs = log_softmax(output)
        probs = torch.exp(log_probs)
        for c in range(log_probs.shape[1]):
            pseudo_counts = probs[:, c]
            log_pseudo_counts = log_probs[:, c]
            if pseudo_counts.requires_grad is not True:
                pseudo_counts.requires_grad = True
                log_pseudo_counts.requires_grad = True
            (pseudo_counts * log_pseudo_counts).sum().backward(retain_graph=True)
            for name, param in model.named_parameters():
                fisher[name] += param.grad.pow(2) / len(loader)

    for param_name in fisher.keys():
        fisher[param_name] /= len(loader)
    return fisher

定义带有EWC正则化的训练函数:

def ewc_train(model, loader, optimizer, criterion, fisher, prev_task_params, lamda=1000):
    model.train()
    for batch_idx, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        ewc_loss = 0
        for name, param in model.named_parameters():
            _loss = fisher[name] * (prev_task_params[name] - param).pow(2)
            ewc_loss += _loss.sum()
        loss += lamda * ewc_loss
        loss.backward()
        optimizer.step()

最后,准备数据并执行训练:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=64, shuffle=True)
test_loader = DataLoader(datasets.MNIST('data', train=False, transform=transform), batch_size=1000, shuffle=True)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 假设我们已经有了第一个任务的训练结果
initial_params = {
   name: param.clone().detach() for name, param in model.named_parameters()}
fisher_diags = fisher_matrix_diag(model, train_loader)

for epoch in range(5):  # loop over the dataset multiple times
    ewc_train(model, train_loader, optimizer, criterion, fisher_diags, initial_params)

以上就是使用EWC技术来缓解微调过程中灾难性遗忘问题的一种实现方式。除了EWC之外,还有其他方法如LwF(Learning without Forgetting)、在线记忆回放(Online Memory Replay)、多任务学习(Multi-task Learning)等,它们各有特点,在不同的场景下可能表现出不同的效果。选择哪种方法取决于具体的应用场景和个人需求。希望上述示例能够帮助你在实际项目中解决类似的问题。

相关文章
|
文字识别 前端开发
CodeFuse-VLM 开源,支持多模态多任务预训练/微调
随着huggingface开源社区的不断更新,会有更多的vision encoder 和 LLM 底座发布,这些vision encoder 和 LLM底座都有各自的强项,例如 code-llama 适合生成代码类任务,但是不适合生成中文类的任务,因此用户常常需要根据vision encoder和LLM的特长来搭建自己的多模态大语言模型。针对多模态大语言模型种类繁多的落地场景,我们搭建了CodeFuse-VLM 框架,支持多种视觉模型和语言大模型,使得MFT-VLM可以适应不同种类的任务。
1177 0
|
机器学习/深度学习 人工智能 自然语言处理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
揭秘深度学习中的微调难题:如何运用弹性权重巩固(EWC)策略巧妙应对灾难性遗忘,附带实战代码详解助你轻松掌握技巧
【10月更文挑战第1天】深度学习中,模型微调虽能提升性能,但常导致“灾难性遗忘”,即模型在新任务上训练后遗忘旧知识。本文介绍弹性权重巩固(EWC)方法,通过在损失函数中加入正则项来惩罚对重要参数的更改,从而缓解此问题。提供了一个基于PyTorch的实现示例,展示如何在训练过程中引入EWC损失,适用于终身学习和在线学习等场景。
939 4
揭秘深度学习中的微调难题:如何运用弹性权重巩固(EWC)策略巧妙应对灾难性遗忘,附带实战代码详解助你轻松掌握技巧
|
11月前
|
机器学习/深度学习 存储 监控
揭秘微调‘失忆’之谜:如何运用低秩适应与多任务学习等策略,快速破解灾难性遗忘难题?
【10月更文挑战第13天】本文介绍了几种有效解决微调灾难性遗忘问题的方法,包括低秩适应(LoRA)、持续学习和增量学习策略、记忆增强方法、多任务学习框架、正则化技术和适时停止训练。通过示例代码和具体策略,帮助读者优化微调过程,提高模型的稳定性和效能。
452 5
|
机器学习/深度学习 存储 算法
持续学习中避免灾难性遗忘的Elastic Weight Consolidation Loss数学原理及代码实现
在本文中,我们将探讨一种方法来解决这个问题,称为Elastic Weight Consolidation。EWC提供了一种很有前途的方法来减轻灾难性遗忘,使神经网络在获得新技能的同时保留先前学习任务的知识。
848 1
|
9月前
|
搜索推荐 物联网 PyTorch
Qwen2.5-7B-Instruct Lora 微调
本教程介绍如何基于Transformers和PEFT框架对Qwen2.5-7B-Instruct模型进行LoRA微调。
10416 34
Qwen2.5-7B-Instruct Lora 微调
|
9月前
|
数据采集 前端开发 物联网
【项目实战】通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型
本文介绍了一个基于多模态大模型的医疗图像诊断项目。项目旨在通过训练一个医疗领域的多模态大模型,提高医生处理医学图像的效率,辅助诊断和治疗。作者以家中老人的脑部CT为例,展示了如何利用MedTrinity-25M数据集训练模型,经过数据准备、环境搭建、模型训练及微调、最终验证等步骤,成功使模型能够识别CT图像并给出具体的诊断意见,与专业医生的诊断结果高度吻合。
17654 7
【项目实战】通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型
|
10月前
|
人工智能 JSON 监控
Qwen2.5-Coder-7B-Instruct Lora 微调 SwanLab 可视化记录版
本节我们简要介绍如何基于 transformers、peft 等框架,对Qwen2.5-Coder-7B-Instruct 模型进行Lora微调。使用的数据集是中文法律问答数据集 DISC-Law-SFT,同时使用 SwanLab 监控训练过程与评估模型效果。
1127 4
|
机器学习/深度学习 算法 Python
【博士每天一篇文献-算法】Overcoming catastrophic forgetting in neural networks
本文介绍了一种名为弹性权重合并(EWC)的方法,用于解决神经网络在学习新任务时遭受的灾难性遗忘问题,通过选择性地降低对旧任务重要权重的更新速度,成功地在多个任务上保持了高性能,且实验结果表明EWC在连续学习环境中的有效性。
659 2
【博士每天一篇文献-算法】Overcoming catastrophic forgetting in neural networks
|
文字识别 自然语言处理 数据可视化
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
在 Qwen2 发布后的过去三个月里,许多开发者基于 Qwen2 语言模型构建了新的模型,并提供了宝贵的反馈。在这段时间里,通义千问团队专注于创建更智能、更博学的语言模型。今天,Qwen 家族的最新成员:Qwen2.5系列正式开源
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!