PyTorch:常见错误 inplace operation

简介: `inplace` 操作是 PyTorch 里面一个比较常见的错误,有的时候会比较好发现,但是有的时候同样类似的报错,会比较不好发现。

inplace 操作是 PyTorch 里面一个比较常见的错误,有的时候会比较好发现,例如下面的代码:

import torch
w = torch.rand(4, requires_grad=True)
w += 1
loss = w.sum()
loss.backward()

执行 loss 对参数 w 进行求导,会出现报错:RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

导致这个报错的主要是第 3 行代码 w += 1,如果把这句改成 w = w + 1,再执行就不会报错了。这种写法导致的 inplace operation 是比较好发现的,但是有的时候同样类似的报错,会比较不好发现。例如下面的代码:

import torch
x = torch.zeros(4)
w = torch.rand(4, requires_grad=True)
x[0] = torch.rand(1) * w[0]
for i in range(3):
    x[i+1] = torch.sin(x[i]) * w[i]
loss = x.sum()
loss.backward()

执行之后会出现报错:

>>> RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: 
[torch.FloatTensor []], which is output 0 of SelectBackward, is at version 4; expected version 3 instead. 
Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

根据提示我们可以使用 with torch.autograd.set_detect_anomaly(True) 来帮助我们定位具体的出错位置(这个方法会花费比较长的时间)。

with torch.autograd.set_detect_anomaly(True):
    x = torch.zeros(4)
    w = torch.rand(4, requires_grad=True)
    x[0] = torch.rand(1) * w[0]
    for i in range(3):
        x[i+1] = torch.sin(x[i]) * w[i]
    loss = x.sum()
    loss.backward()

运行会增加这些报错:

>>> /Users/strongnine/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:130: 
UserWarning: Error detected in SinBackward. Traceback of forward call that caused the error:

可以看到出现了 Error detected in SinBackward.,这句描述,我们可以猜测大概是 torch.sin() 这个函数出现了问题。实际上,这个报错的解决办法,就是将第 6 行代码 x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i],就行了。

import torch
x = torch.zeros(4)
w = torch.rand(4, requires_grad=True)
x[0] = torch.rand(1) * w[0]
for i in range(3):
    x[i+1] = torch.sin(x[i].clone()) * w[i]
loss = x.sum()
loss.backward()

总结一下,遇到 inplace operation 的报错,一般可以通过:

  • x += 1 改成 x = x + 1
  • x[:, :, 0:3] = x[:, :, 0:3] + 1 改成 x[:, :, 0:3] = x[:, :, 0:3].clone() + 11
  • x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i]

如果自己检查不出是哪里出现了问题,可以使用 with torch.autograd.set_detect_anomaly(True) 来帮助我们定位具体的出错位置,但是要注意的是这个方法一般会运行比较长的时间。

目录
相关文章
|
存储 机器学习/深度学习 PyTorch
|
16小时前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
8月前
|
机器学习/深度学习 自然语言处理 算法
【NLP】Pytorch构建神经网络
【NLP】Pytorch构建神经网络
|
16小时前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
102 2
|
16小时前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
|
16小时前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习中卷积神经网络(CNN)的讲解及图像处理实战(超详细 附源码)
PyTorch深度学习中卷积神经网络(CNN)的讲解及图像处理实战(超详细 附源码)
133 0
|
16小时前
|
机器学习/深度学习 搜索推荐 数据可视化
PyTorch搭建基于图神经网络(GCN)的天气推荐系统(附源码和数据集)
PyTorch搭建基于图神经网络(GCN)的天气推荐系统(附源码和数据集)
103 0
|
16小时前
|
机器学习/深度学习 数据采集 自然语言处理
PyTorch搭建LSTM神经网络实现文本情感分析实战(附源码和数据集)
PyTorch搭建LSTM神经网络实现文本情感分析实战(附源码和数据集)
187 0
|
16小时前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch搭建卷积神经网络(CNN)进行视频行为识别(附源码和数据集)
PyTorch搭建卷积神经网络(CNN)进行视频行为识别(附源码和数据集)
45 0