JAX 中文文档(十四)(4)

简介: JAX 中文文档(十四)

JAX 中文文档(十四)(3)https://developer.aliyun.com/article/1559758

jax.sharding 模块

原文:jax.readthedocs.io/en/latest/jax.sharding.html

class jax.sharding.Sharding

描述了jax.Array如何跨设备布局。

property addressable_devices: set[Device]

Sharding中由当前进程可寻址的设备集合。

addressable_devices_indices_map(global_shape

从可寻址设备到它们包含的数组数据切片的映射。

addressable_devices_indices_map 包含适用于可寻址设备的device_indices_map部分。

参数:

global_shape (tuple[int, …**])

返回类型:

Mapping[Device, tuple[slice, …] | None]

property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

devices_indices_map(global_shape)

返回从设备到它们包含的数组切片的映射。

映射包括所有全局设备,即包括来自其他进程的不可寻址设备。

参数:

global_shape (tuple[int, …**])

返回类型:

Mapping[Device, tuple[slice, …]]

is_equivalent_to(other, ndim)

如果两个分片等效,则返回True

如果它们在相同设备上放置了相同的逻辑数组分片,则两个分片是等效的。

例如,如果NamedShardingPositionalSharding都将数组的相同分片放置在相同的设备上,则它们可能是等效的。

参数:

  • self (Sharding)
  • other (Sharding)
  • ndim (int)

返回类型:

bool

property is_fully_addressable: bool

此分片是否是完全可寻址的?

如果当前进程能够寻址Sharding中列出的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable 等效于 “is_local”。

property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则分片是完全复制的。

property memory_kind: str | None

返回分片的内存类型。

shard_shape(global_shape)

返回每个设备上数据的形状。

此函数返回的分片形状是从global_shape和分片属性计算得出的。

参数:

global_shape (tuple[int, …**])

返回类型:

tuple[int, …]

with_memory_kind(kind)

返回具有指定内存类型的新分片实例。

参数:

kind (str)

返回类型:

分片

class jax.sharding.SingleDeviceSharding

基类:分片

一个将其数据放置在单个设备上的分片

参数:

device – 单个设备

示例

>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0]) 
property device_set: set[Device]

分片跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的非可寻址设备。

devices_indices_map(global_shape)

返回从设备到每个包含的数组片段的映射。

映射包括所有全局设备,即包括来自其他进程的非可寻址设备。

参数:

global_shape (tuple[int, …**])

返回类型:

映射[设备, tuple[slice, …]]

property is_fully_addressable: bool

此分片是否完全可寻址?

如果当前进程可以寻址分片中命名的所有设备,则称分片完全可寻址。is_fully_addressable在多进程 JAX 中等同于“is_local”。

property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则分片完全复制。

property memory_kind: str | None

返回分片的内存类型。

with_memory_kind(kind)

返回具有指定内存类型的新分片实例。

参数:

kind (str)

返回类型:

单设备分片

class jax.sharding.NamedSharding

基类:分片

一个NamedSharding使用命名轴来表示分片。

一个NamedSharding是设备Mesh和描述如何跨该网格对数组进行分片的PartitionSpec的组合。

一个Mesh是 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x''y'

一个PartitionSpec是一个元组,其元素可以是None、一个网格轴或一组网格轴的元组。每个元素描述如何在零个或多个网格维度上对输入维度进行分区。例如,PartitionSpec('x', 'y')表示数据的第一维在网格的 x 轴上进行分片,第二维在网格的 y 轴上进行分片。

分布式数组和自动并行化(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)教程详细讲解了如何使用MeshPartitionSpec,包括更多细节和图示。

参数:

  • mesh – 一个jax.sharding.Mesh对象。
  • spec – 一个 jax.sharding.PartitionSpec 对象。

示例

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec) 
property addressable_devices: set[Device]

当前进程可以访问的Sharding中的设备集。

property device_set: set[Device]

Sharding跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

property is_fully_addressable: bool

此分片是否完全可寻址?

一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。

property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则称分片为完全复制。

property memory_kind: str | None

返回分片的内存类型。

property mesh

(self) -> object

property spec

(self) -> object

with_memory_kind(kind)

返回具有指定内存类型的新Sharding实例。

参数:

kind (str)

返回类型:

NamedSharding

class jax.sharding.PositionalSharding(devices, *, memory_kind=None)

基类:Sharding

参数:

  • devices (Sequence*[xc.Device]* | np.ndarray)
  • memory_kind (str | None)
property device_set: set[Device]

Sharding跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

property is_fully_addressable: bool

此分片是否完全可寻址?

一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。

property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则称分片为完全复制。

property memory_kind: str | None

返回分片的内存类型。

with_memory_kind(kind)

返回具有指定内存类型的新Sharding实例。

参数:

kind (str)

返回类型:

PositionalSharding

class jax.sharding.PmapSharding

基类:Sharding

描述了jax.pmap()使用的分片。

classmethod default(shape, sharded_dim=0, devices=None)

创建一个PmapSharding,与jax.pmap()使用的默认放置方式匹配。

参数:

  • shape (tuple[int, …**]) – 输入数组的形状。
  • sharded_dim (int") – 输入数组进行分片的维度。默认为 0。
  • devicesSequence[Device] | None) – 可选的设备序列。如果省略,隐含的
  • usedpmap 使用的设备顺序是) – jax.local_devices()
  • of这是顺序) – jax.local_devices()

返回类型:

PmapSharding

property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括其他进程的非可寻址设备。

property devices

(self)-> ndarray

devices_indices_map(global_shape)

返回设备到每个包含的数组切片的映射。

映射包括所有全局设备,即包括其他进程的非可寻址设备。

参数:

global_shape元组[int,…**]

返回类型:

Mapping[Device元组[切片,…]]

is_equivalent_to(other, ndim)

如果两个分片等效,则返回True

如果它们将相同的逻辑数组分片放置在相同的设备上,则两个分片是等效的。

例如,如果NamedShardingPositionalSharding将数组的相同分片放置在相同的设备上,则它们可能是等效的。

参数:

  • selfPmapSharding
  • otherPmapSharding
  • ndimint

返回类型:

布尔(“in Python v3.12”)

property is_fully_addressable: bool

这个分片是否完全可寻址?

如果当前进程能够处理Sharding中命名的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable相当于“is_local”。

property is_fully_replicated: bool

这个分片是否完全复制?

如果每个设备都有完整数据的副本,则分片是完全复制的。

property memory_kind: str | None

返回分片的内存类型。

shard_shape(global_shape)

返回每个设备上数据的形状。

此函数返回的分片形状是从global_shape和分片属性计算而来的。

参数:

global_shape元组[int,…**]

返回类型:

元组[int,…]

property sharding_spec

(self)-> jax::ShardingSpec

with_memory_kind(kind)

返回具有指定内存类型的新 Sharding 实例。

参数:

kindstr

class jax.sharding.GSPMDSharding

基类:Sharding

property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

property is_fully_addressable: bool

此分片是否完全可寻址?

如果当前进程可以访问Sharding中命名的所有设备,则分片是完全可寻址的。is_fully_addressable相当于多进程 JAX 中的“is_local”。

property is_fully_replicated: bool

此分片是否完全复制?

一个分片是完全复制的,如果每个设备都有整个数据的完整副本。

property memory_kind: str | None

返回分片的内存类型。

with_memory_kind(kind)

返回具有指定内存类型的新 Sharding 实例。

参数:

kindstr

返回类型:

GSPMDSharding

class jax.sharding.PartitionSpec(*partitions)

元组描述如何在设备网格上对数组进行分区。

每个元素都可以是None、字符串或字符串元组。有关更多详细信息,请参阅jax.sharding.NamedSharding的文档。

此类存在,以便 JAX 的 pytree 实用程序可以区分分区规范和应视为 pytrees 的元组。

class jax.sharding.Mesh(devices, axis_names)

声明在此管理器范围内可用的硬件资源。

特别是,所有axis_names在管理块内都变成有效的资源名称,并且可以在jax.experimental.pjit.pjit()in_axis_resources参数中使用,还请参阅 JAX 的多进程编程模型(jax.readthedocs.io/en/latest/multi_process.html)和分布式数组与自动并行化教程(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

如果您在多线程中编译,请确保with Mesh上下文管理器位于线程将执行的函数内部。

参数:

  • devicesndarray) - 包含 JAX 设备对象(例如从jax.devices()获得的对象)的 NumPy ndarray 对象。
  • axis_namestuple[Any, …**]) - 资源轴名称序列,用于分配给devices参数的维度。其长度应与devices的秩匹配。

示例

>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 


>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 


JAX 中文文档(十四)(5)https://developer.aliyun.com/article/1559760

相关文章
|
3天前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
14 3
|
3天前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
13 3
|
3天前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
9 2
|
3天前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
7 1
|
3天前
|
并行计算 异构计算 索引
JAX 中文文档(十六)(4)
JAX 中文文档(十六)
12 2
|
3天前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
8 1
|
3天前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
6 1
|
3天前
|
关系型数据库
JAX 中文文档(十四)(1)
JAX 中文文档(十四)
10 0
|
3天前
|
资源调度 算法 安全
JAX 中文文档(十四)(3)
JAX 中文文档(十四)
8 0
|
3天前
|
API 异构计算 索引
JAX 中文文档(十四)(2)
JAX 中文文档(十四)
8 0