JAX 中文文档(一)(2)

简介: JAX 中文文档(一)

JAX 中文文档(一)(1)https://developer.aliyun.com/article/1559829


🔪 JAX - 锋利的部分 🔪

原文:jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

levskaya@ mattjj@

在意大利乡间漫步时,人们会毫不犹豫地告诉您,JAX 具有 “una anima di pura programmazione funzionale”

JAX 是一种用于表达和组合数值程序转换的语言。JAX 还能够为 CPU 或加速器(GPU/TPU)编译数值程序。对于许多数值和科学程序,JAX 表现出色,但前提是它们必须按照我们下面描述的某些约束条件编写。

import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp 

🔪 纯函数

JAX 的转换和编译设计仅适用于函数式纯的 Python 函数:所有输入数据通过函数参数传递,所有结果通过函数结果输出。纯函数如果以相同的输入调用,将始终返回相同的结果。

下面是一些函数示例,这些函数不是函数式纯的,因此 JAX 的行为与 Python 解释器不同。请注意,这些行为并不由 JAX 系统保证;正确使用 JAX 的方法是仅在函数式纯 Python 函数上使用它。

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.]))) 
Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.] 
g = 0.
def impure_uses_globals(x):
  return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.]))) 
First call:  4.0
Second call:  5.0
Third call, different type:  [14.] 
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value 
First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 

即使一个 Python 函数在内部实际上使用了有状态的对象,只要它不读取或写入外部状态,它就可以是函数式纯的:

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.)) 
50.0 

不建议在希望jit的任何 JAX 函数中使用迭代器或任何控制流原语。原因是迭代器是一个引入状态以检索下一个元素的  Python 对象。因此,它与 JAX 的函数式编程模型不兼容。在下面的代码中,有一些尝试在 JAX  中使用迭代器的错误示例。其中大多数会返回错误,但有些会给出意外的结果。

import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error 
45
0 

🔪 原地更新

在 Numpy 中,您习惯于执行以下操作:

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array) 
original array:
[[0\. 0\. 0.]
 [0\. 0\. 0.]
 [0\. 0\. 0.]]
updated array:
[[0\. 0\. 0.]
 [1\. 1\. 1.]
 [0\. 0\. 0.]] 

然而,如果我们尝试在 JAX 设备数组上就地更新,我们会收到错误!(☉_☉)

%xmode Minimal 
Exception reporting mode: Minimal 
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0 
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html 

允许变量在原地变异会使程序分析和转换变得困难。JAX 要求程序是纯函数。

相反,JAX 提供了对 JAX 数组上的 .at 属性进行函数式数组更新

️⚠️ 在 jit 的代码中和 lax.while_looplax.fori_loop 中,切片的大小不能是参数 的函数,而只能是参数 形状 的函数 — 切片的起始索引没有此类限制。有关此限制的更多信息,请参阅下面的 控制流 部分。

数组更新:x.at[idx].set(y)

例如,上述更新可以写成:

updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array) 
updated array:
 [[0\. 0\. 0.]
 [1\. 1\. 1.]
 [0\. 0\. 0.]] 

JAX 的数组更新函数与其 NumPy 版本不同,是在原地外执行的。也就是说,更新后的数组作为新数组返回,原始数组不会被更新修改。

print("original array unchanged:\n", jax_array) 
original array unchanged:
 [[0\. 0\. 0.]
 [0\. 0\. 0.]
 [0\. 0\. 0.]] 

然而,在jit编译的代码内部,如果x.at[idx].set(y)输入值 x 没有被重用,编译器会优化数组更新以进行原地操作。

使用其他操作的数组更新

索引数组更新不仅限于覆盖值。例如,我们可以进行索引加法如下:

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array) 
original array:
[[1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]]
new array post-addition:
[[1\. 1\. 1\. 8\. 8\. 8.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 8\. 8\. 8.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 8\. 8\. 8.]] 

有关索引数组更新的更多详细信息,请参阅.at属性的文档

🔪 超出边界索引

在 NumPy 中,当您索引数组超出其边界时,通常会抛出错误,例如:

np.arange(10)[11] 
IndexError: index 11 is out of bounds for axis 0 with size 10 

然而,在加速器上运行的代码中引发错误可能会很困难或不可能。因此,JAX 必须为超出边界的索引选择一些非错误行为(类似于无效的浮点算术结果为NaN的情况)。当索引操作是数组索引更新时(例如index_add或类似的原语),将跳过超出边界的索引;当操作是数组索引检索时(例如 NumPy 索引或类似的原语),索引将夹紧到数组的边界,因为必须返回某些内容。例如,数组的最后一个值将从此索引操作中返回:

jnp.arange(10)[11] 
Array(9, dtype=int32) 

如果您希望对超出边界索引的行为有更精细的控制,可以使用ndarray.at的可选参数;例如:

jnp.arange(10.0).at[11].get() 
Array(9., dtype=float32) 
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan) 
Array(nan, dtype=float32) 

注意由于这种索引检索行为,像jnp.nanargminjnp.nanargmax这样的函数在由 NaN 组成的切片中返回-1,而 NumPy 会抛出错误。

还请注意,由于上述两种行为不是互为反操作,反向模式自动微分(将索引更新转换为索引检索及其反之)将不会保留超出边界索引的语义。因此,将 JAX 中的超出边界索引视为未定义行为可能是个好主意。

🔪 非数组输入:NumPy vs. JAX

NumPy 通常可以接受 Python 列表或元组作为其 API 函数的输入:

np.sum([1, 2, 3]) 
np.int64(6) 

JAX 在这方面有所不同,通常会返回有用的错误:

jnp.sum([1, 2, 3]) 
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0. 

这是一个有意的设计选择,因为向追踪函数传递列表或元组可能导致性能下降,而这种性能下降可能很难检测到。

例如,请考虑允许列表输入的jnp.sum的以下宽松版本:

def permissive_sum(x):
  return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x) 
Array(45, dtype=int32) 

输出与预期相符,但这隐藏了底层的潜在性能问题。在 JAX 的追踪和 JIT 编译模型中,Python 列表或元组中的每个元素都被视为单独的 JAX 变量,并分别处理和推送到设备。这可以在上面的permissive_sum函数的 jaxpr 中看到:

make_jaxpr(permissive_sum)(x) 
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
    j:i32[]. let
    k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
    n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
    o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
    q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
    r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
    s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
    t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
    u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
    v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
    w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m
    x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
    y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o
    z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p
    ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
    bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r
    bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
    bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t
    be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
    bf:i32[] = reduce_sum[axes=(0,)] be
  in (bf,) } 

列表的每个条目都作为单独的输入处理,导致追踪和编译开销随列表大小线性增长。为了避免这样的意外,JAX 避免将列表和元组隐式转换为数组。

如果您希望将元组或列表传递给 JAX 函数,可以首先显式地将其转换为数组:

jnp.sum(jnp.array(x)) 
Array(45, dtype=int32) 


JAX 中文文档(一)(3)https://developer.aliyun.com/article/1559832

相关文章
|
3月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
26 0
JAX 中文文档(六)(1)
|
3月前
JAX 中文文档(九)(3)
JAX 中文文档(九)
28 0
|
3月前
|
机器学习/深度学习 并行计算 安全
JAX 中文文档(七)(1)
JAX 中文文档(七)
34 0
|
3月前
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
22 0
|
3月前
|
编译器 测试技术 API
JAX 中文文档(四)(4)
JAX 中文文档(四)
27 0
|
3月前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
25 0
|
3月前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
26 0
|
3月前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
18 0
|
3月前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
44 0
|
3月前
|
安全 编译器 TensorFlow
JAX 中文文档(四)(5)
JAX 中文文档(四)
22 0