> 复现代码时遇到了自动混合精度。查阅资料得知,Pytorch从1.60开始支持自动混合精度训练。其中自动、混合精度是两个关键词,那么代表什么意思呢?一起来看看吧!
✨1 混合精度训练简介
目前,Pytorch一共支持10种数据类型:
- torch.FloatTensor # 另一种表述:FP32**
- torch.DoubleTensor # 64-bit floating point
- torch.HalfTensor # 另一种表述:FP16
- torch.ByteTensor
- torch.CharTensor
- torch.ShortTensor
- torch.IntTensor
- torch.LongTensor
默认使用的是32位浮点型精度的Tensor,即torch.FloatTensor 。因此,默认情况下我们训练的是一个FP32的模型。但不是所有数据都需要FP32那么大的内存。
此时,采用自动混合精度(Automatic Mixed Precision, AMP)训练,一部分算子数值精度为FP16,其余算子的数值精度是FP32,而哪些算子用FP16,哪些用FP32,由amp自动安排。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更多的 batch size、更大模型和尺寸更大的输入进行训练。
✨2 自动混合精度训练的使用
结合一段示例代码来看:
# amp依赖Tensor core架构,所以模型必须在cuda设备下使用 model = Model() model.to("cuda") # 必须!!! optimizer = optim.SGD(model.parameters(), ...) # (新增)创建GradScaler对象 scaler = GradScaler(enabled=True) # 虽然默认为True,体验一下过程 for epoch in epochs: for img, target in data: optimizer.zero_grad() # (新增)启动autocast上下文管理器 with autocast(enabled=True): # (不变)上下文管理器下,model前向传播,以及loss计算自动切换数值精度 output = model(img) loss = loss_fn(output, target) # (修改)反向传播 scaler.scale(loss).backward() # (修改)梯度计算 scaler.step(optimizer) # (新增)scaler更新 scaler.update()
使用自动化精度时,只有在模型以及损失计算,反向传播,梯度更新时作出一定的改变,具体有:
scaler = GradScaler():创建对象GradScaler,并赋予变量scaler
with autocast()::启动autocast上下文管理器,内含需要做精度放缩的计算(必须包含模型计算以及损失计算)
scaler.scale(loss).backward()
:利用scaler做反向传播scaler.step(optimizer)
:梯度更新scaler.update()
:scaler更新
🎃 2.2 GradScaler
构造:
torch.cuda.amp.GradScaler( init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True, )
这里形式很固定,只有一个参数enabled根据自己的需求进行改变
参数:
- enabled:是否做scale。如果为False,则返回原数据。如果为True,则进行一次精度转换。
原谅我,涉及到了原理,水平有限,真的看不懂,欢迎大家交流
🎉 2.2 autocast
autocast( enable=True # 同上 )
上面的示例中,autocast是在训练脚本中使用的,除此之外还有两种方式:
- 作为装饰器,在forward函数中使用
==============================================================
class Model(nn.Module): def __init__(self): pass @torch.cuda.amp.autocast() # autocast导入路径 def forward(): pass
==============================================================
在forward中使用上下文管理器
==============================================================
class Model(nn.Module): def __init__(self): pass def forward(): with torch.cuda.amp.autocast(): # 上下文管理器 pass