JAX 中文文档(一)(5)

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
简介: JAX 中文文档(一)

JAX 中文文档(一)(4)https://developer.aliyun.com/article/1559834


jit改变了输出的精确数值

有时候,用户会对用jit()包装一个函数后,函数的输出发生变化感到惊讶。例如:

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365 
>>> print(jit(f)(x))
0.5723649 

这种输出的细微差异来自于 XLA 编译器中的优化:在编译过程中,XLA 有时会重新排列或省略某些操作,以使整体计算更加高效。

在这种情况下,XLA 利用对数的性质将log(sqrt(x))替换为0.5 * log(x),这是一个数学上相同的表达式,可以比原始表达式更有效地计算。输出的差异来自于浮点数运算只是对真实数学的近似,因此计算相同表达式的不同方式可能会有细微的差异。

其他时候,XLA 的优化可能导致更加显著的差异。考虑以下例子:

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf 
>>> print(jit(f)(x))
100.0 

在非 JIT 编译的逐操作模式下,结果为inf,因为jnp.exp(x)溢出并返回inf。然而,在 JIT 模式下,XLA 认识到logexp的反函数,并从编译函数中移除这些操作,简单地返回输入。在这种情况下,JIT 编译产生了对真实结果更准确的浮点数近似。

遗憾的是,XLA 的代数简化的完整列表文档不是很好,但如果你熟悉 C++ 并且对 XLA 编译器进行的优化类型感兴趣,你可以在源代码中查看它们:algebraic_simplifier.cc。## jit修饰函数编译速度非常慢

如果你的jit修饰函数在第一次调用时需要数十秒(甚至更长时间!)来运行,但在后续调用时执行速度很快,那么 JAX 正在花费很长时间来追踪或编译你的代码。

这通常表明调用你的函数生成了大量 JAX 内部表示的代码,通常是因为它大量使用了 Python 控制流,比如for循环。对于少量循环迭代,Python 是可以接受的,但如果你需要许多循环迭代,你应该重写你的代码,利用 JAX 的结构化控制流原语(如lax.scan())或避免用jit包装循环(你仍然可以在循环内部使用jit装饰的函数)。

如果你不确定问题出在哪里,你可以尝试在你的函数上运行jax.make_jaxpr()。如果输出很长,可能会导致编译速度慢。

有时候不明显如何重写你的代码以避免 Python 循环,因为你的代码使用了多个形状不同的数组。在这种情况下推荐的解决方案是利用像jax.numpy.where()这样的函数,在具有固定形状的填充数组上进行计算。

如果你的函数由于其他原因编译速度很慢,请在 GitHub 上提一个问题。## 如何在方法中使用 jit

大多数jax.jit()的示例涉及装饰独立的 Python 函数,但在类内部装饰方法会引入一些复杂性。例如,请考虑以下简单的类,我们在方法上使用了标准的jit()注解:

>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y 

然而,当你尝试调用此方法时,这种方法将导致错误:

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type. 

问题在于函数的第一个参数是self,其类型为CustomClass,而 JAX 不知道如何处理这种类型。在这种情况下,我们可能会使用三种基本策略,并在下面讨论它们。

策略 1: JIT 编译的辅助函数

最直接的方法是在类外部创建一个辅助函数,可以像平常一样进行 JIT 装饰。例如:

>>> from functools import partial
>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)
>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y 

结果将按预期工作:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6 

这种方法的好处是简单、明确,避免了教 JAX 如何处理CustomClass类型对象的需要。但是,你可能希望将所有方法逻辑保留在同一个地方。

策略 2: 将self标记为静态

另一种常见模式是使用static_argnumsself参数标记为静态。但是必须小心,以避免意外的结果。你可能会简单地这样做:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y 

如果你调用该方法,它将不再引发错误:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6 

然而,有一个问题:如果在第一次方法调用后修改对象,则后续方法调用可能会返回不正确的结果:

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6 

为什么会这样?当你将对象标记为静态时,它将有效地被用作 JIT 内部编译缓存中的字典键,这意味着其哈希值(即 hash(obj) )、相等性(即 obj1 == obj2 )和对象身份(即 obj1 is obj2 )的行为应保持一致。自定义对象的默认 __hash__ 是其对象 ID,因此 JAX 无法知道突变对象应触发重新编译。

你可以通过为对象定义适当的 __hash____eq__ 方法来部分解决这个问题;例如:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul)) 

(参见object.__hash__() 的文档,进一步讨论在覆盖 __hash__ 时的要求)。

只要你不修改对象,这种方法与 JIT 和其他转换一起工作正常。将对象用作哈希键的突变会导致几个微妙的问题,这就是为什么例如可变 Python 容器(如dictlist)不定义 __hash__,而它们的不可变对应物(如tuple)会。

如果你的类依赖于原地突变(比如在其方法中设置 self.attr = ...),那么你的对象并非真正“静态”,将其标记为静态可能会导致问题。幸运的是,对于这种情况还有另一种选择。

策略 3:将 CustomClass 设为 PyTree

JIT 编译类方法的最灵活方法是将类型注册为自定义的 PyTree 对象;请参阅扩展 pytrees。这样可以明确指定类的哪些组件应视为静态,哪些应视为动态。以下是具体操作:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)
>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten) 

这当然更加复杂,但解决了上述简单方法所带来的所有问题:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3
>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6 

只要你的 tree_flattentree_unflatten 函数能正确处理类中所有相关属性,你应该能直接将这种类型的对象用作 JIT 编译函数的参数,而不需要任何特殊的注释。 ## 控制数据和计算在设备上的放置

让我们先来看看 JAX 中数据和计算放置的原则。

在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1)数据所在的设备;2)数据是否已提交到设备(有时称为数据对设备的粘性)。

默认情况下,JAX 数组被放置在默认设备上未提交状态 (jax.devices()[0]),这通常是第一个 GPU 或 TPU。如果没有 GPU 或 TPU 存在,jax.devices()[0] 是 CPU。可以通过 jax.default_device() 上下文管理器临时覆盖默认设备,或者通过设置环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 来设置整个进程的默认设备为 “cpu”、“gpu” 或 “tpu”(JAX_PLATFORMS 也可以是一个平台列表,指定优先顺序中可用的平台)。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)} 

对涉及未提交数据的计算将在默认设备上执行,并且结果也会在默认设备上保持未提交状态。

数据也可以使用带有 device 参数的 jax.device_put() 明确地放置到设备上,在这种情况下,数据将会 提交 到设备上:

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)} 

包含一些已提交输入的计算将在已提交的设备上执行,并且结果将在同一设备上提交。在已提交到多个设备上的参数上调用操作将会引发错误。

也可以在没有 device 参数的情况下使用 jax.device_put()。如果数据已经在设备上(无论是已提交还是未提交状态),则保持不变。如果数据不在任何设备上,即它是常规的 Python 或 NumPy 值,则将其放置在默认设备上未提交状态。

经过 JIT 编译的函数行为与任何其他基本操作相同——它们会跟随数据,并且如果在提交到多个设备上的数据上调用时将会报错。

(在 2021 年 3 月之前的 PR #6002 中,创建数组常量时存在一些懒惰,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似的操作实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建数组然后移动。但为了简化实现,这种优化被移除了。)

(截至 2020 年 4 月,jax.jit() 函数有一个影响设备放置的 device 参数。该参数是实验性的,可能会被移除或更改,并且不建议使用。)

对于一个详细的例子,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data

你刚刚将一个复杂的函数从 NumPy/SciPy 移植到 JAX。那真的加快了速度吗?

当使用 JAX 测量代码速度时,请记住与 NumPy 的这些重要差异:

  1. JAX 代码是即时编译(JIT)的。大多数使用 JAX 编写的代码可以以支持 JIT 编译的方式编写,这可以使其运行 更快(参见 To JIT or not to JIT)。为了从 JAX 中获得最大的性能,应在最外层的函数调用上应用 jax.jit()
    请记住,第一次运行 JAX 代码时,它会更慢,因为它正在被编译。即使在您自己的代码中不使用 jit,因为 JAX 的内置函数也是 JIT 编译的,这也是真实的。
  2. JAX 具有异步分派。 这意味着您需要调用 .block_until_ready() 来确保计算实际发生了(参见异步分派)。
  3. JAX 默认只使用 32 位数据类型。 您可能希望在 NumPy 中明确使用 32 位数据类型,或者在 JAX 中启用 64 位数据类型(参见Double (64 bit) precision)以进行公平比较。
  4. 在 CPU 和加速器之间传输数据需要时间。 如果您只想测量评估函数所需的时间,您可能希望先将数据传输到要运行的设备上(参见控制数据和计算放置在设备上)。

下面是一个将所有这些技巧放在一起进行微基准测试以比较 JAX 和 NumPy 的示例,利用 IPython 的便捷的 %time 和 %timeit 魔法命令

import numpy as np
import jax.numpy as jnp
import jax
def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))
x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime
%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime 

当在 Colab 上使用 GPU 运行时,我们看到:

  • NumPy 在 CPU 上每次评估需要 16.2 毫秒。
  • JAX 将 NumPy 数组复制到 GPU 上花费了 1.26 毫秒。
  • JAX 编译该函数需要 193 毫秒。
  • JAX 在 GPU 上每次评估需要 485 微秒。

在这种情况下,我们看到一旦数据传输完毕并且函数被编译,JAX 在 GPU 上重复评估时大约快了 30 倍。

这个比较公平吗?也许是。最终重要的性能是运行完整应用程序时的性能,这些应用程序不可避免地包含了一些数据传输和编译。此外,我们小心地选择了足够大的数组(1000x1000)和足够密集的计算(@ 操作符执行矩阵乘法),以摊销 JAX/加速器相对于 NumPy/CPU 增加的开销。例如,如果我们将这个例子切换到使用 10x10 的输入,JAX/GPU 的运行速度比 NumPy/CPU 慢 10 倍(100 µs vs 10 µs)。

JAX 是否比 NumPy 更快?

用户经常试图通过这样的基准测试来回答一个问题,即 JAX 是否比 NumPy 更快;由于这两个软件包的差异,这并不是一个简单的问题。

广义上讲:

  • NumPy 操作是急切地、同步地执行,只在 CPU 上执行。
  • JAX 操作可以被急切地执行或者在编译之后执行(如果在 jit() 内部);它们被异步地分派(参见异步分派);并且它们可以在 CPU、GPU 或 TPU 上执行,每种设备都有非常不同且不断演变的性能特征。

这些架构差异使得直接比较 NumPy 和 JAX 的基准测试变得困难。

另外,这些差异还导致了软件包在工程上的不同关注点:例如,NumPy 大力减少了单个数组操作的每次调用分派开销,因为在 NumPy  的计算模型中,这种开销是无法避免的。另一方面,JAX 有几种方法可以避免分派开销(例如,JIT  编译、异步分派、批处理转换等),因此减少每次调用的开销并不是一个首要任务。

综上所述,在总结时:如果您在 CPU 上进行单个数组操作的微基准测试,通常可以预期 NumPy 由于其较低的每次操作分派开销而胜过  JAX。如果您在 GPU 或 TPU 上运行代码,或者在 CPU 上进行更复杂的 JIT 编译操作序列的基准测试,通常可以预期 JAX 胜过  NumPy。##不同类型的 JAX 值

在转换函数过程中,JAX 会用特殊的追踪器值替换一些函数参数。

如果您使用print语句,您可以看到这一点:

def func(x):
  print(x)
  return jnp.cos(x)
res = jax.jit(func)(0.) 

上述代码确实返回了正确的值1.,但它还打印出了Traced作为x的值。通常情况下,JAX 在内部以透明的方式处理这些追踪器值,例如,在用于实现jax.numpy函数的数值 JAX 原语中。这就是为什么在上面的例子中jnp.cos能够正常工作的原因。

更确切地说,追踪器值用于 JAX 变换函数的参数,除了由jax.jit()的特殊参数(如static_argnums)或jax.pmap()static_broadcasted_argnums标识的参数。通常,涉及至少一个追踪器值的计算将产生一个追踪器值。除了追踪器值之外,还有常规Python 值:在 JAX 变换之外计算的值,或者来自上述特定 JAX 变换的静态参数,或者仅仅是来自其他常规 Python 值的计算。在缺少 JAX 变换的情况下,这些值在任何地方都可以使用。

一个追踪器值携带一个抽象值,例如,ShapedArray包含有关数组形状和 dtype 的信息。我们将这些追踪器称为抽象追踪器。一些追踪器,例如,为自动微分变换的参数引入的那些,携带包含实际数组数据的ConcreteArray抽象值,并且用于解析条件。我们将这些追踪器称为具体追踪器。从这些具体追踪器计算出的追踪器值,也许与常规值结合,会产生具体追踪器。具体值是指常规值或具体追踪器。

大多数情况下,从追踪值计算得到的值本身也是追踪值。只有极少数例外情况,当一个计算可以完全使用追踪器携带的抽象值时,其结果可以是常规值。例如,使用 ShapedArray 抽象值获取追踪器的形状。另一个例子是显式地将具体的追踪器值转换为常规类型,例如 int(x)x.astype(float)。另一种情况是对 bool(x) 的处理,在具体性允许的情况下会产生 Python 布尔值。由于这种情况在控制流中经常出现,所以这种情况尤为显著。

下面是这些转换如何引入抽象或具体追踪器的说明:

  • jax.jit():除了由 static_argnums 指定的位置参数之外,为所有位置参数引入抽象追踪器,这些参数保持为常规值。
  • jax.pmap():除了由 static_broadcasted_argnums 指定的位置参数之外,为所有位置参数引入抽象追踪器
  • jax.vmap()jax.make_jaxpr()xla_computation():为所有位置参数引入抽象追踪器
  • jax.jvp()jax.grad() 为所有位置参数引入具体追踪器。唯一的例外是当这些转换在外部转换内部进行时,实际参数本身就是抽象追踪器时,由自动微分转换引入的追踪器也是抽象追踪器。
  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数时引入抽象追踪器,无论是否存在 JAX 转换。

当您的代码仅能操作常规的 Python 值时,例如基于数据的条件控制流的代码时,这些都是相关的:

def divide(x, y):
  return x / y if y >= 1. else 0. 

如果我们想要应用 jax.jit(),我们必须确保指定 static_argnums=1 以确保 y 保持为常规值。这是由于布尔表达式 y >= 1.,它需要具体的值(常规或追踪器)。如果我们显式地编写 bool(y >= 1.)int(y)float(y),也会发生同样的情况。

有趣的是,jax.grad(divide)(3., 2.) 是有效的,因为 jax.grad() 使用具体追踪器,并使用 y 的具体值解析条件。 ## 缓冲捐赠

当 JAX 执行计算时,它使用设备上的缓冲区来处理所有输入和输出。如果您知道某个输入在计算后不再需要,并且它与某个输出的形状和元素类型匹配,您可以指定要捐赠相应输入的缓冲区来保存输出。这将减少执行所需的内存,减少捐赠缓冲区的大小。

如果您有类似以下模式的情况,可以使用缓冲捐赠:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state) 

您可以将此视为一种在不可变 JAX 数组上进行内存高效的函数更新的方法。在计算的 XLA 边界内,XLA 可以为您进行此优化,但在 jit/pmap 边界处,您需要向 XLA 保证在调用捐赠函数后不会再使用捐赠的输入缓冲区。

您可以通过在函数jax.jit()jax.pjit()jax.pmap()中使用donate_argnums参数来实现这一点。此参数是位置参数列表(从 0 开始)的索引序列:

def add(x, y):
  return x + y
x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y) 

注意,如果使用关键字参数调用函数,则此方法目前不起作用!以下代码不会捐赠任何缓冲区:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state) 

如果一个参数的缓冲区被捐赠,且其为 pytree,则其所有组件的缓冲区都会被捐赠:

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]
xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs) 

不允许捐赠随后在计算中使用的缓冲区,因此在 y 的缓冲区捐赠后,JAX 会报错因为该缓冲区已失效:

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer 

如果捐赠的缓冲区未被使用,则会收到警告,例如因为捐赠的缓冲区多于输出所需:

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0} 

如果没有输出的形状与捐赠匹配,则捐赠可能也不会被使用:

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0} 

使用where时,梯度包含 NaN

如果定义一个使用where来避免未定义值的函数,如果不小心可能会得到一个反向微分的NaN

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)
my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN 

简而言之,在grad计算期间,对于未定义的jnp.log(x)的伴随是NaN,并且会累积到jnp.where的伴随中。正确的编写这类函数的方法是确保在部分定义的函数内部有一个jnp.where,以确保伴随始终是有限的:

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))
safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok 

除原始jnp.where外可能还需要内部的jnp.where,例如:

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y) 

进一步阅读:

基于排序顺序的函数为何梯度为零?

如果定义一个处理输入的函数,并使用依赖于输入相对顺序的操作(例如maxgreaterargsort等),那么可能会惊讶地发现梯度在所有位置都为零。以下是一个例子,我们定义 f(x)为一个阶跃函数,在 x 为负时返回 0,在 x 为正时返回 1:

import jax
import numpy as np
import jax.numpy as jnp
def f(x):
  return (x > 0).astype(float)
df = jax.vmap(jax.grad(f))
x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])
print(f"f(x)  = {f(x)}")
# f(x)  = [0\. 0\. 0\. 1\. 1.]
print(f"df(x) = {df(x)}")
# df(x) = [0\. 0\. 0\. 0\. 0.] 

虽然输出对输入有响应,但梯度在所有位置为零可能会令人困惑:毕竟,输出确实随输入而变化,那么梯度怎么可能是零呢?然而,在这种情况下,零确实是正确的结果。

这是为什么?请记住,微分测量的是给定 xf 的变化。对于 x=1.0f 返回 1.0。如果我们微扰 x 使其稍大或稍小,这并不会改变输出,因此根据定义,grad(f)(1.0) 应该为零。对于所有大于零的 f 值,此逻辑同样成立:微扰输入不会改变输出,因此梯度为零。同样,对于所有小于零的 x 值,输出为零。微扰 x 不会改变这个输出,因此梯度为零。这让我们面对 x=0 的棘手情况。当然,如果你向上微扰 x,它会改变输出,但这是有问题的:x  的微小变化会产生函数值的有限变化,这意味着梯度是未定义的。幸运的是,在这种情况下我们还有另一种方法来测量梯度:我们向下微扰函数,此时输出不变,因此梯度为零。JAX   和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但其中一个被定义,另一个未定义,我们使用被定义的那个。根据梯度的这一定义,从数学和数值上来说,此函数的梯度在任何地方都是零。

问题在于我们的函数在 x = 0 处有不连续点。我们的 f 本质上是一个 Heaviside Step Function,我们可以使用 Sigmoid Function 作为平滑替代。当 x 远离零时,Sigmoid 函数近似等于 Heaviside 函数,但在 x = 0 处用一个平滑的、可微的曲线替换不连续性。通过使用 jax.nn.sigmoid(),我们得到一个具有良定义梯度的类似计算:

def g(x):
  return jax.nn.sigmoid(x)
dg = jax.vmap(jax.grad(g))
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0\.   0.27 0.5  0.73 1\.  ]
  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0\.   0.2  0.25 0.2  0\.  ] 

jax.nn 子模块还有其他常见基于排名的函数的平滑版本,例如 jax.nn.softmax() 可以替换 jax.numpy.argmax() 的使用,jax.nn.soft_sign() 可以替换 jax.numpy.sign() 的使用,jax.nn.softplus()jax.nn.squareplus() 可以替换 jax.nn.relu() 的使用,等等。

我如何将 JAX 追踪器转换为 NumPy 数组?

在运行时检查转换后的 JAX 函数时,您会发现数组值被 Tracer 对象替换:

@jax.jit
def f(x):
  print(type(x))
  return x
f(jnp.arange(5)) 

这将打印如下内容:

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> 

一个常见的问题是如何将这样的追踪器转换回正常的 NumPy 数组。简而言之,无法将追踪器转换为 NumPy 数组,因为追踪器是具有给定形状和数据类型的每一个可能值的抽象表示,而 NumPy 数组是该抽象类的具体成员。有关在 JAX 转换环境中追踪器工作的更多讨论,请参阅 JIT mechanics

将跟踪器转换回数组的问题通常出现在与运行时访问计算中的中间值相关的另一个目标的背景下。例如:

  • 如果您希望出于调试目的在运行时打印跟踪值,您可以考虑使用jax.debug.print()
  • 如果您希望在转换后的 JAX 函数中调用非 JAX 代码,您可以考虑使用jax.pure_callback(),其示例可在纯回调示例中找到。
  • 如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据或将数组内容记录到磁盘),您可以考虑使用jax.experimental.io_callback(),其示例可在IO 回调示例中找到。

关于运行时回调的更多信息和它们的使用示例,请参阅JAX 中的外部回调

为什么会有些 CUDA 库加载/初始化失败?

在解析动态库时,JAX 使用通常的动态链接器搜索模式。JAX 将RPATH设置为指向通过 pip 安装的 NVIDIA CUDA 软件包的 JAX 相对位置,如果安装了这些软件包,则优先使用它们。如果ld.so在其通常的搜索路径中找不到您的 CUDA 运行时库,则必须在LD_LIBRARY_PATH中显式包含这些库的路径。确保您的 CUDA 文件可被发现的最简单方法是简单地安装标准的jax[cuda_12]安装选项中包含的nvidia-*-cu12 pip 软件包。

偶尔,即使您确保您的运行时库可被发现,仍可能存在加载或初始化的问题。这类问题的常见原因是运行时 CUDA 库初始化时内存不足。这有时是因为 JAX 将预分配当前可用设备内存的太大一部分以提高执行速度,偶尔会导致没有足够的内存留给运行时 CUDA 库初始化。

在运行多个 JAX 实例、与执行自己的预分配的 TensorFlow 并行运行 JAX,或者在 GPU 被其他进程大量使用的系统上运行 JAX 时,特别容易发生这种情况。如果有疑问,请尝试使用减少预分配来重新运行程序,可以通过减少XLA_PYTHON_CLIENT_MEM_FRACTION(默认为.75)或设置XLA_PYTHON_CLIENT_PREALLOCATE=false来实现。有关更多详细信息,请参阅JAX GPU 内存分配页面。

. ]

print(f"dg(x) = {dg(x)}")

dg(x) = [0. 0.2 0.25 0.2 0. ]

`jax.nn` 子模块还有其他常见基于排名的函数的平滑版本,例如 `jax.nn.softmax()` 可以替换 `jax.numpy.argmax()` 的使用,`jax.nn.soft_sign()` 可以替换 `jax.numpy.sign()` 的使用,`jax.nn.softplus()` 或 `jax.nn.squareplus()` 可以替换 `jax.nn.relu()` 的使用,等等。
## 我如何将 JAX 追踪器转换为 NumPy 数组?
在运行时检查转换后的 JAX 函数时,您会发现数组值被 `Tracer` 对象替换:
```py
@jax.jit
def f(x):
  print(type(x))
  return x
f(jnp.arange(5)) 

这将打印如下内容:

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> 

一个常见的问题是如何将这样的追踪器转换回正常的 NumPy 数组。简而言之,无法将追踪器转换为 NumPy 数组,因为追踪器是具有给定形状和数据类型的每一个可能值的抽象表示,而 NumPy 数组是该抽象类的具体成员。有关在 JAX 转换环境中追踪器工作的更多讨论,请参阅 JIT mechanics

将跟踪器转换回数组的问题通常出现在与运行时访问计算中的中间值相关的另一个目标的背景下。例如:

  • 如果您希望出于调试目的在运行时打印跟踪值,您可以考虑使用jax.debug.print()
  • 如果您希望在转换后的 JAX 函数中调用非 JAX 代码,您可以考虑使用jax.pure_callback(),其示例可在纯回调示例中找到。
  • 如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据或将数组内容记录到磁盘),您可以考虑使用jax.experimental.io_callback(),其示例可在IO 回调示例中找到。

关于运行时回调的更多信息和它们的使用示例,请参阅JAX 中的外部回调

为什么会有些 CUDA 库加载/初始化失败?

在解析动态库时,JAX 使用通常的动态链接器搜索模式。JAX 将RPATH设置为指向通过 pip 安装的 NVIDIA CUDA 软件包的 JAX 相对位置,如果安装了这些软件包,则优先使用它们。如果ld.so在其通常的搜索路径中找不到您的 CUDA 运行时库,则必须在LD_LIBRARY_PATH中显式包含这些库的路径。确保您的 CUDA 文件可被发现的最简单方法是简单地安装标准的jax[cuda_12]安装选项中包含的nvidia-*-cu12 pip 软件包。

偶尔,即使您确保您的运行时库可被发现,仍可能存在加载或初始化的问题。这类问题的常见原因是运行时 CUDA 库初始化时内存不足。这有时是因为 JAX 将预分配当前可用设备内存的太大一部分以提高执行速度,偶尔会导致没有足够的内存留给运行时 CUDA 库初始化。

在运行多个 JAX 实例、与执行自己的预分配的 TensorFlow 并行运行 JAX,或者在 GPU 被其他进程大量使用的系统上运行 JAX 时,特别容易发生这种情况。如果有疑问,请尝试使用减少预分配来重新运行程序,可以通过减少XLA_PYTHON_CLIENT_MEM_FRACTION(默认为.75)或设置XLA_PYTHON_CLIENT_PREALLOCATE=false来实现。有关更多详细信息,请参阅JAX GPU 内存分配页面。

相关文章
|
3月前
|
机器学习/深度学习
JAX 中文文档(六)(5)
JAX 中文文档(六)
24 0
|
3月前
|
机器学习/深度学习 异构计算 Python
JAX 中文文档(四)(3)
JAX 中文文档(四)
22 0
|
3月前
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
23 0
|
3月前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
29 0
|
3月前
|
存储 并行计算 数据可视化
JAX 中文文档(六)(3)
JAX 中文文档(六)
25 0
|
3月前
|
存储 PyTorch 测试技术
JAX 中文文档(八)(5)
JAX 中文文档(八)
26 0
|
3月前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
40 0
|
3月前
|
编译器 测试技术 API
JAX 中文文档(四)(4)
JAX 中文文档(四)
27 0
|
3月前
|
存储 并行计算 开发工具
JAX 中文文档(十)(1)
JAX 中文文档(十)
32 0
|
3月前
|
安全 编译器 TensorFlow
JAX 中文文档(四)(5)
JAX 中文文档(四)
22 0