JAX 中文文档(二)(3)https://developer.aliyun.com/article/1559670
处理 pytrees
JAX 内置支持类似字典(dicts)的数组对象,或者列表的列表的字典,或其他嵌套结构 — 在 JAX 中称为 pytrees。本节将解释如何使用它们,提供有用的代码示例,并指出常见的“坑”和模式。
什么是 pytree?
一个 pytree 是由类似容器的 Python 对象构建的容器结构 — “叶子” pytrees 和/或更多的 pytrees。一个 pytree 可以包括列表、元组和字典。一个叶子是任何不是 pytree 的东西,比如一个数组,但一个单独的叶子也是一个 pytree。
在机器学习(ML)的上下文中,一个 pytree 可能包含:
- 模型参数
- 数据集条目
- 强化学习代理观察
当处理数据集时,你经常会遇到 pytrees(比如列表的列表的字典)。
下面是一个简单 pytree 的示例。在 JAX 中,你可以使用 jax.tree.leaves()
,从树中提取扁平化的叶子,如此处所示:
import jax import jax.numpy as jnp example_trees = [ [1, 'a', object()], (1, (2, 3), ()), [1, {'k1': 2, 'k2': (3, 4)}, 5], {'a': 2, 'b': (2, 3)}, jnp.array([1, 2, 3]), ] # Print how many leaves the pytrees have. for pytree in example_trees: # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees. leaves = jax.tree.leaves(pytree) print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7f3d0048f950>] has 3 leaves: [1, 'a', <object object at 0x7f3d0048f950>] (1, (2, 3), ()) has 3 leaves: [1, 2, 3] [1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5] {'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3] Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]
在 JAX 中,任何由类似容器的 Python 对象构建的树状结构都可以被视为 pytree。如果它们在 pytree 注册表中,则类被视为容器类,默认情况下包括列表、元组和字典。任何类型不在 pytree 容器注册表中的对象都将被视为树中的叶子节点。
可以通过注册类并使用指定如何扁平化树的函数来扩展 pytree 注册表以包括用户定义的容器类;请参见下面的自定义 pytree 节点。
JAX 提供了许多实用程序来操作 pytrees。这些可以在 jax.tree_util
子包中找到;为了方便起见,其中许多在 jax.tree
模块中有别名。
常见功能:jax.tree.map
最常用的 pytree 函数是 jax.tree.map()
。它的工作方式类似于 Python 的原生 map
,但透明地操作整个 pytree。
这里有一个例子:
list_of_lists = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ] jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
jax.tree.map()
也允许在多个参数上映射一个N-ary函数。例如:
another_list_of_lists = list_of_lists jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
当使用多个参数与 jax.tree.map()
时,输入的结构必须完全匹配。也就是说,列表必须有相同数量的元素,字典必须有相同的键,等等。
用 jax.tree.map
示例解释 ML 模型参数
此示例演示了在训练简单多层感知器(MLP)时,pytree 操作如何有用。
从定义初始模型参数开始:
import numpy as np def init_mlp_params(layer_widths): params = [] for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]): params.append( dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in), biases=np.ones(shape=(n_out,)) ) ) return params params = init_mlp_params([1, 128, 128, 1])
使用 jax.tree.map()
检查初始参数的形状:
jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)}, {'biases': (128,), 'weights': (128, 128)}, {'biases': (1,), 'weights': (128, 1)}]
接下来,定义训练 MLP 模型的函数:
# Define the forward pass. def forward(params, x): *hidden, last = params for layer in hidden: x = jax.nn.relu(x @ layer['weights'] + layer['biases']) return x @ last['weights'] + last['biases'] # Define the loss function. def loss_fn(params, x, y): return jnp.mean((forward(params, x) - y) ** 2) # Set the learning rate. LEARNING_RATE = 0.0001 # Using the stochastic gradient descent, define the parameter update function. # Apply `@jax.jit` for JIT compilation (speed). @jax.jit def update(params, x, y): # Calculate the gradients with `jax.grad`. grads = jax.grad(loss_fn)(params, x, y) # Note that `grads` is a pytree with the same structure as `params`. # `jax.grad` is one of many JAX functions that has # built-in support for pytrees. # This is useful - you can apply the SGD update using JAX pytree utilities. return jax.tree.map( lambda p, g: p - LEARNING_RATE * g, params, grads ) ```## 自定义 pytree 节点 本节解释了在 JAX 中如何通过使用 `jax.tree_util.register_pytree_node()` 和 `jax.tree.map()` 扩展将被视为 pytree 内部节点(pytree 节点)的 Python 类型集合。 你为什么需要这个?在前面的示例中,pytrees 被展示为列表、元组和字典,其他所有内容都作为 pytree 叶子。这是因为如果你定义了自己的容器类,它会被视为 pytree 叶子,除非你*注册*它到 JAX。即使你的容器类内部包含树形结构,这个情况也是一样的。例如: ```py class Special(object): def __init__(self, x, y): self.x = x self.y = y jax.tree.leaves([ Special(0, 1), Special(2, 4), ])
[<__main__.Special at 0x7f3d005a23e0>, <__main__.Special at 0x7f3d005a1960>]
因此,如果你尝试使用 jax.tree.map()
来期望容器内的元素作为叶子,你会得到一个错误:
jax.tree.map(lambda x: x + 1, [ Special(0, 1), Special(2, 4) ])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'
作为解决方案,JAX 允许通过全局类型注册表扩展被视为内部 pytree 节点的类型集合。此外,已注册类型的值被递归地遍历。
首先,使用 jax.tree_util.register_pytree_node()
注册一个新类型:
from jax.tree_util import register_pytree_node class RegisteredSpecial(Special): def __repr__(self): return "RegisteredSpecial(x={}, y={})".format(self.x, self.y) def special_flatten(v): """Specifies a flattening recipe. Params: v: The value of the registered type to flatten. Returns: A pair of an iterable with the children to be flattened recursively, and some opaque auxiliary data to pass back to the unflattening recipe. The auxiliary data is stored in the treedef for use during unflattening. The auxiliary data could be used, for example, for dictionary keys. """ children = (v.x, v.y) aux_data = None return (children, aux_data) def special_unflatten(aux_data, children): """Specifies an unflattening recipe. Params: aux_data: The opaque data that was specified during flattening of the current tree definition. children: The unflattened children Returns: A reconstructed object of the registered type, using the specified children and auxiliary data. """ return RegisteredSpecial(*children) # Global registration register_pytree_node( RegisteredSpecial, special_flatten, # Instruct JAX what are the children nodes. special_unflatten # Instruct JAX how to pack back into a `RegisteredSpecial`. )
现在你可以遍历特殊容器结构:
jax.tree.map(lambda x: x + 1, [ RegisteredSpecial(0, 1), RegisteredSpecial(2, 4), ])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]
现代 Python 配备了有助于更轻松定义容器的有用工具。一些工具可以直接与 JAX 兼容,但其他的则需要更多关注。
例如,Python 中的 NamedTuple
子类不需要注册即可被视为 pytree 节点类型:
from typing import NamedTuple, Any class MyOtherContainer(NamedTuple): name: str a: Any b: Any c: Any # NamedTuple subclasses are handled as pytree nodes, so # this will work out-of-the-box. jax.tree.leaves([ MyOtherContainer('Alice', 1, 2, 3), MyOtherContainer('Bob', 4, 5, 6) ])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]
注意,现在 name
字段出现为一个叶子,因为所有元组元素都是子元素。这就是当你不必费力注册类时会发生的情况。 ## Pytrees 和 JAX 变换
许多 JAX 函数,比如 jax.lax.scan()
,操作的是数组的 pytrees。此外,所有 JAX 函数变换都可以应用于接受作为输入和输出的数组 pytrees 的函数。
一些 JAX 函数变换接受可选参数,指定如何处理某些输入或输出值(例如 in_axes
和 out_axes
参数给 jax.vmap()
)。这些参数也可以是 pytrees,它们的结构必须对应于相应参数的 pytree 结构。特别是为了能够将这些参数 pytree 中的叶子与参数 pytree 中的值匹配起来,“匹配”参数 pytrees 的叶子与参数 pytrees 的值,这些参数 pytrees 通常受到一定限制。
例如,如果你将以下输入传递给 jax.vmap()
(请注意,函数的输入参数被视为一个元组):
vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))
然后,你可以使用以下 in_axes
pytree 来指定仅映射 k2
参数(axis=0
),其余不进行映射(axis=None
):
vmap(f, in_axes=(None, {"k1": None, "k2": 0}))
可选参数 pytree 结构必须匹配主输入 pytree 的结构。但是,可选参数可以选择作为“前缀” pytree 指定,这意味着一个单独的叶值可以应用于整个子 pytree。
例如,如果你有与上述相同的 jax.vmap()
输入,但希望仅对字典参数进行映射,你可以使用:
vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0})
或者,如果希望每个参数都被映射,可以编写一个应用于整个参数元组 pytree 的单个叶值:
vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0})
这恰好是jax.vmap()
的默认in_axes
值。
对于转换函数的特定输入或输出值的其他可选参数,例如jax.vmap()
中的out_axes
,相同的逻辑也适用于其他可选参数。 ## 显式键路径
在 pytree 中,每个叶子都有一个键路径。叶的键路径是一个键的list
,列表的长度等于叶在 pytree 中的深度。每个键是一个hashable 对象,表示对应的 pytree 节点类型中的索引。键的类型取决于 pytree 节点类型;例如,对于dict
,键的类型与tuple
的键的类型不同。
对于任何 pytree 节点实例的内置 pytree 节点类型,其键集是唯一的。对于具有此属性的节点组成的 pytree,每个叶的键路径都是唯一的。
JAX 提供了以下用于处理键路径的jax.tree_util.*
方法:
jax.tree_util.tree_flatten_with_path()
: 类似于jax.tree.flatten()
,但返回键路径。jax.tree_util.tree_map_with_path()
: 类似于jax.tree.map()
,但函数还接受键路径作为参数。jax.tree_util.keystr()
: 给定一个通用键路径,返回一个友好的读取器字符串表达式。
例如,一个用例是打印与某个叶值相关的调试信息:
import collections ATuple = collections.namedtuple("ATuple", ('name')) tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')] flattened, _ = jax.tree_util.tree_flatten_with_path(tree) for key_path, value in flattened: print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1 Value of tree[1]['k1']: 2 Value of tree[1]['k2'][0]: 3 Value of tree[1]['k2'][1]: 4 Value of tree[2].name: foo
要表达键路径,JAX 提供了几种内置 pytree 节点类型的默认键类型,即:
SequenceKey(idx: int)
: 适用于列表和元组。DictKey(key: Hashable)
: 用于字典。GetAttrKey(name: str)
: 适用于namedtuple
和最好是自定义的 pytree 节点(更多见下一节)
您可以自由地为自定义节点定义自己的键类型。只要它们的__str__()
方法也被覆盖为友好的表达式,它们将与jax.tree_util.keystr()
一起使用。
for key_path, _ in flattened: print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),) Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1')) Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0)) Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1)) Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name')) ```## 常见的 pytree 陷阱 本节介绍了在使用 JAX pytrees 时遇到的一些常见问题(“陷阱”)。 ### 将 pytree 节点误认为叶子 一个常见的需要注意的问题是意外引入*树节点*而不是*叶子节点*: ```py a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))] # Try to make another pytree with ones instead of zeros. shapes = jax.tree.map(lambda x: x.shape, a_tree) jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)), (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
这里发生的是数组的shape
是一个元组,它是一个 pytree 节点,其元素是叶子节点。因此,在映射中,不是在例如(2, 3)
上调用jnp.ones
,而是在2
和3
上调用。
解决方案将取决于具体情况,但有两种广泛适用的选项:
- 重写代码以避免中间
jax.tree.map()
。 - 将元组转换为 NumPy 数组(
np.array
)或 JAX NumPy 数组(jnp.array
),这样整个序列就成为一个叶子。
jax.tree_util
对None
的处理
jax.tree_util
函数将None
视为不存在的 pytree 节点,而不是叶子:
jax.tree.leaves([None, None, None])
[]
要将None
视为叶子,可以使用is_leaf
参数:
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]
自定义 pytrees 和使用意外值进行初始化
另一个与用户定义的 pytree 对象常见的陷阱是,JAX 变换偶尔会使用意外值来初始化它们,因此在初始化时执行的任何输入验证可能会失败。例如:
class MyTree: def __init__(self, a): self.a = jnp.asarray(a) register_pytree_node(MyTree, lambda tree: ((tree.a,), None), lambda _, args: MyTree(*args)) tree = MyTree(jnp.arange(5.0)) jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`.
TypeError: Cannot interpret '<object object at 0x7f3cce5742a0>' as a data type The above exception was the direct cause of the following exception: TypeError: Cannot determine dtype of <object object at 0x7f3cce5742a0> During handling of the above exception, another exception occurred: TypeError: Value '<object object at 0x7f3cce5742a0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`.
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:3289: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error. return array(a, dtype=dtype, copy=bool(copy), order=order)
TypeError: Cannot interpret '<object object at 0x7f3cce574780>' as a data type The above exception was the direct cause of the following exception: TypeError: Cannot determine dtype of <object object at 0x7f3cce574780> During handling of the above exception, another exception occurred: TypeError: Value '<object object at 0x7f3cce574780>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
- 在第一种情况下,使用
jax.vmap(...)(tree)
,JAX 的内部使用object()
值的数组来推断树的结构。 - 在第二种情况下,使用
jax.jacobian(...)(tree)
,将一个将树映射到树的函数的雅可比矩阵定义为树的树。
潜在解决方案 1:
- 自定义 pytree 类的
__init__
和__new__
方法通常应避免执行任何数组转换或其他输入验证,或者预期并处理这些特殊情况。例如:
class MyTree: def __init__(self, a): if not (type(a) is object or a is None or isinstance(a, MyTree)): a = jnp.asarray(a) self.a = a
潜在解决方案 2:
- 结构化您的自定义
tree_unflatten
函数,以避免调用__init__
。如果选择这条路线,请确保您的tree_unflatten
函数在代码更新时与__init__
保持同步。例如:
def tree_unflatten(aux_data, children): del aux_data # Unused in this class. obj = object.__new__(MyTree) obj.a = a return obj ```## 常见 pytree 模式 本节涵盖了 JAX pytrees 中一些最常见的模式。 ### 使用 `jax.tree.map` 和 `jax.tree.transpose` 对 pytree 进行转置 要对 pytree 进行转置(将树的列表转换为列表的树),JAX 提供了两个函数:{func} `jax.tree.map`(更基础)和 `jax.tree.transpose()`(更灵活、复杂且冗长)。 **选项 1:** 使用 `jax.tree.map()`。这里是一个例子: ```py def tree_transpose(list_of_trees): """ Converts a list of trees of identical structure into a single tree of lists. """ return jax.tree.map(lambda *xs: list(xs), *list_of_trees) # Convert a dataset from row-major to column-major. episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)] tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}
选项 2: 对于更复杂的转置,使用 jax.tree.transpose()
,它更冗长,但允许您指定更灵活的内部和外部 pytree 结构。例如:
jax.tree.transpose( outer_treedef = jax.tree.structure([0 for e in episode_steps]), inner_treedef = jax.tree.structure(episode_steps[0]), pytree_to_transpose = episode_steps )
{'obs': [3, 4], 't': [1, 2]}
JAX 中文文档(二)(5)https://developer.aliyun.com/article/1559672