Julia:Zygote 上自定义后向传播

简介: Zygote 是 Julia 上一个实现自动微分、自动求导的包,其中 `@adjoint` 宏是 Zygote 接口的一个重要组成部分。使用 `@adjoint` 可以自定义函数的后向传播。

Zygote 是 Julia 上一个实现自动微分、自动求导的包,其中 @adjoint 宏是 Zygote 接口的一个重要组成部分。使用 @adjoint 可以自定义函数的后向传播。

Pullbacks

要理解 @adjoint 首先要先理解更为底层的函数 pullbackgradient 实际上就是 pullback 的语法糖(syntactic sugar)。

julia> y, back = Zygote.pullback(sin, 0.5)
(0.479425538604203, Zygote.var"#41#42"{Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}}(Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}(ChainRules.var"#sin_pullback#1430"{Float64}(0.8775825618903728))))

julia> y
0.479425538604203

pullback 输入两个参数 sin0.5 分别代表要求导的函数和要求导的值,会得到两个输出:给定函数的结果 sin(0.5) 以及一个 pullback,也就是上面代码中的 back 变量。back 对函数 sin 进行梯度计算,接受的是一个派生,并且产生新的一个变量。。从数学上讲,就是 vector-Jacobian 积的实现。其中 $y=f(x)$ 和梯度 $\frac{\partial{l}}{\partial{x}}$ 写为 $\bar{x}$,pullback $\mathcal{B}_y$ 如下计算:

$$ \bar{x}=\frac{\partial l}{\partial x}=\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}=\mathcal{B}_{y}(\bar{y}) $$

更为具体的讲,以上面的代码为例子,函数 $y=\sin(x)$. $\frac{\partial y}{\partial x}=\cos (x)$,所以 pullback 就为 $\bar{y}\cos(x)$,其中 $\bar{y}=\frac{\partial l}{\partial y}$。换句话说,pullback(sin, x)dsin(x) = (sin(x), ȳ -> (ȳ * cos(x),)) 等价。

gradient 中函数 $l=f(x)$ 并且假设 $\bar{l}=\frac{\partial l}{\partial l}=1$,并且将其输入到 pullback 中。在 sin 的例子中,

julia> dsin(x) = (sin, ȳ -> (ȳ * cos(x),))
dsin (generic function with 1 method)

julia> function gradsin(x)
           _, back = dsin(x)
           back(1)
       end
gradsin (generic function with 1 method)

julia> gradsin(0.5)
(0.8775825618903728,)

julia> cos(0.5)
0.8775825618903728
                
julia> back(1)
(0.8775825618903728,)

个人理解,为什么前面要加一项 $\frac{\partial l}{\partial y}$,这是为了实现链式法则。比如假设最终的损失是 $l$,函数 $y(x)$,要得到损失函数 $l$ 对参数 $x$ 的微分 $\frac{\partial l}{\partial x}$,根据链式法则就是损失函数对函数 $y$ 的微分乘以函数对参数 $x$ 的微分,即 $\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}$。函数 $y$ 的 pullback 就是损失函数对函数 $y$ 的微分(用 $\bar{y}$ 表示)乘以函数对 $x$ 的微分。

对于上面的例子,pullback 函数返回的第一个结果为:假设函数 $y=\sin(x)$ 就是损失函数 $l$ 时,$x=0.5$ 时的结果,即 $\cos(0.5)$,并且返回的 back 就是一个关于 $\frac{\partial l}{\partial y}$ 的函数,可以看成是 $\mathcal{B}(\frac{\partial l}{\partial y})=\frac{\partial l}{\partial y}\cos(0.5)$。

假如 $l=0.5y=0.5\sin(x)$,我们可以得到 $\frac{\partial l}{\partial y}=0.5$,那么 $\frac{\partial l}{\partial x}=\mathcal{B}(\frac{\partial l}{\partial y})=\mathcal{B}(0.5)$。


参考:

[1] Custom Adjoints • Zygote

目录
相关文章
allegro如何看元器件的高度
allegro如何看元器件的高度
907 0
|
人工智能 算法 自动驾驶
使用OpenCV实现Halcon算法(2)形状匹配开源项目,shape_based_matching
使用OpenCV实现Halcon算法(2)形状匹配开源项目,shape_based_matching
5166 1
使用OpenCV实现Halcon算法(2)形状匹配开源项目,shape_based_matching
|
前端开发
自定义 Hook 编写指南
【10月更文挑战第15天】本文介绍了 React 中的 Hooks 和自定义 Hook 的基本概念、编写方法及常见问题。通过具体代码示例,详细讲解了如何在函数组件中使用状态和其他 React 特性,并分享了避免常见错误的技巧。自定义 Hook 可以帮助你将组件中的逻辑提取出来,使其更加可重用和可维护。
761 68
|
Java 关系型数据库 MySQL
Maven——创建 Spring Boot项目
Maven 是一个项目管理工具,通过配置 `pom.xml` 文件自动获取所需的 jar 包,简化了项目的构建和管理过程。其核心功能包括项目构建和依赖管理,支持创建、编译、测试、打包和发布项目。Maven 仓库分为本地仓库和远程仓库,远程仓库包括中央仓库、私服和其他公共库。此外,文档还介绍了如何创建第一个 SpringBoot 项目并实现简单的 HTTP 请求响应。
982 1
Maven——创建 Spring Boot项目
GEE——Google dynamic world中在影像导出过程中无法完全导出较大面积影像的解决方案(投影的转换)EPSG:32630和EPSG:4326的区别
GEE——Google dynamic world中在影像导出过程中无法完全导出较大面积影像的解决方案(投影的转换)EPSG:32630和EPSG:4326的区别
329 0
|
存储 运维 安全
Spring运维之boot项目多环境(yaml 多文件 proerties)及分组管理与开发控制
通过以上措施,可以保证Spring Boot项目的配置管理在专业水准上,并且易于维护和管理,符合搜索引擎收录标准。
714 2
|
机器学习/深度学习 人工智能 自然语言处理
评测:AI 大模型助力客户对话分析
该评测报告详细介绍了Al大模型在客户对话分析中的应用,涵盖了实践原理、实施方法、部署体验、示例代码及业务适应性。报告指出,该方案利用NLP和机器学习技术,深度解析对话内容,精准识别用户意图,显著提升服务质量与客户体验。实施方法清晰明了,文档详尽,部署体验顺畅,提供了丰富的引导和支持。示例代码实用性强,但在依赖库安装和资源限制方面需注意调整。整体上,该方案能够满足基本对话分析需求,但在特定行业场景中还需进一步定制化开发。
|
网络安全 语音技术
语音情感基座模型emotion4vec 问题之计算emotion2vec模型中的总损失L,如何操作
语音情感基座模型emotion4vec 问题之计算emotion2vec模型中的总损失L,如何操作
177 1
|
Ubuntu 安全 Linux
Linux必备|如何重置忘记的 Root 密码
Linux必备|如何重置忘记的 Root 密码
2280 7
|
JavaScript
Vue中 引入使用 vue-splitpane 实现窗格的拆分、调节
Vue中 引入使用 vue-splitpane 实现窗格的拆分、调节
2576 0
Vue中 引入使用 vue-splitpane 实现窗格的拆分、调节