JAX 中文文档(十四)(3)https://developer.aliyun.com/article/1559758
jax.sharding 模块
类
class jax.sharding.Sharding
描述了jax.Array
如何跨设备布局。
property addressable_devices: set[Device]
Sharding
中由当前进程可寻址的设备集合。
addressable_devices_indices_map(global_shape
从可寻址设备到它们包含的数组数据切片的映射。
addressable_devices_indices_map
包含适用于可寻址设备的device_indices_map
部分。
参数:
global_shape (tuple[int, …**])
返回类型:
Mapping[Device, tuple[slice, …] | None]
property device_set: set[Device]
这个Sharding
跨越的设备集合。
在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
devices_indices_map(global_shape)
返回从设备到它们包含的数组切片的映射。
映射包括所有全局设备,即包括来自其他进程的不可寻址设备。
参数:
global_shape (tuple[int, …**])
返回类型:
Mapping[Device, tuple[slice, …]]
is_equivalent_to(other, ndim)
如果两个分片等效,则返回True
。
如果它们在相同设备上放置了相同的逻辑数组分片,则两个分片是等效的。
例如,如果NamedSharding
和PositionalSharding
都将数组的相同分片放置在相同的设备上,则它们可能是等效的。
参数:
- self (Sharding)
- other (Sharding)
- ndim (int)
返回类型:
property is_fully_addressable: bool
此分片是否是完全可寻址的?
如果当前进程能够寻址Sharding
中列出的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable
等效于 “is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则分片是完全复制的。
property memory_kind: str | None
返回分片的内存类型。
shard_shape(global_shape)
返回每个设备上数据的形状。
此函数返回的分片形状是从global_shape
和分片属性计算得出的。
参数:
global_shape (tuple[int, …**])
返回类型:
with_memory_kind(kind)
返回具有指定内存类型的新分片实例。
参数:
kind (str)
返回类型:
分片
class jax.sharding.SingleDeviceSharding
基类:分片
一个将其数据放置在单个设备上的分片
。
参数:
device – 单个设备
。
示例
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
property device_set: set[Device]
此分片
跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的非可寻址设备。
devices_indices_map(global_shape)
返回从设备到每个包含的数组片段的映射。
映射包括所有全局设备,即包括来自其他进程的非可寻址设备。
参数:
global_shape (tuple[int, …**])
返回类型:
property is_fully_addressable: bool
此分片是否完全可寻址?
如果当前进程可以寻址分片
中命名的所有设备,则称分片完全可寻址。is_fully_addressable
在多进程 JAX 中等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则分片完全复制。
property memory_kind: str | None
返回分片的内存类型。
with_memory_kind(kind)
返回具有指定内存类型的新分片实例。
参数:
kind (str)
返回类型:
单设备分片
class jax.sharding.NamedSharding
基类:分片
一个NamedSharding
使用命名轴来表示分片。
一个NamedSharding
是设备Mesh
和描述如何跨该网格对数组进行分片的PartitionSpec
的组合。
一个Mesh
是 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x'
或 'y'
。
一个PartitionSpec
是一个元组,其元素可以是None
、一个网格轴或一组网格轴的元组。每个元素描述如何在零个或多个网格维度上对输入维度进行分区。例如,PartitionSpec('x', 'y')
表示数据的第一维在网格的 x
轴上进行分片,第二维在网格的 y
轴上进行分片。
分布式数组和自动并行化(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names
)教程详细讲解了如何使用Mesh
和PartitionSpec
,包括更多细节和图示。
参数:
- mesh – 一个
jax.sharding.Mesh
对象。 - spec – 一个
jax.sharding.PartitionSpec
对象。
示例
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property addressable_devices: set[Device]
当前进程可以访问的Sharding
中的设备集。
property device_set: set[Device]
该Sharding
跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
property is_fully_addressable: bool
此分片是否完全可寻址?
一个分片如果当前进程可以访问Sharding
中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable
等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则称分片为完全复制。
property memory_kind: str | None
返回分片的内存类型。
property mesh
(self) -> object
property spec
(self) -> object
with_memory_kind(kind)
返回具有指定内存类型的新Sharding
实例。
参数:
kind (str)
返回类型:
NamedSharding
class jax.sharding.PositionalSharding(devices, *, memory_kind=None)
基类:Sharding
参数:
- devices (Sequence*[xc.Device]* | np.ndarray)
- memory_kind (str | None)
property device_set: set[Device]
该Sharding
跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
property is_fully_addressable: bool
此分片是否完全可寻址?
一个分片如果当前进程可以访问Sharding
中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable
等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则称分片为完全复制。
property memory_kind: str | None
返回分片的内存类型。
with_memory_kind(kind)
返回具有指定内存类型的新Sharding
实例。
参数:
kind (str)
返回类型:
PositionalSharding
class jax.sharding.PmapSharding
基类:Sharding
描述了jax.pmap()
使用的分片。
classmethod default(shape, sharded_dim=0, devices=None)
创建一个PmapSharding
,与jax.pmap()
使用的默认放置方式匹配。
参数:
- shape (tuple[int, …**]) – 输入数组的形状。
- sharded_dim (int") – 输入数组进行分片的维度。默认为 0。
- devices(Sequence[Device] | None) – 可选的设备序列。如果省略,隐含的
- used(pmap 使用的设备顺序是) –
jax.local_devices()
。 - of(这是顺序) –
jax.local_devices()
。
返回类型:
PmapSharding
property device_set: set[Device]
这个Sharding
跨越的设备集合。
在多控制器 JAX 中,设备集合是全局的,即包括其他进程的非可寻址设备。
property devices
(self)-> ndarray
devices_indices_map(global_shape)
返回设备到每个包含的数组切片的映射。
映射包括所有全局设备,即包括其他进程的非可寻址设备。
参数:
返回类型:
is_equivalent_to(other, ndim)
如果两个分片等效,则返回True
。
如果它们将相同的逻辑数组分片放置在相同的设备上,则两个分片是等效的。
例如,如果NamedSharding
和PositionalSharding
将数组的相同分片放置在相同的设备上,则它们可能是等效的。
参数:
- self(PmapSharding)
- other(PmapSharding)
- ndim(int)
返回类型:
布尔(“in Python v3.12”)
property is_fully_addressable: bool
这个分片是否完全可寻址?
如果当前进程能够处理Sharding
中命名的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable
相当于“is_local”。
property is_fully_replicated: bool
这个分片是否完全复制?
如果每个设备都有完整数据的副本,则分片是完全复制的。
property memory_kind: str | None
返回分片的内存类型。
shard_shape(global_shape)
返回每个设备上数据的形状。
此函数返回的分片形状是从global_shape
和分片属性计算而来的。
参数:
返回类型:
property sharding_spec
(self)-> jax::ShardingSpec
with_memory_kind(kind)
返回具有指定内存类型的新 Sharding 实例。
参数:
kind(str)
class jax.sharding.GSPMDSharding
基类:Sharding
property device_set: set[Device]
这个Sharding
跨越的设备集合。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
property is_fully_addressable: bool
此分片是否完全可寻址?
如果当前进程可以访问Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
相当于多进程 JAX 中的“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
一个分片是完全复制的,如果每个设备都有整个数据的完整副本。
property memory_kind: str | None
返回分片的内存类型。
with_memory_kind(kind)
返回具有指定内存类型的新 Sharding 实例。
参数:
kind(str)
返回类型:
GSPMDSharding
class jax.sharding.PartitionSpec(*partitions)
元组描述如何在设备网格上对数组进行分区。
每个元素都可以是None
、字符串或字符串元组。有关更多详细信息,请参阅jax.sharding.NamedSharding
的文档。
此类存在,以便 JAX 的 pytree 实用程序可以区分分区规范和应视为 pytrees 的元组。
class jax.sharding.Mesh(devices, axis_names)
声明在此管理器范围内可用的硬件资源。
特别是,所有axis_names
在管理块内都变成有效的资源名称,并且可以在jax.experimental.pjit.pjit()
的in_axis_resources
参数中使用,还请参阅 JAX 的多进程编程模型(jax.readthedocs.io/en/latest/multi_process.html
)和分布式数组与自动并行化教程(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
)
如果您在多线程中编译,请确保with Mesh
上下文管理器位于线程将执行的函数内部。
参数:
- devices(ndarray) - 包含 JAX 设备对象(例如从
jax.devices()
获得的对象)的 NumPy ndarray 对象。 - axis_names(tuple[Any, …**]) - 资源轴名称序列,用于分配给
devices
参数的维度。其长度应与devices
的秩匹配。
示例
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
JAX 中文文档(十四)(5)https://developer.aliyun.com/article/1559760