PyTorch 是一个非常流行的深度学习框架,它支持动态计算图,非常适合快速原型设计和研究。但随着模型规模的增长和数据集的扩大,如何充分利用 GPU 来加速训练过程变得尤为重要。本文将详细介绍 11 个实用的技巧,帮助你优化 PyTorch 代码性能。
技巧 1:使用 .to(device) 进行数据传输
在 PyTorch 中,可以通过 .to(device) 方法将张量和模型转移到 GPU 上。这一步骤是利用 GPU 计算能力的基础。
示例代码:
复制
import torch
创建设备对象
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
将张量移到 GPU 上
x = torch.tensor([1, 2, 3]).to(device)
y = torch.tensor([4, 5, 6], device=device) # 直接指定设备
将模型移到 GPU 上
model = torch.nn.Linear(3, 1).to(device)
print(x)
print(y)
print(next(model.parameters()).device)
输出结果:
复制
tensor([1, 2, 3], device='cuda:0')
tensor([4, 5, 6], device='cuda:0')
cuda:0
技巧 2:使用 torch.no_grad() 减少内存消耗
在训练过程中,torch.autograd 会自动记录所有操作以便计算梯度。但在评估模型时,我们可以关闭自动梯度计算以减少内存占用。
示例代码:
复制
with torch.no_grad():
predictions = model(x)
print(predictions)
输出结果:
复制
tensor([[12.]], device='cuda:0')
技巧 3:使用 torch.backends.cudnn.benchmark = True 加速卷积层
CuDNN 库提供了高度优化的卷积实现。通过设置 torch.backends.cudnn.benchmark = True,可以让 PyTorch 在每次运行前选择最适合当前输入大小的算法。
示例代码:
复制
torch.backends.cudnn.benchmark = True
conv_layer = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1).to(device)
input_tensor = torch.randn(1, 3, 32, 32).to(device)
output = conv_layer(input_tensor)
print(output.shape)
输出结果:
复制
torch.Size([1, 32, 32, 32])
技巧 4:使用 torch.utils.data.DataLoader 并行加载数据
数据加载通常是训练过程中的瓶颈之一。DataLoader 可以多线程加载数据,从而加速这一过程。
示例代码:
复制
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(x, y)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)
for inputs, labels in data_loader:
outputs = model(inputs)
print(outputs)
输出结果:
复制
tensor([[12.]], device='cuda:0')
技巧 5:使用混合精度训练
混合精度训练结合了单精度和半精度(FP16)浮点运算,可以显著减少内存消耗并加速训练过程。
示例代码:
复制
from torch.cuda.amp import autocast, GradScaler
model = torch.nn.Linear(3, 1).to(device)
scaler = GradScaler()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for i in range(10):
optimizer.zero_grad()
with autocast():
output = model(x)
loss = torch.nn.functional.mse_loss(output, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print(f"Iteration {i + 1}: Loss = {loss.item():.4f}")
复制
Iteration 1: Loss = 18.0000
Iteration 2: Loss = 17.8203
Iteration 3: Loss = 17.6406
...