【一起撸个DL框架】3 前向传播

简介: 3 前向传播🥝3.1 前情提要上一篇:【一起撸个DL框架】2 节点与计算图的搭建

3 前向传播🥝

3.1 前情提要

上一篇:【一起撸个DL框架】2 节点与计算图的搭建

在上一节中,我们定义了加法节点和变量节点类,搭建计算图并实现了加法功能。但还有一个小问题,那就是节点类的定义中,只有父节点有值时,才能调用compute()方法计算本节点的值。而当存在多个节点串联时,就无法直接调用结果节点的compute()方法。因此,这一节我们将采用递归来解决这个问题。

2948636f553c458c88bf0bf91f20ff1e.png

3.2 前向传播:递归的forward方法

我们只需要修改Node类,在其中添加一个forword()方法。当父节点的值为空时,递归地调用forward()计算节点的值,然后再调用compute()计算本节点的值。

class Node:
    def __init__(self, parent1=None, parent2=None) -> None:
        self.parent1 = parent1
        self.parent2 = parent2
        self.value = None
    def set_value(self, value):
        self.value = value
    def compute(self):
        pass
    def forward(self):
        for parent in [self.parent1, self.parent2]:
            if parent.value is None:
                parent.forward()
        self.compute()
        return self.value

然后,我们就可以使用修改过的节点类,搭建出图1中的计算图,并计算节点add2的值。

if __name__ == '__main__':
    # 搭建计算图
    x1 = Varrible()
    x2 = Varrible()
    add1 = Add(x1, x2)
    x3 = Varrible()
    add2 = Add(add1, x3)
    # 输入
    x1.set_value(int(input('请输入x1:')))
    x2.set_value(int(input('请输入x2:')))
    x3.set_value(int(input('请输入x3:')))
    # 前向传播
    y = add2.forward()
    print(y)

运行代码效果如下:

请输入x1:1
请输入x2:2
请输入x3:3
6

3.3 再添乘法节点:搭建函数y=2x+1

函数y = 2 x + 1 y=2x+1y=2x+1的计算图如图2所示,与图1很相似,只是其中一个加法节点换成了乘法节点。但不同之处是,在函数y = 2 x + 1 y=2x+1y=2x+1的计算图中,只有x一个自变量,其余变量节点称为参数。

49b1b0684af74bfa826c48664ee94a39.png

乘法节点类的实现与加法节点差不多,如下所示:

class Mul(Node):
    def __init__(self, parent1=None, parent2=None) -> None:
        super().__init__(parent1, parent2)
    def compute(self):
        self.value = self.parent1.value * self.parent2.value

下面是图2中计算图的搭建:

if __name__ == '__main__':
    # 搭建计算图
    w = Varrible()
    x = Varrible()
    mul = Mul(w, x)
    b = Varrible()
    add = Add(mul, b)
    # 输入
    w.set_value(2)
    b.set_value(1)
    x.set_value(int(input('请输入x:')))
    # 前向传播
    y = add.forward()
    print(y)
请输入x:2
5

3.4 小结

这一节的内容比较简单,我们用递归实现了前向传播,并搭建了一个一次函数:y = 2 x + 1  


相关文章
|
4月前
|
机器学习/深度学习 Java 网络架构
YOLOv5改进 | TripletAttention三重注意力机制(附代码+机制原理+添加教程)
YOLOv5改进 | TripletAttention三重注意力机制(附代码+机制原理+添加教程)
297 0
|
4月前
|
机器学习/深度学习
【一起撸个DL框架】4 反向传播求梯度
4 反向传播求梯度🥥 4.1 简介 上一篇:【一起撸个DL框架】3 前向传播 前面我们已经介绍了前向传播,而本节即将介绍的反向传播中的自动微分机制,可以说是深度学习框架的一个核心功能。因为计算图中的参数正是按照着梯度的指引来更新的。
50 0
|
4月前
|
数据可视化 算法 数据挖掘
R语言SIR模型网络结构扩散过程模拟SIR模型(Susceptible Infected Recovered )代码实例
R语言SIR模型网络结构扩散过程模拟SIR模型(Susceptible Infected Recovered )代码实例
|
4月前
|
机器学习/深度学习 Java 网络架构
YOLOv8改进 | TripletAttention三重注意力机制(附代码+机制原理+添加教程)
YOLOv8改进 | TripletAttention三重注意力机制(附代码+机制原理+添加教程)
690 0
|
4月前
|
机器学习/深度学习 算法 网络安全
【一起撸个DL框架】5 实现:自适应线性单元
5 实现:自适应线性单元🍇 1 简介 上一篇:【一起撸个DL框架】4 反向传播求梯度 上一节我们实现了计算图的反向传播,可以求结果节点关于任意节点的梯度。下面我们将使用梯度来更新参数,实现一个简单的自适应线性单元。
48 0
|
10月前
|
机器学习/深度学习 并行计算 Go
YOLOv5 网络组件与激活函数 代码理解笔记
最近在看YOLOv5 第6个版本的代码,记录了一下笔记,分享一下。首先看了网络结构、网络组件,对应代码models\common.py。然后看了激活函数,对应代码utils\activations.py。
266 0
|
机器学习/深度学习 传感器 安全
【SIR传播】基于matlab模拟复杂网络SIR传播模型
【SIR传播】基于matlab模拟复杂网络SIR传播模型
|
机器学习/深度学习 PyTorch 算法框架/工具
Dropout的深入理解(基础介绍、模型描述、原理深入、代码实现以及变种)
Dropout的深入理解(基础介绍、模型描述、原理深入、代码实现以及变种)
|
机器学习/深度学习 数据可视化 PyTorch
【Pytorch神经网络实战案例】17 带W散度的WGAN-div模型生成Fashon-MNST模拟数据
W散度的损失函数GAN-dv模型使用了W散度来替换W距离的计算方式,将原有的真假样本采样操作换为基于分布层面的计算。
189 0
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch基于迁移学习的VGG卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕VGG神经网络的注释 两个基本一样 只是这个网络是迁移过来的
Pytorch基于迁移学习的VGG卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕VGG神经网络的注释 两个基本一样 只是这个网络是迁移过来的