本文已收录于Pytorch系列专栏: Pytorch入门与实践 专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下坚实的基础。免费订阅,持续更新。
计算图
计算图是用来描述运算的有向无环图
计算图有两个主要元素:
- 结点 Node
- 边 Edge
结点表示数据:如向量,矩阵,张量
边表示运算:如加减乘除卷积等
用计算图表示:y = (x+ w) * (w+1)
a = x + w
b = w + 1
y = a * b
计算图与梯度求导
y = (x+ w) * (w+1)
a = x + w
b = w + 1
y = a * b
$\begin{aligned} \frac{\partial y}{\partial w} &=\frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \\ &=b * 1+a * 1 \\ &=b+a \\ &=(w+1)+(x+w) \\ &=2 * w+x+1 \\ &=2 * 1+2+1=5 \end{aligned}$
可见,对于变量w的求导过程就是寻找它在计算图中的所有路径的求导之和。
code:
import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x) # retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()
print(w.grad)
tensor([5.])
计算图与梯度求导
y = (x+ w) * (w+1)
叶子结点 :用户创建的结点称为叶子结点,如 X 与 W
is_leaf: 指示张量是否为叶子结点
叶子节点的作用是标志存储叶子节点的梯度,而清除在反向传播过程中的变量的梯度,以达到节省内存的目的。当然,如果想要保存过程中变量的梯度值,可以采用retain_grad()
grad_fn: 记录创建该张量时所用的方法(函数)
- y.grad_fn= \<MulBackward0>
- a.grad_fn= \<AddBackward0>
- b.grad_fn= \<AddBackward0>
PyTorch的动态图机制
根据计算图搭建方式,可将计算图分为动态图和静态图
- 动态图
运算与搭建同时进行
灵活 易调节
例如动态图 PyTorch:
- 静态
先搭建图, 后运算
高效 不灵活。
静态图 TensorFlow