3 前向传播🥝
3.1 前情提要
上一篇:【一起撸个DL框架】2 节点与计算图的搭建
在上一节中,我们定义了加法节点和变量节点类,搭建计算图并实现了加法功能。但还有一个小问题,那就是节点类的定义中,只有父节点有值时,才能调用compute()方法计算本节点的值。而当存在多个节点串联时,就无法直接调用结果节点的compute()方法。因此,这一节我们将采用递归来解决这个问题。
3.2 前向传播:递归的forward方法
我们只需要修改Node类,在其中添加一个forword()
方法。当父节点的值为空时,递归地调用forward()
计算节点的值,然后再调用compute()
计算本节点的值。
class Node: def __init__(self, parent1=None, parent2=None) -> None: self.parent1 = parent1 self.parent2 = parent2 self.value = None def set_value(self, value): self.value = value def compute(self): pass def forward(self): for parent in [self.parent1, self.parent2]: if parent.value is None: parent.forward() self.compute() return self.value
然后,我们就可以使用修改过的节点类,搭建出图1中的计算图,并计算节点add2的值。
if __name__ == '__main__': # 搭建计算图 x1 = Varrible() x2 = Varrible() add1 = Add(x1, x2) x3 = Varrible() add2 = Add(add1, x3) # 输入 x1.set_value(int(input('请输入x1:'))) x2.set_value(int(input('请输入x2:'))) x3.set_value(int(input('请输入x3:'))) # 前向传播 y = add2.forward() print(y)
运行代码效果如下:
请输入x1:1 请输入x2:2 请输入x3:3 6
3.3 再添乘法节点:搭建函数y=2x+1
函数y = 2 x + 1 y=2x+1y=2x+1的计算图如图2所示,与图1很相似,只是其中一个加法节点换成了乘法节点。但不同之处是,在函数y = 2 x + 1 y=2x+1y=2x+1的计算图中,只有x一个自变量,其余变量节点称为参数。
乘法节点类的实现与加法节点差不多,如下所示:
class Mul(Node): def __init__(self, parent1=None, parent2=None) -> None: super().__init__(parent1, parent2) def compute(self): self.value = self.parent1.value * self.parent2.value
下面是图2中计算图的搭建:
if __name__ == '__main__': # 搭建计算图 w = Varrible() x = Varrible() mul = Mul(w, x) b = Varrible() add = Add(mul, b) # 输入 w.set_value(2) b.set_value(1) x.set_value(int(input('请输入x:'))) # 前向传播 y = add.forward() print(y)
请输入x:2 5
3.4 小结
这一节的内容比较简单,我们用递归实现了前向传播,并搭建了一个一次函数:y = 2 x + 1