JAX 中文文档(六)(1)

简介: JAX 中文文档(六)


原文:jax.readthedocs.io/en/latest/

高级教程

原文:jax.readthedocs.io/en/latest/advanced_guide.html

本节包含更高级主题的示例和教程,如多核计算、自定义操作及更深入的应用

示例

  • 使用 tensorflow/datasets 进行简单神经网络训练
  • 使用 PyTorch 数据加载进行简单神经网络训练
  • 贝叶斯推断的自动批处理

并行计算

  • 在多主机和多进程环境中使用 JAX
  • 分布式数组和自动并行化
  • 带有 shard_map 的 SPMD 多设备并行性
  • API 规范
  • 集合教程
  • 玩具示例
  • 多主机/多进程环境中的分布式数据加载
  • 带有 xmap 的命名轴和易于修改的并行性

自动微分

  • 自动微分食谱
  • 为可转换为 JAX 的 Python 函数编写自定义导数规则
  • 使用 jax.checkpoint(又名 jax.remat)控制自动微分的保存值

JAX 内部机制

  • JAX 原语的工作原理
  • 在 JAX 中编写自定义 Jaxpr 解释器
  • 使用 C++ 和 CUDA 为 GPU 的自定义操作
  • 检查正确性

深入探讨

  • JAX 中的广义卷积

训练一个简单的神经网络,使用 tensorflow/datasets 进行数据加载

原文:jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

[外链图片转存中…(img-EaO4RJX2-1718950514655)]

neural_network_and_data_loading.ipynb 衍生

让我们结合我们在快速入门中展示的所有内容来训练一个简单的神经网络。我们将首先使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算。我们将使用 tensorflow/datasets 数据加载 API 来加载图像和标签(因为它非常出色,世界上不需要再另外一种数据加载库 😛)。

当然,您可以使用 JAX 与任何与 NumPy 兼容的 API,使模型的指定更加即插即用。这里,仅供解释用途,我们不会使用任何神经网络库或特殊的 API 来构建我们的模型。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random 

超参数

让我们先处理一些簿记事项。

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0)) 

自动批量预测

让我们首先定义我们的预测函数。请注意,我们为单个图像示例定义了这个函数。我们将使用 JAX 的 vmap 函数自动处理小批量数据,而不会影响性能。

from jax.scipy.special import logsumexp
def relu(x):
  return jnp.maximum(0, x)
def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits) 

让我们检查我们的预测函数只适用于单个图像。

# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape) 
(10,) 
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!') 
Invalid shapes! 
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape) 
(10, 10) 

到目前为止,我们已经具备了定义和训练神经网络所需的所有要素。我们已经构建了一个自动批处理版本的 predict 函数,应该可以在损失函数中使用。我们应该能够使用 grad 对神经网络参数的损失函数进行求导。最后,我们应该能够使用 jit 加速整个过程。

实用函数和损失函数

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)] 

使用 tensorflow/datasets 进行数据加载

JAX 主要专注于程序转换和支持加速的 NumPy,因此我们不包括数据加载或整理在 JAX 库中。已经有很多出色的数据加载器,所以我们只需使用它们,而不是重新发明轮子。我们将使用 tensorflow/datasets 数据加载器。

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')
import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'
# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c
# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)
# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels) 
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape) 
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10) 

训练循环

import time
def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)
for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time
  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc)) 
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341 

我们现在已经使用了大部分 JAX API:grad 用于求导,jit 用于加速和 vmap 用于自动向量化。我们使用 NumPy 来指定所有的计算,并从 tensorflow/datasets 借用了优秀的数据加载器,并在 GPU 上运行了整个过程。

训练一个简单的神经网络,使用 PyTorch 进行数据加载

原文:jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

[外链图片转存中…(img-Jm94pLhk-1718950514656)]

版权所有 2018 年 JAX 作者。

根据 Apache 许可证第 2.0 版许可使用本文件;除非符合许可证,否则不得使用本文件。您可以在以下链接获取许可证的副本

https://www.apache.org/licenses/LICENSE-2.0

除非适用法律要求或书面同意,否则在许可证下发布的软件是按“原样”分发的,不提供任何明示或暗示的担保或条件。有关特定语言下的权限和限制,请参阅许可证。

让我们结合我们在快速入门中展示的一切,来训练一个简单的神经网络。我们将首先使用 JAX 进行计算,指定并训练一个简单的 MLP 来处理 MNIST 数据集。我们将使用 PyTorch 的数据加载 API 加载图像和标签(因为它非常棒,世界上不需要另一个数据加载库)。

当然,您可以使用 JAX 与任何与 NumPy 兼容的 API,以使模型的指定更加即插即用。在这里,仅用于解释目的,我们不会使用任何神经网络库或特殊的 API 来构建我们的模型。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random 

超参数

让我们先处理一些记录事项。

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0)) 

自动批处理预测

让我们首先定义我们的预测函数。请注意,我们正在为单个图像示例定义这个函数。我们将使用 JAX 的vmap函数自动处理小批量,而无需性能损失。

from jax.scipy.special import logsumexp
def relu(x):
  return jnp.maximum(0, x)
def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits) 

让我们检查我们的预测函数是否只适用于单个图像。

# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape) 
(10,) 
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!') 
Invalid shapes! 
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape) 
(10, 10) 

到这一步,我们已经具备了定义和训练神经网络所需的所有要素。我们已经构建了predict的自动批处理版本,我们应该能够在损失函数中使用它。我们应该能够使用grad来计算损失相对于神经网络参数的导数。最后,我们应该能够使用jit来加速整个过程。

实用工具和损失函数

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)] 

使用 PyTorch 进行数据加载

JAX 专注于程序转换和支持加速器的 NumPy,因此我们不在 JAX  库中包括数据加载或数据处理。已经有很多出色的数据加载器,所以我们只需使用它们,而不是重新发明轮子。我们将获取 PyTorch  的数据加载器,并制作一个小的 shim 以使其与 NumPy 数组兼容。

!pip  install  torch  torchvision 
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)
Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)
Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0) 
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST
def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32)) 
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) 
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Processing...
Done! 
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)
# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets) 
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data
  warnings.warn("train_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets
  warnings.warn("train_labels has been renamed targets")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data
  warnings.warn("test_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets
  warnings.warn("test_labels has been renamed targets") 

训练循环

import time
for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time
  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc)) 
Epoch 0 in 55.15 sec
Training set accuracy 0.9157500267028809
Test set accuracy 0.9195000529289246
Epoch 1 in 42.26 sec
Training set accuracy 0.9372166991233826
Test set accuracy 0.9384000301361084
Epoch 2 in 44.37 sec
Training set accuracy 0.9491666555404663
Test set accuracy 0.9469000697135925
Epoch 3 in 41.75 sec
Training set accuracy 0.9568166732788086
Test set accuracy 0.9534000158309937
Epoch 4 in 41.16 sec
Training set accuracy 0.9631333351135254
Test set accuracy 0.9577000737190247
Epoch 5 in 38.89 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9616000652313232
Epoch 6 in 40.68 sec
Training set accuracy 0.9708333611488342
Test set accuracy 0.9650000333786011
Epoch 7 in 41.50 sec
Training set accuracy 0.973716676235199
Test set accuracy 0.9672000408172607 

我们现在已经完全使用了 JAX API:grad 用于求导,jit 用于加速,vmap 用于自动向量化。我们使用 NumPy 来指定所有的计算,借用了 PyTorch 中优秀的数据加载器,并且在 GPU 上运行整个过程。


JAX 中文文档(六)(2)https://developer.aliyun.com/article/1559682

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
4月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
39 1
|
4月前
|
存储 并行计算 开发工具
JAX 中文文档(十)(1)
JAX 中文文档(十)
46 0
|
4月前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
29 0
|
4月前
|
存储 缓存 API
JAX 中文文档(五)(1)
JAX 中文文档(五)
38 0
|
4月前
JAX 中文文档(九)(3)
JAX 中文文档(九)
38 0
|
4月前
|
存储 机器学习/深度学习 并行计算
JAX 中文文档(二)(5)
JAX 中文文档(二)
41 0
|
4月前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
115 0
|
4月前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
51 0
|
4月前
|
并行计算 测试技术 异构计算
JAX 中文文档(一)(5)
JAX 中文文档(一)
76 0
|
4月前
|
存储 并行计算 数据可视化
JAX 中文文档(六)(3)
JAX 中文文档(六)
30 0