在 PyTorch 中使用梯度检查点在GPU 上训练更大的模型

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
文档理解,结构化解析 100页
简介: 本文将介绍解梯度检查点(Gradient Checkpointing),这是一种可以让你以增加训练时间为代价在 GPU 中训练大模型的技术。 我们将在 PyTorch 中实现它并训练分类器模型。

作为机器学习从业者,我们经常会遇到这样的情况,想要训练一个比较大的模型,而 GPU 却因为内存不足而无法训练它。当我们在出于安全原因不允许在云计算的环境中工作时,这个问题经常会出现。在这样的环境中,我们无法足够快地扩展或切换到功能强大的硬件并训练模型。并且由于梯度下降算法的性质,通常较大的批次在大多数模型中会产生更好的结果,但在大多数情况下,由于内存限制,我们必须使用适应GPU显存的批次大小。

本文将介绍解梯度检查点(Gradient Checkpointing),这是一种可以让你以增加训练时间为代价在 GPU 中训练大模型的技术。 我们将在 PyTorch 中实现它并训练分类器模型。

梯度检查点

在反向传播算法中,梯度计算从损失函数开始,计算后更新模型权重。图中每一步计算的所有导数或梯度都会被存储,直到计算出最终的更新梯度。这样做会消耗大量 GPU 内存。梯度检查点通过在需要时重新计算这些值和丢弃在进一步计算中不需要的先前值来节省内存。

让我们用下面的虚拟图来解释。

上面是一个计算图,每个叶节点上的数字相加得到最终输出。假设这个图表示反向传播期间发生的计算,那么每个节点的值都会被存储,这使得执行求和所需的总内存为7,因为有7个节点。但是我们可以用更少的内存。假设我们将1和2相加,并在下一个节点中将它们的值存储为3,然后删除这两个值。我们可以对4和5做同样的操作,将9作为加法的结果存储。3和9也可以用同样的方式操作,存储结果后删除它们。通过执行这些操作,在计算过程中所需的内存从7减少到3。

在没有梯度检查点的情况下,使用PyTorch训练分类模型

我们将使用PyTorch构建一个分类模型,并在不使用梯度检查点的情况下训练它。记录模型的不同指标,如训练所用的时间、内存消耗、准确性等。

由于我们主要关注GPU的内存消耗,所以在训练时需要检测每批的内存消耗。这里使用nvidia-ml-py3库,该库使用nvidia-smi命令来获取内存信息。

 pip install nvidia-ml-py3

为了简单起见,我们使用简单的狗和猫分类数据集的子集。

 git clone https://github.com/laxmimerit/dog-cat-full-dataset.git

执行上述命令后会在dog-cat-full-dataset的文件夹中得到完整的数据集。

导入所需的包并初始化nvdia-smi

 importtorch
 importtorch.nnasnn
 importtorch.optimasoptim
 importnumpyasnp
 fromtorchvisionimportdatasets, models, transforms
 importmatplotlib.pyplotasplt
 importtime
 importos
 importcv2
 importnvidia_smi
 importcopy
 fromPILimportImage
 fromtorch.utils.dataimportDataset,DataLoader
 importtorch.utils.checkpointascheckpoint
 fromtqdmimporttqdm
 importshutil
 fromtorch.utils.checkpointimportcheckpoint_sequential
 device="cuda"iftorch.cuda.is_available() else"cpu"
 %matplotlibinline
 importrandom
 
 nvidia_smi.nvmlInit()

导入训练和测试模型所需的所有包。我们还初始化nvidia-smi。

定义数据集和数据加载器

 #Define the dataset and the dataloader.
 train_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/train",
                             transform=transforms.Compose([
                                 transforms.RandomRotation(30),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.RandomResizedCrop(224, scale=(0.96, 1.0), ratio=(0.95, 1.05)),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                             ]))
 
 val_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/test",
                             transform=transforms.Compose([
                                 transforms.Resize([224, 224]),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                             ]))
 
 train_dataloader=DataLoader(train_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=2)
 
 val_dataloader=DataLoader(val_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=2)

这里我们用torchvision数据集的ImageFolder类定义数据集。还在数据集上定义了某些转换,如RandomRotation, RandomHorizontalFlip等。最后对图片进行归一化,并且设置batch_size=64

定义训练和测试函数

 deftrain_model(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):
 
     model.train()
     #Training loop.
     forepochinrange(epochs):
       model.train()
       forimages, targetintqdm(train_dataloader):
           images, target=images.to(device), target.to(device)
           images.requires_grad=True
           optimizer.zero_grad()
           output=model(images)
           loss=loss_func(output, target)
           loss.backward()
           optimizer.step()
       ifos.path.exists('grad_checkpoints/') isFalse:
         os.mkdir('grad_checkpoints')
       torch.save(model.state_dict(), 'grad_checkpoints/epoch_'+str(epoch)+'.pt')
 
 
       #Test the model on validation data.
       train_acc,train_loss=test_model(model,train_dataloader)
       val_acc,val_loss=test_model(model,val_dataloader)
 
       #Check memory usage.
       handle=nvidia_smi.nvmlDeviceGetHandleByIndex(0)
       info=nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
       memory_used=info.used
       memory_used=(memory_used/1024)/1024
 
       print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
 
 
 
 deftest_model(model,val_dataloader):
   model.eval()
   test_loss=0
   correct=0
   withtorch.no_grad():
       forimages, targetinval_dataloader:
           images, target=images.to(device), target.to(device)
           output=model(images)
           test_loss+=loss_func(output, target).data.item()
           _, predicted=torch.max(output, 1)
           correct+= (predicted==target).sum().item()
   
   test_loss/=len(val_dataloader.dataset)
 
   returnint(correct/len(val_dataloader.dataset) *100),test_loss

上面创建了一个简单的训练和测试循环来训练模型。最后还通过调用nvidia-smi计算内存使用。

训练

 torch.manual_seed(0)
 
 #Learning rate.
 lr=0.003
 
 #Defining the VGG16 sequential model.
 vgg16=models.vgg16()
 vgg_layers_list=list(vgg16.children())[:-1]
 vgg_layers_list.append(nn.Flatten())
 vgg_layers_list.append(nn.Linear(25088,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,2))
 model=nn.Sequential(*vgg_layers_list)
 model=model.to(device)
 
 
 
 #Num of epochs to train
 num_epochs=10
 
 #Loss
 loss_func=nn.CrossEntropyLoss()
 
 # Optimizer 
 # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
 optimizer=optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
 
 
 #Training the model.
 model=train_model(model, loss_func, optimizer,
                        train_dataloader,val_dataloader,num_epochs)

我们使用VGG16模型进行分类。下面是模型的训练日志。

可以从上面的日志中看到,在没有检查点的情况下,训练64个批大小的模型大约需要5分钟,占用内存为14222.125 mb。

使用带有梯度检查点的PyTorch训练分类模型

为了用梯度检查点训练模型,只需要编辑train_model函数。

 deftrain_with_grad_checkpointing(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):
 
 
     #Training loop.
     forepochinrange(epochs):
       model.train()
       forimages, targetintqdm(train_dataloader):
           images, target=images.to(device), target.to(device)
           images.requires_grad=True
           optimizer.zero_grad()
           #Applying gradient checkpointing
           segments=2
 
           # get the modules in the model. These modules should be in the order
           # the model should be executed
           modules= [modulefork, moduleinmodel._modules.items()]
 
           # now call the checkpoint API and get the output
           output=checkpoint_sequential(modules, segments, images)
           loss=loss_func(output, target)
           loss.backward()
           optimizer.step()
       ifos.path.exists('checkpoints/') isFalse:
         os.mkdir('checkpoints')
       torch.save(model.state_dict(), 'checkpoints/epoch_'+str(epoch)+'.pt')
 
 
       #Test the model on validation data.
       train_acc,train_loss=test_model(model,train_dataloader)
       val_acc,val_loss=test_model(model,val_dataloader)
 
       #Check memory.
       handle=nvidia_smi.nvmlDeviceGetHandleByIndex(0)
       info=nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
       memory_used=info.used
       memory_used=(memory_used/1024)/1024
 
       print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
 
 deftest_model(model,val_dataloader):
   model.eval()
   test_loss=0
   correct=0
   withtorch.no_grad():
       forimages, targetinval_dataloader:
           images, target=images.to(device), target.to(device)
           output=model(images)
           test_loss+=loss_func(output, target).data.item()
           _, predicted=torch.max(output, 1)
           correct+= (predicted==target).sum().item()
   
   test_loss/=len(val_dataloader.dataset)
 
   returnint(correct/len(val_dataloader.dataset) *100),test_lossdeftest_model(model,val_dataloader)

我们将函数名修改为train_with_grad_checkpointing。也就是不通过模型(图)运行训练,而是使用checkpoint_sequential函数进行训练,该函数有三个输入:modules, segments, input。modules是神经网络层的列表,按它们执行的顺序排列。segments是在序列中创建的段的个数,使用梯度检查点进行训练以段为单位将输出用于重新计算反向传播期间的梯度。本文设置segments=2。input是模型的输入,在我们的例子中是图像。这里的checkpoint_sequential仅用于顺序模型,对于其他一些模型将产生错误。

使用梯度检查点进行训练,如果你在notebook上执行所有的代码。建议重新启动,因为nvidia-smi可能会获得以前代码中的内存消耗。

 torch.manual_seed(0)
 
 lr=0.003
 
 # model = models.resnet50()
 # model=model.to(device)
 
 vgg16=models.vgg16()
 vgg_layers_list=list(vgg16.children())[:-1]
 vgg_layers_list.append(nn.Flatten())
 vgg_layers_list.append(nn.Linear(25088,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,2))
 model=nn.Sequential(*vgg_layers_list)
 model=model.to(device)
 
 
 
 
 num_epochs=10
 
 #Loss
 loss_func=nn.CrossEntropyLoss()
 
 # Optimizer 
 # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
 optimizer=optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
 
 
 #Fitting the model.
 model=train_with_grad_checkpointing(model, loss_func, optimizer,
                        train_dataloader,val_dataloader,num_epochs)

输出如下:

从上面的输出可以看到,每个epoch的训练大约需要6分45秒。但只需要10550.125 mb的内存,也就是说我们用时间换取了空间,并且这两种情况下的精度都是79,因为在梯度检查点的情况下模型的精度没有损失。

总结

梯度检查点是一个非常好的技术,它可以帮助在小显存的情况下完整模型的训练。经过我们的测试,一般情况下梯度检查点会将训练时间延长20%左右,但是时间长点总比不能用要好,对吧。

本文的源代码:

https://avoid.overfit.cn/post/a13e29c312c741ac94d4a5079fb9f8af

作者:Vikas Kumar Ojha

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
目录
相关文章
|
1月前
|
并行计算 Shell TensorFlow
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
在使用TensorFlow-GPU训练MTCNN时,如果遇到“Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED”错误,通常是由于TensorFlow、CUDA和cuDNN版本不兼容或显存分配问题导致的,可以通过安装匹配的版本或在代码中设置动态显存分配来解决。
47 1
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
|
30天前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
145 2
|
1月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
52 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
1月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
48 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
17天前
|
人工智能 语音技术 UED
仅用4块GPU、不到3天训练出开源版GPT-4o,这是国内团队最新研究
【10月更文挑战第19天】中国科学院计算技术研究所提出了一种名为LLaMA-Omni的新型模型架构,实现与大型语言模型(LLMs)的低延迟、高质量语音交互。该模型集成了预训练的语音编码器、语音适配器、LLM和流式语音解码器,能够在不进行语音转录的情况下直接生成文本和语音响应,显著提升了用户体验。实验结果显示,LLaMA-Omni的响应延迟低至226ms,具有创新性和实用性。
36 1
|
2月前
|
并行计算 PyTorch 算法框架/工具
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
101 4
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
|
2月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
129 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
2月前
|
机器学习/深度学习 监控 PyTorch
PyTorch 模型调试与故障排除指南
在深度学习领域,PyTorch 成为开发和训练神经网络的主要框架之一。本文为 PyTorch 开发者提供全面的调试指南,涵盖从基础概念到高级技术的内容。目标读者包括初学者、中级开发者和高级工程师。本文探讨常见问题及解决方案,帮助读者理解 PyTorch 的核心概念、掌握调试策略、识别性能瓶颈,并通过实际案例获得实践经验。无论是在构建简单神经网络还是复杂模型,本文都将提供宝贵的洞察和实用技巧,帮助开发者更高效地开发和优化 PyTorch 模型。
40 3
PyTorch 模型调试与故障排除指南
|
1月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
3月前
|
机器学习/深度学习 边缘计算 PyTorch
PyTorch 与边缘计算:将深度学习模型部署到嵌入式设备
【8月更文第29天】随着物联网技术的发展,越来越多的数据处理任务开始在边缘设备上执行,以减少网络延迟、降低带宽成本并提高隐私保护水平。PyTorch 是一个广泛使用的深度学习框架,它不仅支持高效的模型训练,还提供了多种工具帮助开发者将模型部署到边缘设备。本文将探讨如何将PyTorch模型高效地部署到嵌入式设备上,并通过一个具体的示例来展示整个流程。
466 1
下一篇
无影云桌面