【AI系统】动手实现自动微分

简介: 本章介绍如何实现自动微分,重点讲解前向自动微分的原理及Python实现方法。通过操作符重载,将程序分解为基础表达式组合,利用链式法则计算导数。示例代码展示了如何使用自定义类`ADTangent`实现加、减、乘、log、sin等操作的自动微分,验证了与PyTorch和MindSpore等框架的一致性。

在这章内容,会介绍是怎么实现自动微分的,因为代码量非常小,也许你也可以写一个玩玩。前面的文章当中,已经把自动微分的原理深入浅出的讲了一下,也引用了非常多的论文。有兴趣的可以顺着综述 A survey 这篇深扒一下。

前向自动微分原理

了解自动微分的不同实现方式非常有用。在这里呢,我们将介绍主要的前向自动微分,通过 Python 这个高级语言来实现操作符重载。在正反向模式中的这篇的文章中,我们介绍了前向自动微分的基本数学原理。

前向模式(Forward Automatic Differentiation,也叫做 tangent mode AD)或者前向累积梯度(前向模式)

前向自动微分中,从计算图的起点开始,沿着计算图边的方向依次向前计算,最终到达计算图的终点。它根据自变量的值计算出计算图中每个节点的值以及其导数值,并保留中间结果。一直得到整个函数的值和其导数值。整个过程对应于一元复合函数求导时从最内层逐步向外层求导。

image

简单确实简单,可以总结前向自动微分关键步骤为:

  • 分解程序为一系列已知微分规则的基础表达式的组合
  • 根据已知微分规则给出各基础表达式的微分结果
  • 根据基础表达式间的数据依赖关系,使用链式法则将微分结果组合完成程序的微分结果

而通过 Python 高级语言,进行操作符重载后的关键步骤其实也相类似:

  • 分解程序为一系列已知微分规则的基础表达式组合,并使用高级语言的重载操作
  • 在重载运算操作的过程中,根据已知微分规则给出各基础表达式的微分结果
  • 根据基础表达式间的数据依赖关系,使用链式法则将微分结果组合完成程序的微分结果

具体实现

首先呢,我们需要加载通用的 numpy 库,用于实际运算的,如果不用 numpy,在 python 中也可以使用 math 来代替。

import numpy as np

前向自动微分又叫做 tangent mode AD,所以我们准备一个叫做 ADTangent 的类,这类初始化的时候有两个参数,一个是 x,表示输入具体的数值;另外一个是 dx,表示经过对自变量 x 求导后的值。

需要注意的是,操作符重载自动微分不像源码转换可以给出求导的公式,一般而言并不会给出求导公式,而是直接给出最后的求导值,所以就会有 dx 的出现。

class ADTangent:

    # 自变量 x,对自变量进行求导得到的 dx
    def __init__(self, x, dx):
        self.x = x
        self.dx = dx

    # 重载 str 是为了方便打印的时候,看到输入的值和求导后的值
    def __str__(self):
        context = f'value:{self.x:.4f}, grad:{self.dx}'
        return context

下面是核心代码,也就是操作符重载的内容,在 ADTangent 类中通过 Python 私有函数重载加号,首先检查输入的变量 other 是否属于 ADTangent,如果是那么则把两者的自变量 x 进行相加。

其中值得注意的就是 dx 的计算,因为是正向自动微分,因此每一个前向的计算都会有对应的反向求导计算。求导的过程是这个程序的核心,不过不用担心的是这都是最基础的求导法则。最后返回自身的对象 ADTangent(x, dx)。

    def __add__(self, other):
        if isinstance(other, ADTangent):
            x = self.x + other.x
            dx = self.dx + other.dx
        elif isinstance(other, float):
            x = self.x + other
            dx = self.dx
        else:
            return NotImplementedError
        return ADTangent(x, dx)

下面则是对减号、乘法、log、sin 几个操作进行操作符重载,正向的重载的过程比较简单,基本都是按照上面的 add 的代码讨论来实现。

    def __sub__(self, other):
        if isinstance(other, ADTangent):
            x = self.x - other.x
            dx = self.dx - other.dx
        elif isinstance(other, float):
            x = self.x - other
            ex = self.dx
        else:
            return NotImplementedError
        return ADTangent(x, dx)

    def __mul__(self, other):
        if isinstance(other, ADTangent):
            x = self.x * other.x
            dx = self.x * other.dx + self.dx * other.x
        elif isinstance(other, float):
            x = self.x * other
            dx = self.dx * other
        else:
            return NotImplementedError
        return ADTangent(x, dx)

    def log(self):
        x = np.log(self.x)
        dx = 1 / self.x * self.dx
        return ADTangent(x, dx)

    def sin(self):
        x = np.sin(self.x)
        dx = self.dx * np.cos(self.x)
        return ADTangent(x, dx)

以公式为例:

$$ f(x_{1},x_{2})=ln(x_{1})+x_{1}*x_{2}−sin(x_{2}) \tag{1} \\ $$

因为是基于 ADTangent 类进行了操作符重载,因此在初始化自变量 x 和 y 的值需要使用 ADTangent 来初始化,然后通过代码 f = ADTangent.log(x) + x * y - ADTangent.sin(y) 来实现。

由于这里是求 f 关于自变量 x 的导数,因此初始化数据的时候呢,自变量 x 的 dx 设置为 1,而自变量 y 的 dx 设置为 0。

x = ADTangent(x=2., dx=1)
y = ADTangent(x=5., dx=0)
f = ADTangent.log(x) + x * y - ADTangent.sin(y)
print(f)
    value:11.6521, grad:5.5

从输出结果来看,正向计算的输出结果是跟上面图相同,而反向的导数求导结果也与上图相同。下面一个是 Pytroch 的实现结果对比,最后是 MindSpore 的实现结果对比。

可以看到呢,上面的简单实现的自动微分结果和 Pytroch 、MindSpore 是相同的。还是很有意思的。

Pytroch 对公式 1 的自动微分结果:

import torch
from torch.autograd import Variable

x = Variable(torch.Tensor([2.]), requires_grad=True)
y = Variable(torch.Tensor([5.]), requires_grad=True)
f = torch.log(x) + x * y - torch.sin(y)
f.backward()

print(f)
print(x.grad)
print(y.grad)

输出结果:

    tensor([11.6521], grad_fn=<SubBackward0>)
    tensor([5.5000])
    tensor([1.7163])

MindSpore 对公式 1 的自动微分结果:

import numpy as np
import mindspore.nn as nn
from mindspore import Parameter, Tensor

class Fun(nn.Cell):
    def __init__(self):
        super(Fun, self).__init__()

    def construct(self, x, y):
        f = ops.log(x) + x * y - ops.sin(y)
        return f

x = Tensor(np.array([2.], np.float32))
y = Tensor(np.array([5.], np.float32))
f = Fun()(x, y)

grad_all = ops.GradOperation()
grad = grad_all(Fun())(x, y)

print(f)
print(grad[0])

输出结果:

    [11.65207]
    5.5

如果您想了解更多AI知识,与AI专业人士交流,请立即访问昇腾社区官方网站https://www.hiascend.com/或者深入研读《AI系统:原理与架构》一书,这里汇聚了海量的AI学习资源和实践课程,为您的AI技术成长提供强劲动力。不仅如此,您还有机会投身于全国昇腾AI创新大赛和昇腾AI开发者创享日等盛事,发现AI世界的无限奥秘~

目录
相关文章
|
8天前
|
人工智能 监控 安全
提效40%?揭秘AI驱动的支付方式“一键接入”系统
本项目构建AI驱动的研发提效系统,通过Qwen Coder与MCP工具链协同,实现跨境支付渠道接入的自动化闭环。采用多智能体协作模式,结合结构化Prompt、任务拆解、流程管控与安全约束,显著提升研发效率与交付质量,探索大模型在复杂业务场景下的高采纳率编码实践。
提效40%?揭秘AI驱动的支付方式“一键接入”系统
|
9天前
|
人工智能 自然语言处理 前端开发
最佳实践2:用通义灵码以自然语言交互实现 AI 高考志愿填报系统
本项目旨在通过自然语言交互,结合通义千问AI模型,构建一个智能高考志愿填报系统。利用Vue3与Python,实现信息采集、AI推荐、专业详情展示及数据存储功能,支持响应式设计与Supabase数据库集成,助力考生精准择校选专业。(239字)
78 12
|
5天前
|
存储 人工智能 搜索推荐
LangGraph 记忆系统实战:反馈循环 + 动态 Prompt 让 AI 持续学习
本文介绍基于LangGraph构建的双层记忆系统,通过短期与长期记忆协同,实现AI代理的持续学习。短期记忆管理会话内上下文,长期记忆跨会话存储用户偏好与决策,结合人机协作反馈循环,动态更新提示词,使代理具备个性化响应与行为进化能力。
97 10
LangGraph 记忆系统实战:反馈循环 + 动态 Prompt 让 AI 持续学习
|
8天前
|
人工智能 JSON 安全
Claude Code插件系统:重塑AI辅助编程的工作流
Anthropic为Claude Code推出插件系统与市场,支持斜杠命令、子代理、MCP服务器等功能模块,实现工作流自动化与团队协作标准化。开发者可封装常用工具或知识为插件,一键共享复用,构建个性化AI编程环境,推动AI助手从工具迈向生态化平台。
111 1
|
9天前
|
机器学习/深度学习 人工智能 自然语言处理
拔俗当AI成为你的“心灵哨兵”:多模态心理风险预警系统如何工作?
AI多模态心理预警系统通过融合表情、语调、文字、绘画等多维度数据,结合深度学习与多模态分析,实时评估心理状态。它像“心灵哨兵”,7×24小时动态监测情绪变化,发现抑郁、焦虑等风险及时预警,兼顾隐私保护,助力早期干预,用科技守护心理健康。(238字)
|
9天前
|
存储 人工智能 自然语言处理
拔俗AI产投公司档案管理系统:让数据资产 “活” 起来的智能助手
AI产投档案管理系统通过NLP、知识图谱与加密技术,实现档案智能分类、秒级检索与数据关联分析,破解传统人工管理效率低、数据孤岛难题,助力投资决策提效与数据资产化,推动AI产投数字化转型。
|
9天前
|
人工智能 算法 数据安全/隐私保护
拔俗AI多模态心理风险预警系统:用科技守护心理健康的第一道防线
AI多模态心理风险预警系统通过语音、文本、表情与行为数据,智能识别抑郁、焦虑等心理风险,实现早期干预。融合多源信息,提升准确率,广泛应用于校园、企业,助力心理健康服务从“被动响应”转向“主动预防”,为心灵筑起智能防线。(238字)
|
9天前
|
人工智能 搜索推荐 Cloud Native
拔俗AI助教系统:教师的"超级教学秘书",让每堂课都精准高效
备课到深夜、批改作业如山?阿里云原生AI助教系统,化身“超级教学秘书”,智能备课、实时学情分析、自动批改、精准辅导,为教师减负增效。让课堂从经验驱动转向数据驱动,每位学生都被看见,教育更有温度。
|
9天前
|
机器学习/深度学习 人工智能 监控
拔俗AI智能营运分析助手软件系统:企业决策的"数据军师",让经营从"拍脑袋"变"精准导航"
AI智能营运分析助手打破数据孤岛,实时整合ERP、CRM等系统数据,自动生成报表、智能预警与可视化决策建议,助力企业从“经验驱动”迈向“数据驱动”,提升决策效率,降低运营成本,精准把握市场先机。(238字)
|
9天前
|
存储 人工智能 自然语言处理
拔俗AI自动化评价分析系统:让数据说话,让决策更智能
在用户体验为核心的时代,传统评价分析面临效率低、洞察浅等痛点。本文基于阿里云AI与大数据技术,构建“数据-算法-应用”三层智能分析体系,实现多源数据实时接入、情感与主题精准识别、跨模态融合分析及实时预警,助力企业提升运营效率、加速产品迭代、优化服务质量,并已在头部电商平台成功落地,显著提升用户满意度与商业转化。

热门文章

最新文章