JAX 中文文档(八)(3)

简介: JAX 中文文档(八)

JAX 中文文档(八)(2)https://developer.aliyun.com/article/1559663


JAX 可转换的 Python 函数的自定义导数规则

原文:jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

[外链图片转存中…(img-Tun1SV4O-1718950304061)]

mattjj@ Mar 19 2020, last updated Oct 14 2020

JAX 中定义微分规则的两种方式:

  1. 使用 jax.custom_jvpjax.custom_vjp 来为已经可转换为 JAX 的 Python 函数定义自定义微分规则;以及
  2. 定义新的 core.Primitive 实例及其所有转换规则,例如调用来自其他系统(如求解器、模拟器或一般数值计算系统)的函数。

本笔记本讨论的是 #1. 要了解关于 #2 的信息,请参阅关于添加原语的笔记本

关于 JAX 自动微分 API 的介绍,请参阅自动微分手册。本笔记本假定读者已对jax.jvpjax.grad,以及 JVPs 和 VJPs 的数学含义有一定了解。

TL;DR

使用 jax.custom_jvp 进行自定义 JVPs

import jax.numpy as jnp
from jax import custom_jvp
@custom_jvp
def f(x, y):
  return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out 
from jax import jvp, grad
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.)) 
2.7278922
2.7278922
-1.2484405
-1.2484405 
# Equivalent alternative using the defjvps convenience wrapper
@custom_jvp
def f(x, y):
  return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) 
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.)) 
2.7278922
2.7278922
-1.2484405
-1.2484405 

使用 jax.custom_vjp 进行自定义 VJPs

from jax import custom_vjp
@custom_vjp
def f(x, y):
  return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
  cos_x, sin_x, y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd) 
print(grad(f)(2., 3.)) 
-1.2484405 

示例问题

要了解 jax.custom_jvpjax.custom_vjp 所解决的问题,我们可以看几个例子。有关 jax.custom_jvpjax.custom_vjp API 的更详细介绍在下一节中。

数值稳定性

jax.custom_jvp 的一个应用是提高微分的数值稳定性。

假设我们想编写一个名为 log1pexp 的函数,用于计算 (x \mapsto \log ( 1 + e^x ))。我们可以使用 jax.numpy 来写:

import jax.numpy as jnp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))
log1pexp(3.) 
Array(3.0485873, dtype=float32, weak_type=True) 

因为它是用 jax.numpy 编写的,所以它是 JAX 可转换的:

from jax import jit, grad, vmap
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

但这里存在一个数值稳定性问题:

print(grad(log1pexp)(100.)) 
nan 

那似乎不对!毕竟,(x \mapsto \log (1 + e^x)) 的导数是 (x \mapsto \frac{e^x}{1 + e^x}),因此对于大的 (x) 值,我们期望值约为 1。

通过查看梯度计算的 jaxpr,我们可以更深入地了解发生了什么:

from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.) 
{ lambda ; a:f32[]. let
    b:f32[] = exp a
    c:f32[] = add 1.0 b
    _:f32[] = log c
    d:f32[] = div 1.0 c
    e:f32[] = mul d b
  in (e,) } 

通过分析 jaxpr 如何评估,我们可以看到最后一行涉及的值相乘会导致浮点数计算四舍五入为 0 和 (\infty),这从未是一个好主意。也就是说,我们实际上在评估大数值的情况下,计算的是 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x),这实际上会变成 0. * jnp.inf

而不是生成这样大和小的值,希望浮点数能够提供的取消,我们宁愿将导数函数表达为一个更稳定的数值程序。特别地,我们可以编写一个程序,更接近地评估相等的数学表达式 (1 - \frac{1}{1 + e^x}),看不到取消。

这个问题很有趣,因为即使我们的log1pexp的定义已经可以进行 JAX 微分(并且可以使用jitvmap等转换),我们对应用标准自动微分规则到组成log1pexp并组合结果的结果并不满意。相反,我们想要指定整个函数log1pexp如何作为一个单位进行微分,从而更好地安排这些指数。

这是关于 Python 函数的自定义导数规则的一个应用,这些函数已经可以使用 JAX 进行转换:指定如何对复合函数进行微分,同时仍然使用其原始的 Python 定义进行其他转换(如jitvmap等)。

这里是使用jax.custom_jvp的解决方案:

from jax import custom_jvp
@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = log1pexp(x)
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
  return ans, ans_dot 
print(grad(log1pexp)(100.)) 
1.0 
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

这里是一个defjvps方便包装,来表达同样的事情:

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t) 
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
1.0
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

强制执行微分约定

一个相关的应用是强制执行微分约定,也许在边界处。

考虑函数 (f : \mathbb{R}+ \to \mathbb{R}+),其中 (f(x) = \frac{x}{1 + \sqrt{x}}),其中我们取 (\mathbb{R}_+ = [0, \infty))。我们可以像这样实现 (f) 的程序:

def f(x):
  return x / (1 + jnp.sqrt(x)) 

作为在(\mathbb{R})上的数学函数(完整的实数线),(f) 在零点是不可微的(因为从左侧定义导数的极限不存在)。相应地,自动微分产生一个nan值:

print(grad(f)(0.)) 
nan 

但是数学上,如果我们将 (f) 视为 (\mathbb{R}_+) 上的函数,则它在 0 处是可微的 [Rudin 的《数学分析原理》定义  5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和例子  10.1.6]。或者,我们可能会说,作为一个惯例,我们希望考虑从右边的方向导数。因此,对于 Python 函数grad(f)0.0处返回 1.0 是有意义的值。默认情况下,JAX 对微分的机制假设所有函数在(\mathbb{R})上定义,因此这里并不会产生1.0

我们可以使用自定义的 JVP 规则!特别地,我们可以定义 JVP 规则,关于导数函数 (x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)²}) 在 (\mathbb{R}_+) 上,

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
  return ans, ans_dot 
print(grad(f)(0.)) 
1.0 

这里是方便包装版本:

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))
f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t) 
print(grad(f)(0.)) 
1.0 

梯度剪裁

虽然在某些情况下,我们想要表达一个数学微分计算,在其他情况下,我们甚至可能想要远离数学,来调整自动微分的计算。一个典型的例子是反向模式梯度剪裁。

对于梯度剪裁,我们可以使用jnp.clip和一个jax.custom_vjp仅逆模式规则:

from functools import partial
from jax import custom_vjp
@custom_vjp
def clip_gradient(lo, hi, x):
  return x  # identity function
def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save bounds as residuals
def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # use None to indicate zero cotangents for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) 
import matplotlib.pyplot as plt
from jax import vmap
t = jnp.linspace(0, 10, 1000)
plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t)) 
[<matplotlib.lines.Line2D at 0x7f43dfc210f0>] 

def clip_sin(x):
  x = clip_gradient(-0.75, 0.75, x)
  return jnp.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t)) 
[<matplotlib.lines.Line2D at 0x7f43ddb15fc0>] 

Python 调试

另一个应用,是受开发工作流程而非数值驱动的动机,是在反向模式自动微分的后向传递中设置pdb调试器跟踪。

在尝试追踪nan运行时错误的来源,或者仅仔细检查传播的余切(梯度)值时,可以在反向传递中的特定点插入调试器非常有用。您可以使用jax.custom_vjp来实现这一点。

我们将在下一节中推迟一个示例。


JAX 中文文档(八)(4)https://developer.aliyun.com/article/1559665

相关文章
|
编解码 索引
pcl 无序点云数据空间变化检测
pcl 无序点云数据空间变化检测
pcl 无序点云数据空间变化检测
|
8月前
|
人工智能 运维 Cloud Native
全面开测 - 零门槛,即刻拥有DeepSeek-R1满血版,百万token免费用
DeepSeek是当前热门的推理模型,尤其擅长数学、代码和自然语言等复杂任务。2024年尾,面对裁员危机,技术进步的学习虽减少,但DeepSeek大模型的兴起成为新的学习焦点。满血版DeepSeek(671B参数)与普通版相比,在性能、推理能力和资源需求上有显著差异。满血版支持实时联网数据更新和多轮深度对话,适用于科研、教育和企业级应用等复杂场景。 阿里云提供的满血版DeepSeek部署方案对普通用户特别友好,涵盖云端调用API及各尺寸模型的部署方式,最快5分钟、最低0元即可实现。
1091 68
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
89 0
|
8月前
|
存储 人工智能 算法
Magic 1-For-1:北大联合英伟达推出的高质量视频生成量化模型,支持在消费级GPU上快速生成
北京大学、Hedra Inc. 和 Nvidia 联合推出的 Magic 1-For-1 模型,优化内存消耗和推理延迟,快速生成高质量视频片段。
395 3
Magic 1-For-1:北大联合英伟达推出的高质量视频生成量化模型,支持在消费级GPU上快速生成
|
前端开发
彻底搞懂css盒子模型
【10月更文挑战第1天】
299 9
|
机器学习/深度学习 算法 计算机视觉
通过MATLAB分别对比二进制编码遗传优化算法和实数编码遗传优化算法
摘要: 使用MATLAB2022a对比了二进制编码与实数编码的遗传优化算法,关注最优适应度、平均适应度及运算效率。二进制编码适用于离散问题,解表示为二进制串;实数编码适用于连续问题,直接搜索连续空间。两种编码在初始化、适应度评估、选择、交叉和变异步骤类似,但实数编码可能需更复杂策略避免局部最优。选择编码方式取决于问题特性。
|
SQL 监控 安全
网络攻击的阶段详解
【8月更文挑战第31天】
767 0
|
并行计算 异构计算 索引
JAX 中文文档(十六)(4)
JAX 中文文档(十六)
238 2
|
并行计算 API 异构计算
JAX 中文文档(十六)(3)
JAX 中文文档(十六)
313 0
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
217 0