JAX 中文文档(三)(1)

简介: JAX 中文文档(三)


原文:jax.readthedocs.io/en/latest/

有状态计算

原文:jax.readthedocs.io/en/latest/stateful-computations.html

JAX 的转换(如jit()vmap()grad())要求它们包装的函数是纯粹的:即,函数的输出仅依赖于输入,并且没有副作用,比如更新全局状态。您可以在JAX sharp bits: Pure functions中找到关于这一点的讨论。

在机器学习的背景下,这种约束可能会带来一些挑战,因为状态可以以多种形式存在。例如:

  • 模型参数,
  • 优化器状态,以及
  • BatchNorm这样的有状态层。

本节提供了如何在 JAX 程序中正确处理状态的一些建议。

一个简单的例子:计数器

让我们首先看一个简单的有状态程序:一个计数器。

import jax
import jax.numpy as jnp
class Counter:
  """A simple counter."""
  def __init__(self):
    self.n = 0
  def count(self) -> int:
  """Increments the counter and returns the new value."""
    self.n += 1
    return self.n
  def reset(self):
  """Resets the counter to zero."""
    self.n = 0
counter = Counter()
for _ in range(3):
  print(counter.count()) 
1
2
3 

计数器的n属性在连续调用count时维护计数器的状态。调用count的副作用是修改它。

假设我们想要快速计数,所以我们即时编译count方法。(在这个例子中,这实际上不会以任何方式加快速度,由于很多原因,但把它看作是模型参数更新的玩具模型,jit()确实产生了巨大的影响)。

counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
  print(fast_count()) 
1
1
1 

哦不!我们的计数器不能工作了。这是因为

self.n += 1 

count中涉及副作用:它直接修改了输入的计数器,因此此函数不受jit支持。这样的副作用仅在首次跟踪函数时执行一次,后续调用将不会重复该副作用。那么,我们该如何修复它呢?

解决方案:显式状态

问题的一部分在于我们的计数器返回值不依赖于参数,这意味着编译输出中包含了一个常数。但它不应该是一个常数 - 它应该依赖于状态。那么,为什么我们不将状态作为一个参数呢?

CounterState = int
class CounterV2:
  def count(self, n: CounterState) -> tuple[int, CounterState]:
    # You could just return n+1, but here we separate its role as 
    # the output and as the counter state for didactic purposes.
    return n+1, n+1
  def reset(self) -> CounterState:
    return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
  value, state = counter.count(state)
  print(value) 
1
2
3 

在这个Counter的新版本中,我们将n移动到count的参数中,并添加了另一个返回值,表示新的、更新的状态。现在,为了使用这个计数器,我们需要显式地跟踪状态。但作为回报,我们现在可以安全地使用jax.jit这个计数器:

state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
  value, state = fast_count(state)
  print(value) 
1
2
3 

一个一般的策略

我们可以将同样的过程应用到任何有状态方法中,将其转换为无状态方法。我们拿一个形式如下的类

class StatefulClass
  state: State
  def stateful_method(*args, **kwargs) -> Output: 

并将其转换为以下形式的类

class StatelessClass
  def stateless_method(state: State, *args, **kwargs) -> (Output, State): 

这是一个常见的函数式编程模式,本质上就是处理所有 JAX 程序中状态的方式。

注意,一旦我们按照这种方式重写它,类的必要性就不那么明显了。我们可以只保留stateless_method,因为类不再执行任何工作。这是因为,像我们刚刚应用的策略一样,面向对象编程(OOP)是帮助程序员理解程序状态的一种方式。

在我们的情况下,CounterV2 类只是一个名称空间,将所有使用 CounterState 的函数集中在一个位置。读者可以思考:将其保留为类是否有意义?

顺便说一句,你已经在 JAX 伪随机性 API 中看到了这种策略的示例,即 jax.random,在 :ref:pseudorandom-numbers 部分展示。与 Numpy 不同,后者使用隐式更新的有状态类管理随机状态,而 JAX 要求程序员直接使用随机生成器状态——PRNG 密钥。

简单的工作示例:线性回归

现在让我们将这种策略应用到一个简单的机器学习模型上:通过梯度下降进行线性回归。

这里,我们只处理一种状态:模型参数。但通常情况下,你会看到许多种状态在 JAX 函数中交替出现,比如优化器状态、批归一化的层统计数据等。

需要仔细查看的函数是 update

from typing import NamedTuple
class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray
def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * x + params.bias
  return jnp.mean((pred - y) ** 2)
LEARNING_RATE = 0.005
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
  """Performs one SGD update step on params using the given data."""
  grad = jax.grad(loss)(params, x, y)
  # If we were using Adam or another stateful optimizer,
  # we would also do something like
  #
  #   updates, new_optimizer_state = optimizer(grad, optimizer_state)
  # 
  # and then use `updates` instead of `grad` to actually update the params.
  # (And we'd include `new_optimizer_state` in the output, naturally.)
  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grad)
  return new_params 

注意,我们手动地将参数输入和输出到更新函数中。

import matplotlib.pyplot as plt
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise
# Fit regression
params = init(rng)
for _ in range(1000):
  params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend(); 
/tmp/ipykernel_2992/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  new_params = jax.tree_map( 


进一步探讨

上述描述的策略是任何使用 jitvmapgrad 等转换的 JAX 程序必须处理状态的方式。

如果只涉及两个参数,手动处理参数似乎还可以接受,但如果是有数十层的神经网络呢?你可能已经开始担心两件事情:

  1. 我们是否应该手动初始化它们,基本上是在前向传播定义中已经编写过的内容?
  2. 我们是否应该手动处理所有这些事情?

处理这些细节可能有些棘手,但有一些库的示例可以为您解决这些问题。请参阅JAX 神经网络库获取一些示例。

进一步资源

用户指南

原文:jax.readthedocs.io/en/latest/user_guides.html

用户指南是对 JAX 内特定主题的深入探讨,随着您的 JAX 项目发展成为更大或部署代码库,这些主题变得更为相关。

调试和性能

  • 如何在 JAX 中思考
  • 对 JAX 程序进行性能分析
  • 设备内存分析
  • JAX 中的运行时值调试
  • GPU 性能技巧
  • 持久化编译缓存

开发

  • 理解 Jaxprs
  • JAX 中的外部回调
  • 类型提升语义
  • Pytrees

运行时间

  • 提前降低和编译
  • 导出和序列化
  • JAX 错误
  • 转移保护

自定义操作

  • Pallas:一种 JAX 内核语言

如何在 JAX 中思考

原文:jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

[外链图片转存中…(img-rDfp2FNH-1718950659301)] [外链图片转存中…(img-axjiDX87-1718950659301)]

JAX 提供了一个简单而强大的 API 用于编写加速数值代码,但在 JAX 中有效工作有时需要额外考虑。本文档旨在帮助建立对 JAX 如何运行的基础理解,以便您更有效地使用它。

JAX vs. NumPy

关键概念:

  • JAX 提供了一个方便的类似于 NumPy 的接口。
  • 通过鸭子类型,JAX 数组通常可以直接替换 NumPy 数组。
  • 不像 NumPy 数组,JAX 数组总是不可变的。

NumPy 提供了一个众所周知且功能强大的 API 用于处理数值数据。为方便起见,JAX 提供了 jax.numpy,它紧密反映了 NumPy API,并为进入 JAX 提供了便捷的入口。几乎可以用 jax.numpy 完成 numpy 可以完成的任何事情:

import matplotlib.pyplot as plt
import numpy as np
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np); 


import jax.numpy as jnp
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp); 


代码块除了用 jnp 替换 np 外,其余完全相同。正如我们所见,JAX 数组通常可以直接替换 NumPy 数组,用于诸如绘图等任务。

这些数组本身是作为不同的 Python 类型实现的:

type(x_np) 
numpy.ndarray 
type(x_jnp) 
jaxlib.xla_extension.ArrayImpl 

Python 的 鸭子类型 允许在许多地方可互换使用 JAX 数组和 NumPy 数组。

然而,JAX 和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,一旦创建,其内容无法更改。

这里有一个在 NumPy 中突变数组的例子:

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x) 
[10  1  2  3  4  5  6  7  8  9] 

在 JAX 中,等效操作会导致错误,因为 JAX 数组是不可变的:

%xmode minimal 
Exception reporting mode: Minimal 
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10 
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html 

对于更新单个元素,JAX 提供了一个 索引更新语法,返回一个更新后的副本:

y = x.at[0].set(10)
print(x)
print(y) 
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9] 

NumPy、lax 和 XLA:JAX API 层次结构

关键概念:

  • jax.numpy 是一个提供熟悉接口的高级包装器。
  • jax.lax 是一个更严格且通常更强大的低级 API。
  • 所有 JAX 操作都是基于 XLA – 加速线性代数编译器中的操作实现的。

如果您查看 jax.numpy 的源代码,您会看到所有操作最终都是以 jax.lax 中定义的函数形式表达的。您可以将 jax.lax 视为更严格但通常更强大的 API,用于处理多维数组。

例如,虽然jax.numpy将隐式促进参数以允许不同数据类型之间的操作,但jax.lax不会:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types. 
Array(2., dtype=float32, weak_type=True) 
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion. 
MLIRError: Verification failed:
error: "jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))): op requires the same element type for all operands and results
The above exception was the direct cause of the following exception:
ValueError: Cannot lower jaxpr with verifier errors:
  op requires the same element type for all operands and results
    at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))
Define JAX_DUMP_IR_TO to dump the module. 

如果直接使用jax.lax,在这种情况下你将需要显式地进行类型提升:

lax.add(jnp.float32(1), 1.0) 
Array(2., dtype=float32) 

除了这种严格性外,jax.lax还提供了一些比 NumPy 支持的更一般操作更高效的 API。

例如,考虑一个 1D 卷积,在 NumPy 中可以这样表达:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y) 
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32) 

在幕后,这个 NumPy 操作被转换为由lax.conv_general_dilated实现的更通用的卷积:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0] 
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32) 

这是一种批处理卷积操作,专为深度神经网络中经常使用的卷积类型设计,需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多细节,请参见Convolutions in JAX)。

从本质上讲,所有jax.lax操作都是 XLA 中操作的 Python 包装器;例如,在这里,卷积实现由XLA:ConvWithGeneralPadding提供。每个 JAX 操作最终都是基于这些基本 XLA 操作表达的,这就是使得即时(JIT)编译成为可能的原因。

要 JIT 或不要 JIT

关键概念:

  • 默认情况下,JAX 按顺序逐个执行操作。
  • 使用即时(JIT)编译装饰器,可以优化操作序列并一次运行:
  • 并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态且已知的。

所有 JAX 操作都是基于 XLA 表达的事实,使得 JAX 能够使用 XLA 编译器非常高效地执行代码块。

例如,考虑此函数,它对二维矩阵的行进行标准化,表达为jax.numpy操作:

import jax.numpy as jnp
def norm(X):
  X = X - X.mean(0)
  return X / X.std(0) 

可以使用jax.jit变换创建函数的即时编译版本:

from jax import jit
norm_compiled = jit(norm) 

此函数返回与原始函数相同的结果,达到标准浮点精度:

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6) 
True 

但由于编译(其中包括操作的融合、避免分配临时数组以及其他许多技巧),在 JIT 编译的情况下,执行时间可以比非常数级别快得多(请注意使用block_until_ready()以考虑 JAX 的异步调度):

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready() 
319 μs ± 1.98 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
272 μs ± 849 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

话虽如此,jax.jit确实存在一些限制:特别是,它要求所有数组具有静态形状。这意味着一些 JAX 操作与 JIT 编译不兼容。

例如,此操作可以在逐操作模式下执行:

def get_negatives(x):
  return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x) 
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

但如果您尝试在 jit 模式下执行它,则会返回错误:

jit(get_negatives)(x) 
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError 

这是因为该函数生成的数组形状在编译时未知:输出的大小取决于输入数组的值,因此与 JIT 不兼容。

JIT 机制:跟踪和静态变量

关键概念:

  • JIT 和其他 JAX 转换通过跟踪函数来确定其对特定形状和类型输入的影响。
  • 不希望被追踪的变量可以标记为静态

要有效使用 jax.jit,理解其工作原理是很有用的。让我们在一个 JIT 编译的函数中放几个 print() 语句,然后调用该函数:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result
x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y) 
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> 
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32) 

注意,打印语句执行,但打印的不是我们传递给函数的数据,而是打印追踪器对象,这些对象代替它们。

这些追踪器对象是 jax.jit 用来提取函数指定的操作序列的基本替代物,编码数组的形状dtype,但对值是不可知的。然后可以有效地将这个记录的计算序列应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。

当我们在匹配的输入上再次调用编译函数时,无需重新编译,也不打印任何内容,因为结果在编译的 XLA 中计算,而不是在 Python 中:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2) 
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32) 

提取的操作序列编码在 JAX 表达式中,简称为 jaxpr。您可以使用 jax.make_jaxpr 转换查看 jaxpr:

from jax import make_jaxpr
def f(x, y):
  return jnp.dot(x + 1, y + 1)
make_jaxpr(f)(x, y) 
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) } 

注意这一后果:因为 JIT 编译是在没有数组内容信息的情况下完成的,所以函数中的控制流语句不能依赖于追踪的值。例如,这将失败:

@jit
def f(x, neg):
  return -x if neg else x
f(1, True) 
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_2814/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError 

如果有不希望被追踪的变量,可以将它们标记为静态以供 JIT 编译使用:

from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x
f(1, True) 
Array(-1, dtype=int32, weak_type=True) 

请注意,使用不同的静态参数调用 JIT 编译函数会导致重新编译,所以函数仍然如预期般工作:

f(1, False) 
Array(1, dtype=int32, weak_type=True) 

理解哪些值和操作将是静态的,哪些将被追踪,是有效使用 jax.jit 的关键部分。

静态与追踪操作

关键概念:

  • 就像值可以是静态的或者被追踪的一样,操作也可以是静态的或者被追踪的。
  • 静态操作在 Python 中在编译时评估;跟踪操作在 XLA 中在运行时编译并评估。
  • 使用 numpy 进行您希望静态的操作;使用 jax.numpy 进行您希望被追踪的操作。

静态和追踪值的区别使得重要的是考虑如何保持静态值的静态。考虑这个函数:

import jax.numpy as jnp
from jax import jit
@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())
x = jnp.ones((2, 3))
f(x) 
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2814/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:
  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /tmp/ipykernel_2814/1983583872.py:6 (f) 

这会因为找到追踪器而不是整数类型的具体值的 1D 序列而失败。让我们向函数中添加一些打印语句,以了解其原因:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())
f(x) 
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)> 

注意尽管x被追踪,x.shape是一个静态值。然而,当我们在这个静态值上使用jnp.arrayjnp.prod时,它变成了一个被追踪的值,在这种情况下,它不能用于像reshape()这样需要静态输入的函数(回想:数组形状必须是静态的)。

一个有用的模式是使用numpy进行应该是静态的操作(即在编译时完成),并使用jax.numpy进行应该被追踪的操作(即在运行时编译和执行)。对于这个函数,可能会像这样:

from jax import jit
import jax.numpy as jnp
import numpy as np
@jit
def f(x):
  return x.reshape((np.prod(x.shape),))
f(x) 
Array([1., 1., 1., 1., 1., 1.], dtype=float32) 

因此,在 JAX 程序中的一个标准约定是import numpy as npimport jax.numpy as jnp,这样两个接口都可以用来更精细地控制操作是以静态方式(使用numpy,一次在编译时)还是以追踪方式(使用jax.numpy,在运行时优化)执行。


JAX 中文文档(三)(2)https://developer.aliyun.com/article/1559703

相关文章
|
3天前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
9 1
|
3天前
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
8 0
|
3天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
10 0
|
3天前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
7 0
|
3天前
|
存储 机器学习/深度学习 并行计算
JAX 中文文档(二)(5)
JAX 中文文档(二)
8 0
|
3天前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
7 0
|
3天前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
8 0
|
3天前
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
8 0
|
3天前
|
并行计算 编译器
JAX 中文文档(六)(4)
JAX 中文文档(六)
6 0
|
3天前
|
数据可视化 TensorFlow 算法框架/工具
JAX 中文文档(三)(2)
JAX 中文文档(三)
8 0