【一起撸个DL框架】4 反向传播求梯度

简介: 4 反向传播求梯度🥥4.1 简介上一篇:【一起撸个DL框架】3 前向传播前面我们已经介绍了前向传播,而本节即将介绍的反向传播中的自动微分机制,可以说是深度学习框架的一个核心功能。因为计算图中的参数正是按照着梯度的指引来更新的。

4 反向传播求梯度🥥

4.1 简介

上一篇:【一起撸个DL框架】3 前向传播

前面我们已经介绍了前向传播,而本节即将介绍的反向传播中的自动微分机制,可以说是深度学习框架的一个核心功能。因为计算图中的参数正是按照着梯度的指引来更新的。

4.2 导数与梯度

说到“梯度”与“导数”这两个概念,有些同学可能已经有些模糊了。在一元函数的情况下,两者几乎可以混为一谈,然而在多元函数的情况下梯度与导数的概念是有区别的。例如二元函数f ( x , y ) f(x,y)f(x,y),

它沿着二维平面内的每一个方向都会有一个方向导数,方向导数的结果是一个数值,代表沿着该方向的变化率;

屏幕截图 2023-12-28 182244.png

4.3 链式法则

屏幕截图 2023-12-28 182326.png

problem: 梯度不是相对于函数而言的吗?例如▽ f ( w , x ) \triangledown f(w,x)▽f(w,x)。为什么会有“w的梯度”这种概念呢?


在计算图中求一个节点的梯度,只需要将结果节点对子节点的梯度与子节点对自己的梯度乘起来就可以了。

6e9993ab672e4ed0a7f6818c0035edd5.png

4.4 示例:y=2x+1的梯度

为了实现反向传播,我们需要在节点类中加入几个方法。

class Node:
    def __init__(self, parent1=None, parent2=None) -> None:
        self.parent1 = parent1
        self.parent2 = parent2
        self.value = None
        self.grad = None  # 在其它结点求梯度时可能再次用到本结点的梯度
        self.children = []
        parents = [self.parent1, self.parent2]
        for parent in parents:
            if parent is not None:
                parent.children.append(self)
    def set_value(self, value):
        self.value = value
    def compute(self):
        '''抽象方法,在具体的类中重写'''
        pass
    def forward(self):
        for parent in [self.parent1, self.parent2]:
            if parent.value is None:
                parent.forward()
        self.compute()
        return self.value
    def get_parent_grad(self, parent):
        '''求本节点对于父节点的梯度,抽象方法'''
        pass
    def get_grad(self):
        '''求结果节点对本节点的梯度'''
        # 结果结点返回单位值,而不是self.value
        if not self.children:
            return 1
        if self.grad is not None:
            return self.grad
        else:
            self.grad = 0
            for i in range(len(self.children)):
                grad1 = self.children[i].get_parent_grad(parent=self)  # 子节点对自己的梯度
                grad2 = self.children[i].get_grad()  # 结果节点对子节点的梯度
                self.grad += grad1 * grad2
            return self.grad
class Varrible(Node):
    def __init__(self) -> None:
        super().__init__()
class Add(Node):
    def __init__(self, parent1=None, parent2=None) -> None:
        super().__init__(parent1, parent2)
    def compute(self):
        self.value = self.parent1.value + self.parent2.value
    def get_parent_grad(self, parent):
        return 1
class Mul(Node):
    def __init__(self, parent1=None, parent2=None) -> None:
        super().__init__(parent1, parent2)
    def compute(self):
        self.value = self.parent1.value * self.parent2.value
    def get_parent_grad(self, parent):
        '''从parent1,2改成parents时,需要重写'''
        if parent == self.parent1:
            return self.parent2.value
        elif parent == self.parent2:
            return self.parent1.value
        else:
            raise "get which is not a parent of mul node"

在之前的基础上,我们在Node类中新增了两个属性self.grad和self.children,它们都将在get_grad()方法中发挥作用。


然后我们还在Node类中新增了两个方get_parent_grad()和get_grad(),分别用来求子节点对于子节点的父节点(即本节点)的梯度,和结果节点对于本节点的梯度。get_parent_grad()是一个抽象方法,还需在Add类和Mul类中进行具体的实现,实现的方式比较简单,大家阅读源代码即可。


关于get_grad()方法,在本节点没有子节点时,就判断本节点为结果节点,梯度设置为单位值1。在本节点梯度已经存在时,就不再进行递归求值,而直接返回已经保存的梯度值。为什么要保存梯度值?一个节点可以有多个需要求梯度的父节点,这种情况下就会多次用到同一个节点的梯度,每次都递归到结果节点来求值显然是不必要的,于是我们使用空间来换时间。

在下面的示例代码中,我们求w ww和x xx的梯度,就两次用到了mul节点的梯度。

if __name__ == '__main__':
    # 搭建计算图: y=2x+1
    x1 = Varrible()
    w1 = Varrible()
    mul = Mul(x1, w1)
    b = Varrible()
    add = Add(mul, b)
    # 给参数赋值
    w1.set_value(2)
    b.set_value(1)
    # 使用计算图计算
    x1.set_value(int(input("请输入x:")))
    y = add.forward()
    print(f"y: {y}")
    # 反向传播求梯度
    w_grad = w1.get_grad()
    x_grad = x1.get_grad()
    print(f"w_grad: {w_grad}, x_grad: {x_grad}")
    '''
    请输入x:3
    y: 7
    w_grad: 3, x_grad: 2
    '''

然后我们再次搭建了函数y = 2 x + 1 y=2x+1y=2x+1的计算图,首先进行了前向传播,目的是检查计算图是否正确实现了函数y=2x+1的功能。然后进行反向传播求得了w ww和x xx的梯度。


这一节代码的结构已经逐渐变得复杂,一些细节的设计可能需要反复揣摩才能明白,同时代码也还存在一些有待改进的地方,例如Node类的__init__()方法中,每个节点只会有两个父节点的设定其实是不太合适的。


相关文章
|
存储 关系型数据库 MySQL
Linux Centos9 Stream 安装mysql8
Linux Centos9 Stream 安装mysql8
2391 1
|
10月前
|
Python
[oeasy]python074_ai辅助编程_水果程序_fruits_apple_banana_加法_python之禅
本文回顾了从模块导入变量和函数的方法,并通过一个求和程序实例,讲解了Python中输入处理、类型转换及异常处理的应用。重点分析了“明了胜于晦涩”(Explicit is better than implicit)的Python之禅理念,强调代码应清晰明确。最后总结了加法运算程序的实现过程,并预告后续内容将深入探讨变量类型的隐式与显式问题。附有相关资源链接供进一步学习。
207 4
|
Android开发
Android ConstraintLayout按比例缩放View
Android ConstraintLayout按比例缩放View 关键点有两个,第一,使用Android ConstraintLayout的layout_constraintDimensionRatio属性,设置宽高比缩放比例,宽:高。
3005 0
|
编译器 C语言
详解:strerror函数:将错误码转化为错误信息
详解:strerror函数:将错误码转化为错误信息
331 0
详解:strerror函数:将错误码转化为错误信息
|
设计模式 Java API
【Java代理】【静态代理】【动态代理】【动态代理的2种方式】
【Java代理】【静态代理】【动态代理】【动态代理的2种方式】
【Java代理】【静态代理】【动态代理】【动态代理的2种方式】
|
存储 关系型数据库 Linux
PolarDB-CEPH 部署 最新版 | 学习笔记
快速学习 PolarDB-CEPH 部署 最新版,介绍了 PolarDB-CEPH 部署 最新版系统机制, 以及在实际应用过程中如何使用。
PolarDB-CEPH 部署 最新版 | 学习笔记
|
设计模式
GOF设计模式之创建型模式小结
GOF设计模式之创建型模式小结
|
Linux 编译器 编解码
Sleep函数
在VC中Sleep中的第一个英文字符为大写的"S"   在标准C中是sleep, 不要大写.. 下面使用大写的来说明,, 具体用什么看你用什么编译器. 简单的说VC用Sleep, 别的一律使用sleep.   Sleep函数的一般形式:   Sleep(unsigned long);   其中,Sleep()里面的单位,是以毫秒为单位,所以如果想让函数滞留1秒的话,应该是Sleep(1000);   在Linux下,sleep中的“s”不大写   sleep()里面的单位是秒,而不是毫秒。
1146 0
|
17小时前
|
存储 JavaScript 前端开发
JavaScript基础
本节讲解JavaScript基础核心知识:涵盖值类型与引用类型区别、typeof检测类型及局限性、===与==差异及应用场景、内置函数与对象、原型链五规则、属性查找机制、instanceof原理,以及this指向和箭头函数中this的绑定时机。重点突出类型判断、原型继承与this机制,助力深入理解JS面向对象机制。(238字)

热门文章

最新文章