开发者社区> 问答> 正文

在牛顿方法的PyTorch实现中更新步骤

一码平川MACHEL 2019-01-23 17:00:06 523

我试图通过实现牛顿求解x = cos(x)的方法来了解PyTorch的工作原理。这是一个有效的版本:

x = Variable(DoubleTensor([1]), requires_grad=True)

for i in range(5):

y = x - torch.cos(x)
y.backward()
x = Variable(x.data - y.data/x.grad.data, requires_grad=True)

print(x.data) # tensor([0.7390851332151607], dtype=torch.float64) (correct)
这个代码对我来说似乎不优雅(低效?),因为它在for循环的每一步中重建整个计算图(对吧?)。我试图通过简单地更新每个变量所拥有的数据而不是重新创建它来避免这种情况:

x = Variable(DoubleTensor([1]), requires_grad=True)
y = x - torch.cos(x)
y.backward(retain_graph=True)

for i in range(5):

x.data = x.data - y.data/x.grad.data
y.data = x.data - torch.cos(x.data)
y.backward(retain_graph=True)

print(x.data) # tensor([0.7417889255761136], dtype=torch.float64) (wrong)
似乎,用DoubleTensors,我携带足够的精度数字来排除舍入误差。那么错误来自哪里?

可能相关:retain_graph=True如果for循环,上面的代码段会在每一步都没有设置标志的情况下中断。如果我在循环中省略它而得到的错误消息---但是保留在第3行---是: RuntimeError:尝试第二次向后遍历图形,但缓冲区已经被释放。第一次向后调用时指定retain_graph = True 。这似乎证明我误解了一些事情......

PyTorch 算法框架/工具
分享到
取消 提交回答
全部回答(1)
  • 一码平川MACHEL
    2019-07-17 23:26:45

    我认为你的第一个代码版本是最优的,这意味着它不会在每次运行时创建计算图。

    initial guess

    guess = torch.tensor([1], dtype=torch.float64, requires_grad = True)

    function to optimize

    def my_func(x):

    return x - torch.cos(x)
    

    def newton(func, guess, runs=5):

    for _ in range(runs): 
        # evaluate our function with current value of `guess`
        value = my_func(guess)
        value.backward()
        # update our `guess` based on the gradient
        guess.data -= (value / guess.grad).data
        # zero out current gradient to hold new gradients in next iteration 
        guess.grad.data.zero_() 
    return guess.data # return our final `guess` after 5 updates
    

    call starts

    result = newton(my_func, guess)

    output of result

    tensor([0.7391], dtype=torch.float64)
    在每次运行中,my_func()使用当前guess值评估定义计算图的函数。一旦返回结果,我们计算梯度(使用value.backward()调用)。使用这个渐变,我们现在更新我们guess的渐变并将其调零,以便在下次调用时重新保持渐变value.backward()(即它会停止累积渐变;不会将渐变归零,它会默认开始累积渐变渐变。但是,我们想避免这种行为)。

    0 0
+ 订阅

了解行业+人工智能最先进的技术和实践,参与行业+人工智能实践项目

推荐文章
相似问题
推荐课程