在 PyTorch 中使用梯度检查点在GPU 上训练更大的模型

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: 本文将介绍解梯度检查点(Gradient Checkpointing),这是一种可以让你以增加训练时间为代价在 GPU 中训练大模型的技术。 我们将在 PyTorch 中实现它并训练分类器模型。

作为机器学习从业者,我们经常会遇到这样的情况,想要训练一个比较大的模型,而 GPU 却因为内存不足而无法训练它。当我们在出于安全原因不允许在云计算的环境中工作时,这个问题经常会出现。在这样的环境中,我们无法足够快地扩展或切换到功能强大的硬件并训练模型。并且由于梯度下降算法的性质,通常较大的批次在大多数模型中会产生更好的结果,但在大多数情况下,由于内存限制,我们必须使用适应GPU显存的批次大小。

本文将介绍解梯度检查点(Gradient Checkpointing),这是一种可以让你以增加训练时间为代价在 GPU 中训练大模型的技术。 我们将在 PyTorch 中实现它并训练分类器模型。

梯度检查点

在反向传播算法中,梯度计算从损失函数开始,计算后更新模型权重。图中每一步计算的所有导数或梯度都会被存储,直到计算出最终的更新梯度。这样做会消耗大量 GPU 内存。梯度检查点通过在需要时重新计算这些值和丢弃在进一步计算中不需要的先前值来节省内存。

让我们用下面的虚拟图来解释。

上面是一个计算图,每个叶节点上的数字相加得到最终输出。假设这个图表示反向传播期间发生的计算,那么每个节点的值都会被存储,这使得执行求和所需的总内存为7,因为有7个节点。但是我们可以用更少的内存。假设我们将1和2相加,并在下一个节点中将它们的值存储为3,然后删除这两个值。我们可以对4和5做同样的操作,将9作为加法的结果存储。3和9也可以用同样的方式操作,存储结果后删除它们。通过执行这些操作,在计算过程中所需的内存从7减少到3。

在没有梯度检查点的情况下,使用PyTorch训练分类模型

我们将使用PyTorch构建一个分类模型,并在不使用梯度检查点的情况下训练它。记录模型的不同指标,如训练所用的时间、内存消耗、准确性等。

由于我们主要关注GPU的内存消耗,所以在训练时需要检测每批的内存消耗。这里使用nvidia-ml-py3库,该库使用nvidia-smi命令来获取内存信息。

 pip install nvidia-ml-py3

为了简单起见,我们使用简单的狗和猫分类数据集的子集。

 git clone https://github.com/laxmimerit/dog-cat-full-dataset.git

执行上述命令后会在dog-cat-full-dataset的文件夹中得到完整的数据集。

导入所需的包并初始化nvdia-smi

 importtorch
 importtorch.nnasnn
 importtorch.optimasoptim
 importnumpyasnp
 fromtorchvisionimportdatasets, models, transforms
 importmatplotlib.pyplotasplt
 importtime
 importos
 importcv2
 importnvidia_smi
 importcopy
 fromPILimportImage
 fromtorch.utils.dataimportDataset,DataLoader
 importtorch.utils.checkpointascheckpoint
 fromtqdmimporttqdm
 importshutil
 fromtorch.utils.checkpointimportcheckpoint_sequential
 device="cuda"iftorch.cuda.is_available() else"cpu"
 %matplotlibinline
 importrandom
 
 nvidia_smi.nvmlInit()

导入训练和测试模型所需的所有包。我们还初始化nvidia-smi。

定义数据集和数据加载器

 #Define the dataset and the dataloader.
 train_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/train",
                             transform=transforms.Compose([
                                 transforms.RandomRotation(30),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.RandomResizedCrop(224, scale=(0.96, 1.0), ratio=(0.95, 1.05)),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                             ]))
 
 val_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/test",
                             transform=transforms.Compose([
                                 transforms.Resize([224, 224]),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                             ]))
 
 train_dataloader=DataLoader(train_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=2)
 
 val_dataloader=DataLoader(val_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=2)

这里我们用torchvision数据集的ImageFolder类定义数据集。还在数据集上定义了某些转换,如RandomRotation, RandomHorizontalFlip等。最后对图片进行归一化,并且设置batch_size=64

定义训练和测试函数

 deftrain_model(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):
 
     model.train()
     #Training loop.
     forepochinrange(epochs):
       model.train()
       forimages, targetintqdm(train_dataloader):
           images, target=images.to(device), target.to(device)
           images.requires_grad=True
           optimizer.zero_grad()
           output=model(images)
           loss=loss_func(output, target)
           loss.backward()
           optimizer.step()
       ifos.path.exists('grad_checkpoints/') isFalse:
         os.mkdir('grad_checkpoints')
       torch.save(model.state_dict(), 'grad_checkpoints/epoch_'+str(epoch)+'.pt')
 
 
       #Test the model on validation data.
       train_acc,train_loss=test_model(model,train_dataloader)
       val_acc,val_loss=test_model(model,val_dataloader)
 
       #Check memory usage.
       handle=nvidia_smi.nvmlDeviceGetHandleByIndex(0)
       info=nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
       memory_used=info.used
       memory_used=(memory_used/1024)/1024
 
       print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
 
 
 
 deftest_model(model,val_dataloader):
   model.eval()
   test_loss=0
   correct=0
   withtorch.no_grad():
       forimages, targetinval_dataloader:
           images, target=images.to(device), target.to(device)
           output=model(images)
           test_loss+=loss_func(output, target).data.item()
           _, predicted=torch.max(output, 1)
           correct+= (predicted==target).sum().item()
   
   test_loss/=len(val_dataloader.dataset)
 
   returnint(correct/len(val_dataloader.dataset) *100),test_loss

上面创建了一个简单的训练和测试循环来训练模型。最后还通过调用nvidia-smi计算内存使用。

训练

 torch.manual_seed(0)
 
 #Learning rate.
 lr=0.003
 
 #Defining the VGG16 sequential model.
 vgg16=models.vgg16()
 vgg_layers_list=list(vgg16.children())[:-1]
 vgg_layers_list.append(nn.Flatten())
 vgg_layers_list.append(nn.Linear(25088,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,2))
 model=nn.Sequential(*vgg_layers_list)
 model=model.to(device)
 
 
 
 #Num of epochs to train
 num_epochs=10
 
 #Loss
 loss_func=nn.CrossEntropyLoss()
 
 # Optimizer 
 # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
 optimizer=optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
 
 
 #Training the model.
 model=train_model(model, loss_func, optimizer,
                        train_dataloader,val_dataloader,num_epochs)

我们使用VGG16模型进行分类。下面是模型的训练日志。

可以从上面的日志中看到,在没有检查点的情况下,训练64个批大小的模型大约需要5分钟,占用内存为14222.125 mb。

使用带有梯度检查点的PyTorch训练分类模型

为了用梯度检查点训练模型,只需要编辑train_model函数。

 deftrain_with_grad_checkpointing(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):
 
 
     #Training loop.
     forepochinrange(epochs):
       model.train()
       forimages, targetintqdm(train_dataloader):
           images, target=images.to(device), target.to(device)
           images.requires_grad=True
           optimizer.zero_grad()
           #Applying gradient checkpointing
           segments=2
 
           # get the modules in the model. These modules should be in the order
           # the model should be executed
           modules= [modulefork, moduleinmodel._modules.items()]
 
           # now call the checkpoint API and get the output
           output=checkpoint_sequential(modules, segments, images)
           loss=loss_func(output, target)
           loss.backward()
           optimizer.step()
       ifos.path.exists('checkpoints/') isFalse:
         os.mkdir('checkpoints')
       torch.save(model.state_dict(), 'checkpoints/epoch_'+str(epoch)+'.pt')
 
 
       #Test the model on validation data.
       train_acc,train_loss=test_model(model,train_dataloader)
       val_acc,val_loss=test_model(model,val_dataloader)
 
       #Check memory.
       handle=nvidia_smi.nvmlDeviceGetHandleByIndex(0)
       info=nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
       memory_used=info.used
       memory_used=(memory_used/1024)/1024
 
       print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
 
 deftest_model(model,val_dataloader):
   model.eval()
   test_loss=0
   correct=0
   withtorch.no_grad():
       forimages, targetinval_dataloader:
           images, target=images.to(device), target.to(device)
           output=model(images)
           test_loss+=loss_func(output, target).data.item()
           _, predicted=torch.max(output, 1)
           correct+= (predicted==target).sum().item()
   
   test_loss/=len(val_dataloader.dataset)
 
   returnint(correct/len(val_dataloader.dataset) *100),test_lossdeftest_model(model,val_dataloader)

我们将函数名修改为train_with_grad_checkpointing。也就是不通过模型(图)运行训练,而是使用checkpoint_sequential函数进行训练,该函数有三个输入:modules, segments, input。modules是神经网络层的列表,按它们执行的顺序排列。segments是在序列中创建的段的个数,使用梯度检查点进行训练以段为单位将输出用于重新计算反向传播期间的梯度。本文设置segments=2。input是模型的输入,在我们的例子中是图像。这里的checkpoint_sequential仅用于顺序模型,对于其他一些模型将产生错误。

使用梯度检查点进行训练,如果你在notebook上执行所有的代码。建议重新启动,因为nvidia-smi可能会获得以前代码中的内存消耗。

 torch.manual_seed(0)
 
 lr=0.003
 
 # model = models.resnet50()
 # model=model.to(device)
 
 vgg16=models.vgg16()
 vgg_layers_list=list(vgg16.children())[:-1]
 vgg_layers_list.append(nn.Flatten())
 vgg_layers_list.append(nn.Linear(25088,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,2))
 model=nn.Sequential(*vgg_layers_list)
 model=model.to(device)
 
 
 
 
 num_epochs=10
 
 #Loss
 loss_func=nn.CrossEntropyLoss()
 
 # Optimizer 
 # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
 optimizer=optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
 
 
 #Fitting the model.
 model=train_with_grad_checkpointing(model, loss_func, optimizer,
                        train_dataloader,val_dataloader,num_epochs)

输出如下:

从上面的输出可以看到,每个epoch的训练大约需要6分45秒。但只需要10550.125 mb的内存,也就是说我们用时间换取了空间,并且这两种情况下的精度都是79,因为在梯度检查点的情况下模型的精度没有损失。

总结

梯度检查点是一个非常好的技术,它可以帮助在小显存的情况下完整模型的训练。经过我们的测试,一般情况下梯度检查点会将训练时间延长20%左右,但是时间长点总比不能用要好,对吧。

本文的源代码:

https://avoid.overfit.cn/post/a13e29c312c741ac94d4a5079fb9f8af

作者:Vikas Kumar Ojha

相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
目录
相关文章
|
3月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
198 1
|
5月前
|
机器学习/深度学习 PyTorch 测试技术
从训练到推理:Intel Extension for PyTorch混合精度优化完整指南
PyTorch作为主流深度学习框架,凭借动态计算图和异构计算支持,广泛应用于视觉与自然语言处理。Intel Extension for PyTorch针对Intel硬件深度优化,尤其在GPU上通过自动混合精度(AMP)提升训练与推理性能。本文以ResNet-50在CIFAR-10上的实验为例,详解如何利用该扩展实现高效深度学习优化。
274 0
|
3月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
161 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
4月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
259 9
|
6月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
260 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
182 4
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
|
5月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
1950 0
Flow Matching生成模型:从理论基础到Pytorch代码实现
|
7月前
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
782 17
|
7月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。

热门文章

最新文章

推荐镜像

更多