Pytorch中in-place操作相关错误解析及detach()方法说明

本文涉及的产品
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: Pytorch中in-place操作相关错误解析及detach()方法说明

0. 前言

*感谢荼靡,对本文的大力支持。

*感谢新星计划让我认识了优秀的博主。

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

最近在构建nn.RNN模型,及以nn.RNN为基础的nn.LSTM模型遇到了下面这个让人非常头疼的good luck报错:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 1]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

CSDN上关于上面错误提示的说明文章有很多,但是几乎都是直接说明解决方案(而且大部分还没用),欠缺对这个报错的机理解释。因此写作本篇博客记录本问题的解决过程及相关理解。

1. 背景问题描述

把基于nn.RNN的模型简化成以下代码:

import torch

rnn = torch.nn.RNN(input_size=1, hidden_size=1, num_layers=1)

train_set_x = torch.tensor([[[1]],[[2]],[[3]],[[4]],[[5]]], dtype=torch.float32)
train_set_y = torch.tensor([[[2]],[[4]],[[6]],[[8]],[[10]]], dtype=torch.float32)

h0 = torch.tensor([[0]], dtype=torch.float32)
h_cur = h0

loss = torch.nn.MSELoss()
opt = torch.optim.Adadelta(rnn.parameters(), lr = 0.01)

with torch.autograd.set_detect_anomaly(True):
    for i in range(5):
        opt.zero_grad()
        train_output, h_next = rnn(train_set_x[i], h_cur)
        rnn_loss = loss(train_output,train_set_y[i])
        rnn_loss.backward(retain_graph=True)
        opt.step()
        print(train_output)
        h_cur = h_next

原问题链接:Pytorch框架nn.RNN训练时反向传播报错

2. 报错解析:in-place(置位)操作相关理解&说明

上面的错误提示“one of the variables needed for gradient computation has been modified by an inplace operation”,直译就是过来“梯度计算需要的一个变量被一个置位操作更改了”

之前这个问题一直困扰我的原因就是对置位操作的理解不到位,原来我理解置位操作只有形如“x += 1”或“x -= 1”这种运算。但我的原代码中是没有这种运算的,却在报置位操作的错误。

其实置位操作是泛指直接更改内存中的值,而不是先复制一个值再更改复制后的这个值的操作

An in-place operation is an operation that changes directly the content of a given Tensor without making a copy. Inplace operations in pytorch are always postfixed with a , like .add() or .scatter_(). Python operations like += or *= are also inplace operations.

我们常用的赋值方法"a = b",虽然a和b的值是相同的,但是这个相同的值是在两个完全不同的物理地址中,这样如果更改a就不会对b造成任何影响,反之亦然。

但是如果上面的“a = b”是一个置位操作,那么如果改了a,对应b也会做同样的变更,因为他们完全共用一个物理地址,共享同一块内存。

所以置位操作应该慎用,因为可能共享的变量多了,在对这些变量做计算(变更)时,就可能带来一些非期望的变量变更。

那上面的问题代码置位操作在哪?

答:在 “h_cur = h_next” 。通过id()方法,可以看到这两个变量的地址一致。我们认为的赋值操作被Pytorch变成了置位操作

print(id(h_cur))
print(id(h_next))
输出-------------------------------------
2943427659952
2943427659952

回顾上面报错的后半句“[torch.FloatTensor [1, 1]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead. ” 有一个[1,1]的tensor(也就是RNN中的隐层输出h)已经是第3版(version)了,而期望的是第2版(version)。这里的版(version)我理解的就是置位操作的次数。

因为上面 “h_cur = h_next” 这次运算,导致h_cur多操作了一次(版),尽管实际上它的值没变,但是导致了版次不匹配(version mismatch),最终造成了上面的报错。

那为什么Pytorch要默认存在置位操作?

答:为了节省内存,提高运行速度。上面已经说过了,置位操作可以直接更改内存中的数据,而不用先复制一份数据。现在大型的神经网络动辄几万,几十万,甚至上百万个参数,如果不用置位操作,每次backward前都先复制一份参数,再在复制后的参数中进行计算,将会耗费大量的内存来存储这些参数。

隐层输出h如何影响梯度计算?

答:计算公式为:

具体过程请见我此前手推的RNN数学模型:基于Numpy构建RNN模块并进行实例应用(附代码)

当然,如果实在对数学推导过程很抗拒,直接了解这个结论也行:在RNN中隐层输出h是直接参与反向传播梯度计算过程的

最终如何解决in-place操作导致的错误?

答:取消in-place操作。给我们实际上要“赋值”的变量强制指派一块新的内存,可以实现这个目的的方法有detach()方法,clone()方法。由于detach()应用更广泛,下面仅说明detach()方法。

通过常规的引入一个中间变量的方法是不行的,比如:h_cur = mid , mid = h_next 。因为Pytorch仍会默认是置位操作,通过打印id可以看到,h_cur,mid,h_next仍是共用物理地址。

3. detach()方法的作用

①给变量指派一个新的内存

把上面的代码改成:

h_cur = h_next.detach()
print(id(h_cur))
print(id(h_next))
输出---------------------------------------------
3197060036944
3197060036864

两个变量彻底分开,本问题解决。

②把变量变成叶子节点

通过is_leaf()方法可以识别一个变量是否为叶子节点:

h_cur = h_next.detach()
print('h_next_requires_grad:',h_next.requires_grad)
print('h_cur_requires_grad:',h_cur.requires_grad)
print('h_next_is_leaf:',h_next.is_leaf)
print('h_cur_is_lear:',h_cur.is_leaf)
输出---------------------------------------------
h_next_requires_grad: True
h_cur_requires_grad: False
h_next_is_leaf: False
h_cur_is_lear: True

可见detach()方法中断了h_cur的反向传播,把requires_grad设定成了False,且把h_cur设成了叶子节点。

关于叶子节点/非叶子节点的定义及作用,非本文说明对象,推荐一篇写的非常好的博客:Pytorch 叶子张量 leaf tensor (叶子节点) (detach)

4. 更正后代码

更正后完整代码如下:

import torch

rnn = torch.nn.RNN(input_size=1, hidden_size=1, num_layers=1)

train_set_x = torch.tensor([[[[1]]],[[[2]]],[[[3]]],[[[4]]],[[[5]]]], dtype=torch.float32)
train_set_y = torch.tensor([[[[2]]],[[[4]]],[[[6]]],[[[8]]],[[[10]]]], dtype=torch.float32)

h0 = torch.tensor([[[0]]], dtype=torch.float32)
h_cur = h0

loss = torch.nn.MSELoss()
opt = torch.optim.Adadelta(rnn.parameters(), lr = 0.01)


for i in range(5):
    opt.zero_grad()
    train_output, h_next = rnn(train_set_x[0], h_cur)
    rnn_loss = loss(train_output,train_set_y[0])
    h_cur = h_next.detach()
    rnn_loss.backward()
    opt.step()
    print(train_output)


# print(id(h_cur))
# print(id(h_next))
# print('h_next_requires_grad:',h_next.requires_grad)
# print('h_cur_requires_grad:',h_cur.requires_grad)
# print('h_next_is_leaf:',h_next.is_leaf)
# print('h_cur_is_lear:',h_cur.is_leaf)

另外说明,在有些版本中,即使不用detach(),最开始的代码也是可以运行的,例如下面:

推测这可能是较早期的Pytorch中没有默认in-place操作,当然这就会导致上面说的内存消耗变多的问题。


相关文章
|
3月前
|
人工智能
歌词结构的巧妙安排:写歌词的方法与技巧解析,妙笔生词AI智能写歌词软件
歌词创作是一门艺术,关键在于巧妙的结构安排。开头需迅速吸引听众,主体部分要坚实且富有逻辑,结尾则应留下深刻印象。《妙笔生词智能写歌词软件》提供多种 AI 功能,帮助创作者找到灵感,优化歌词结构,写出打动人心的作品。
|
3月前
|
存储 算法 Java
解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用
在Java中,Set接口以其独特的“无重复”特性脱颖而出。本文通过解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用。
61 3
|
22天前
|
安全 Ubuntu Shell
深入解析 vsftpd 2.3.4 的笑脸漏洞及其检测方法
本文详细解析了 vsftpd 2.3.4 版本中的“笑脸漏洞”,该漏洞允许攻击者通过特定用户名和密码触发后门,获取远程代码执行权限。文章提供了漏洞概述、影响范围及一个 Python 脚本,用于检测目标服务器是否受此漏洞影响。通过连接至目标服务器并尝试登录特定用户名,脚本能够判断服务器是否存在该漏洞,并给出相应的警告信息。
143 84
|
3月前
|
人工智能
写歌词的技巧和方法全解析:开启你的音乐创作之旅,妙笔生词智能写歌词软件
怀揣音乐梦想,渴望用歌词抒发情感?掌握关键技巧,你也能踏上创作之旅。灵感来自生活点滴,主题明确,语言简洁,韵律和谐。借助“妙笔生词智能写歌词软件”,AI辅助创作,轻松写出动人歌词,实现音乐梦想。
|
12天前
|
机器学习/深度学习 人工智能 PyTorch
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
本文将系统阐述DPO的工作原理、实现机制,以及其与传统RLHF和SFT方法的本质区别。
68 22
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
|
3天前
|
数据可视化 项目管理
个人和团队都好用的年度复盘工具:看板与KPT方法解析
本文带你了解高效方法KPT复盘法(Keep、Problem、Try),结合看板工具,帮助你理清头绪,快速完成年度复盘。
32 7
个人和团队都好用的年度复盘工具:看板与KPT方法解析
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
50 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
21天前
|
存储 Java 开发者
浅析JVM方法解析、创建和链接
上一篇文章《你知道Java类是如何被加载的吗?》分析了HotSpot是如何加载Java类的,本文再来分析下Hotspot又是如何解析、创建和链接类方法的。
|
29天前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
1月前
|
负载均衡 网络协议 算法
Docker容器环境中服务发现与负载均衡的技术与方法,涵盖环境变量、DNS、集中式服务发现系统等方式
本文探讨了Docker容器环境中服务发现与负载均衡的技术与方法,涵盖环境变量、DNS、集中式服务发现系统等方式,以及软件负载均衡器、云服务负载均衡、容器编排工具等实现手段,强调两者结合的重要性及面临挑战的应对措施。
77 3

热门文章

最新文章

推荐镜像

更多