一文读懂—Pytorch混合精度训练

简介: 一文读懂—Pytiorch混合精度训练

> 复现代码时遇到了自动混合精度。查阅资料得知,Pytorch从1.60开始支持自动混合精度训练。其中自动、混合精度是两个关键词,那么代表什么意思呢?一起来看看吧!

✨1 混合精度训练简介

目前,Pytorch一共支持10种数据类型

  • torch.FloatTensor # 另一种表述:FP32**
  • torch.DoubleTensor # 64-bit floating point
  • torch.HalfTensor # 另一种表述:FP16
  • torch.ByteTensor
  • torch.CharTensor
  • torch.ShortTensor
  • torch.IntTensor
  • torch.LongTensor

默认使用的是32位浮点型精度的Tensor,即torch.FloatTensor 。因此,默认情况下我们训练的是一个FP32的模型。但不是所有数据都需要FP32那么大的内存。

此时,采用自动混合精度(Automatic Mixed Precision, AMP)训练,一部分算子数值精度为FP16,其余算子的数值精度是FP32,而哪些算子用FP16,哪些用FP32,由amp自动安排。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更多的 batch size、更大模型和尺寸更大的输入进行训练。

✨2 自动混合精度训练的使用

结合一段示例代码来看:

# amp依赖Tensor core架构,所以模型必须在cuda设备下使用
model = Model()
model.to("cuda")  # 必须!!!
optimizer = optim.SGD(model.parameters(), ...)
# (新增)创建GradScaler对象
scaler = GradScaler(enabled=True)  # 虽然默认为True,体验一下过程
for epoch in epochs:
    for img, target in data:
        optimizer.zero_grad()
        # (新增)启动autocast上下文管理器
        with autocast(enabled=True):
            # (不变)上下文管理器下,model前向传播,以及loss计算自动切换数值精度
            output = model(img)
            loss = loss_fn(output, target)
        # (修改)反向传播
        scaler.scale(loss).backward()
        # (修改)梯度计算
        scaler.step(optimizer)
        # (新增)scaler更新
        scaler.update()

使用自动化精度时,只有在模型以及损失计算,反向传播,梯度更新时作出一定的改变,具体有:

scaler = GradScaler():创建对象GradScaler,并赋予变量scaler

with autocast()::启动autocast上下文管理器,内含需要做精度放缩的计算(必须包含模型计算以及损失计算)

  1. scaler.scale(loss).backward():利用scaler做反向传播
  2. scaler.step(optimizer):梯度更新
  3. scaler.update():scaler更新

🎃 2.2 GradScaler

构造:

torch.cuda.amp.GradScaler(
  init_scale=65536.0,
  growth_factor=2.0,
  backoff_factor=0.5,
  growth_interval=2000,
  enabled=True,
)

这里形式很固定,只有一个参数enabled根据自己的需求进行改变

参数:

  1. enabled:是否做scale。如果为False,则返回原数据。如果为True,则进行一次精度转换。

原谅我,涉及到了原理,水平有限,真的看不懂,欢迎大家交流

🎉 2.2 autocast

autocast(
  enable=True  # 同上
)

上面的示例中,autocast是在训练脚本中使用的,除此之外还有两种方式:

  1. 作为装饰器,在forward函数中使用

==============================================================

class Model(nn.Module):
  def __init__(self):
    pass
  @torch.cuda.amp.autocast()  # autocast导入路径
  def forward():
    pass

==============================================================


在forward中使用上下文管理器

==============================================================

class Model(nn.Module):
  def __init__(self):
    pass
  def forward():
    with torch.cuda.amp.autocast():  # 上下文管理器
      pass


相关文章
|
3月前
|
存储 人工智能 PyTorch
基于PyTorch/XLA的高效分布式训练框架
基于PyTorch/XLA的高效分布式训练框架
255 2
|
3月前
|
机器学习/深度学习 存储 PyTorch
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
172 0
|
3月前
|
机器学习/深度学习 数据采集 PyTorch
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
323 0
|
3月前
|
PyTorch 算法框架/工具
Automatic mixed precision for Pytorch 自动混合精度训练
Automatic mixed precision for Pytorch 自动混合精度训练
44 0
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
310 0
|
2月前
|
机器学习/深度学习 并行计算 PyTorch
使用PyTorch Profiler进行模型性能分析,改善并加速PyTorch训练
加速机器学习模型训练是工程师的关键需求。PyTorch Profiler提供了一种分析工具,用于测量CPU和CUDA时间,以及内存使用情况。通过在训练代码中嵌入分析器并使用tensorboard查看结果,工程师可以识别性能瓶颈。Profiler的`record_function`功能允许为特定操作命名,便于跟踪。优化策略包括使用FlashAttention或FSDP减少内存使用,以及通过torch.compile提升速度。监控CUDA内核执行和内存分配,尤其是避免频繁的cudaMalloc,能有效提升GPU效率。内存历史记录分析有助于检测内存泄漏和优化批处理大小。
134 1
|
1月前
|
机器学习/深度学习 PyTorch TensorFlow
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
|
2月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】