JAX 中文文档(六)(2)

简介: JAX 中文文档(六)

JAX 中文文档(六)(1)https://developer.aliyun.com/article/1559681


贝叶斯推断的自动批处理

原文:jax.readthedocs.io/en/latest/notebooks/vmapped_log_probs.html

[外链图片转存中…(img-Fd5vUVOI-1718950514656)]

本笔记演示了一个简单的贝叶斯推断示例,其中自动批处理使用户代码更易于编写、更易于阅读,减少了错误的可能性。

灵感来自@davmre 的一个笔记本。

import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
import numpy as np
import scipy as sp 

生成一个虚拟的二分类数据集

np.random.seed(10009)
num_features = 10
num_points = 100
true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32) 
y 
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32) 

编写模型的对数联合函数

我们将编写一个非批处理版本、一个手动批处理版本和一个自动批处理版本。

非批量化

def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result 
log_joint(np.random.randn(num_features)) 
Array(-213.2356, dtype=float32) 
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
  batch_size = 10
  batched_test_beta = np.random.randn(batch_size, num_features)
  log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
  print("Caught expected exception " + str(e)) 
Caught expected exception Incompatible shapes for broadcasting: shapes=[(100,), (100, 10)] 

手动批处理

def batched_log_joint(beta):
    result = 0.
    # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
    # or setting it incorrectly yields an error; at worst, it silently changes the
    # semantics of the model.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
                           axis=-1)
    # Note the multiple transposes. Getting this right is not rocket science,
    # but it's also not totally mindless. (I didn't get it right on the first
    # try.)
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
                           axis=-1)
    return result 
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta) 
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291  ,
       -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ],      dtype=float32) 

使用 vmap 进行自动批处理

它只是有效地工作。

vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta) 
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291  ,
       -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ],      dtype=float32) 

自包含的变分推断示例

从上面复制了一小段代码。

设置(批量化的)对数联合函数

@jax.jit
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result
batched_log_joint = jax.jit(jax.vmap(log_joint)) 

定义 ELBO 及其梯度

def elbo(beta_loc, beta_log_scale, epsilon):
    beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
    return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1))) 

使用 SGD 优化 ELBO

def normal_sample(key, shape):
  """Convenience function for quasi-stateful RNG."""
    new_key, sub_key = random.split(key)
    return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.key(10003)
beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
    key, epsilon = normal_sample(key, epsilon_shape)
    elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
        beta_loc, beta_log_scale, epsilon)
    beta_loc += step_size * beta_loc_grad
    beta_log_scale += step_size * beta_log_scale_grad
    if i % 10 == 0:
        print('{}\t{}'.format(i, elbo_val)) 
0 -180.8538818359375
10  -113.06045532226562
20  -102.73727416992188
30  -99.787353515625
40  -98.90898132324219
50  -98.29745483398438
60  -98.18632507324219
70  -97.57972717285156
80  -97.28599548339844
90  -97.46996307373047
100 -97.4771728515625
110 -97.5806655883789
120 -97.4943618774414
130 -97.50271606445312
140 -96.86396026611328
150 -97.44197845458984
160 -97.06941223144531
170 -96.84028625488281
180 -97.21336364746094
190 -97.56503295898438
200 -97.26397705078125
210 -97.11979675292969
220 -97.39595031738281
230 -97.16831970214844
240 -97.118408203125
250 -97.24345397949219
260 -97.29788970947266
270 -96.69286346435547
280 -96.96438598632812
290 -97.30055236816406
300 -96.63591766357422
310 -97.0351791381836
320 -97.52909088134766
330 -97.28811645507812
340 -97.07321166992188
350 -97.15619659423828
360 -97.25881958007812
370 -97.19515228271484
380 -97.13092041015625
390 -97.11726379394531
400 -96.938720703125
410 -97.26676940917969
420 -97.35322570800781
430 -97.21007537841797
440 -97.28434753417969
450 -97.1630859375
460 -97.2612533569336
470 -97.21343994140625
480 -97.23997497558594
490 -97.14913940429688
500 -97.23527526855469
510 -96.93419647216797
520 -97.21209716796875
530 -96.82575988769531
540 -97.01284790039062
550 -96.94175720214844
560 -97.16520690917969
570 -97.29165649414062
580 -97.42941284179688
590 -97.24370574951172
600 -97.15222930908203
610 -97.49844360351562
620 -96.9906997680664
630 -96.88956451416016
640 -96.89968872070312
650 -97.13793182373047
660 -97.43705749511719
670 -96.99235534667969
680 -97.15623474121094
690 -97.1869125366211
700 -97.11160278320312
710 -97.78105163574219
720 -97.23226165771484
730 -97.16206359863281
740 -96.99581909179688
750 -96.6672134399414
760 -97.16795349121094
770 -97.51435089111328
780 -97.28900146484375
790 -96.91226196289062
800 -97.17100524902344
810 -97.29047393798828
820 -97.16242980957031
830 -97.19107055664062
840 -97.56382751464844
850 -97.00194549560547
860 -96.86555480957031
870 -96.76338195800781
880 -96.83660888671875
890 -97.12178039550781
900 -97.09554290771484
910 -97.0682373046875
920 -97.11947631835938
930 -96.87930297851562
940 -97.45624542236328
950 -96.69279479980469
960 -97.29376220703125
970 -97.3353042602539
980 -97.34962463378906
990 -97.09675598144531 

显示结果

虽然覆盖率不及理想,但也不错,而且没有人说变分推断是精确的。

figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best') 
<matplotlib.legend.Legend at 0x7f6a2c3c86a0> 

在多主机和多进程环境中使用 JAX

原文:jax.readthedocs.io/en/latest/multi_process.html

介绍

本指南解释了如何在 GPU 集群和Cloud TPU pod 等环境中使用 JAX,在这些环境中,加速器分布在多个 CPU 主机或 JAX 进程上。我们将这些称为“多进程”环境。

本指南专门介绍了如何在多进程设置中使用集体通信操作(例如 jax.lax.psum() ),尽管根据您的用例,其他通信方法也可能有用(例如 RPC,mpi4jax)。如果您尚未熟悉 JAX 的集体操作,建议从分片计算部分开始。在 JAX 的多进程环境中,重要的要求是加速器之间的直接通信链路,例如 Cloud TPU 的高速互连或NCCL 用于 GPU。这些链路允许集体操作在多个进程的加速器上高性能运行。

多进程编程模型

关键概念:

  • 您必须在每个主机上至少运行一个 JAX 进程。
  • 您应该使用 jax.distributed.initialize() 初始化集群。
  • 每个进程都有一组独特的本地设备可以访问。全局设备是所有进程的所有设备集合。
  • 使用标准的 JAX 并行 API,如 jit()(参见分片计算入门教程)和 shard_map()。jax.jit 仅接受全局形状的数组。shard_map 允许您按设备形状进行降级。
  • 确保所有进程按照相同顺序运行相同的并行计算。
  • 确保所有进程具有相同数量的本地设备。
  • 确保所有设备相同(例如,全部为 V100 或全部为 H100)。

启动 JAX 进程

与其他分布式系统不同,其中单个控制节点管理多个工作节点,JAX 使用“多控制器”编程模型,其中每个 JAX Python  进程独立运行,有时称为单程序多数据(SPMD)模型。通常,在每个进程中运行相同的 JAX Python  程序,每个进程的执行之间只有轻微差异(例如,不同的进程将加载不同的输入数据)。此外,您必须手动在每个主机上运行您的 JAX 程序! JAX 不会从单个程序调用自动启动多个进程。

(对于多个进程的要求,这就是为什么本指南不作为笔记本提供的原因——我们目前没有好的方法来从单个笔记本管理多个 Python 进程。)

初始化集群

要初始化集群,您应该在每个进程的开始调用 jax.distributed.initialize()jax.distributed.initialize() 必须在程序中的任何 JAX 计算执行之前早些时候调用。

API jax.distributed.initialize() 接受几个参数,即:

  • coordinator_address:集群中进程 0 的 IP 地址,以及该进程上可用的一个端口。进程 0 将启动一个通过该 IP 地址和端口暴露的 JAX 服务,集群中的其他进程将连接到该服务。
  • coordinator_bind_address:集群中进程 0 上的 JAX 服务将绑定到的 IP 地址和端口。默认情况下,它将使用与 coordinator_address 相同端口的所有可用接口进行绑定。
  • num_processes:集群中的进程数
  • process_id:本进程的 ID 号码,范围为[0 .. num_processes)
  • local_device_ids:将当前进程的可见设备限制为 local_device_ids

例如,在 GPU 上,典型用法如下:

import jax
jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
                           num_processes=2,
                           process_id=0) 

在 Cloud TPU、Slurm 和 Open MPI 环境中,你可以简单地调用 jax.distributed.initialize() 而无需参数。参数的默认值将自动选择。在使用 Slurm 和 Open MPI 运行 GPU 时,假定每个 GPU 启动一个进程,即每个进程只分配一个可见本地设备。否则假定每个主机启动一个进程,即每个进程将分配所有本地设备。只有当通过 mpirun/mpiexec 启动 JAX 进程时才会使用 Open MPI 自动初始化。

import jax
jax.distributed.initialize() 

在当前 TPU 上,调用 jax.distributed.initialize() 目前是可选的,但建议使用,因为它启用了额外的检查点和健康检查功能。

本地与全局设备

在开始从您的程序中运行多进程计算之前,了解本地全局设备之间的区别是很重要的。

进程的本地设备是它可以直接寻址和启动计算的设备。 例如,在 GPU 集群上,每个主机只能在直接连接的 GPU 上启动计算。在 Cloud TPU pod 上,每个主机只能在直接连接到该主机的 8 个 TPU 核心上启动计算(有关更多详情,请参阅Cloud TPU 系统架构文档)。你可以通过 jax.local_devices() 查看进程的本地设备。

全局设备是跨所有进程的设备。 一个计算可以跨进程的设备并通过设备之间的直接通信链路执行集体操作,只要每个进程在其本地设备上启动计算即可。你可以通过 jax.devices() 查看所有可用的全局设备。一个进程的本地设备总是全局设备的一个子集。

运行多进程计算

那么,你到底如何运行涉及跨进程通信的计算呢? 使用与单进程中相同的并行评估 API!

例如,shard_map() 可以用于在多个进程间并行计算。(如果您还不熟悉如何使用 shard_map 在单个进程内的多个设备上运行,请参阅分片计算介绍教程。)从概念上讲,这可以被视为在跨主机分片的单个数组上运行 pmap,其中每个主机只“看到”其本地分片的输入和输出。

下面是多进程 pmap 的实际示例:

# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
>>> jax.device_count()  # total number of accelerator devices in the cluster
32
>>> jax.local_device_count()  # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) 

非常重要的是,所有进程以相同的跨进程计算顺序运行。 在每个进程中运行相同的 JAX Python 程序通常就足够了。尽管运行相同程序,但仍需注意可能导致不同顺序计算的一些常见陷阱:

  • 将不同形状的输入传递给同一并行函数的进程可能导致挂起或不正确的返回值。只要它们在进程间产生相同形状的每设备数据分片,不同形状的输入是安全的;例如,传递不同的前导批次大小以在不同的本地设备数上运行是可以的,但是每个进程根据不同的最大示例长度填充其批次是不行的。
  • “最后一批”问题发生在并行函数在(训练)循环中调用时,其中一个或多个进程比其余进程更早退出循环。这将导致其余进程挂起,等待已经完成的进程开始计算。
  • 基于集合的非确定性顺序的条件可能导致代码进程挂起。例如,在当前 Python 版本上遍历 set 或者 Python 3.7 之前的 dict 可能会导致不同进程的顺序不同,即使插入顺序相同也是如此


JAX 中文文档(六)(3)https://developer.aliyun.com/article/1559683

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
3月前
|
存储 安全 API
JAX 中文文档(十)(2)
JAX 中文文档(十)
33 0
|
3月前
|
机器学习/深度学习 测试技术 索引
JAX 中文文档(二)(4)
JAX 中文文档(二)
38 0
|
3月前
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
22 0
|
3月前
|
存储 缓存 API
JAX 中文文档(五)(1)
JAX 中文文档(五)
26 0
|
3月前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
26 0
|
3月前
|
编译器 异构计算 索引
JAX 中文文档(五)(4)
JAX 中文文档(五)
53 0
|
3月前
|
编译器 测试技术 API
JAX 中文文档(四)(4)
JAX 中文文档(四)
27 0
|
3月前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
27 0
|
3月前
|
机器学习/深度学习 并行计算 安全
JAX 中文文档(七)(1)
JAX 中文文档(七)
34 0
|
3月前
|
并行计算 编译器
JAX 中文文档(六)(4)
JAX 中文文档(六)
19 0