高级教程
本节包含更高级主题的示例和教程,如多核计算、自定义操作及更深入的应用
示例
- 使用 tensorflow/datasets 进行简单神经网络训练
- 使用 PyTorch 数据加载进行简单神经网络训练
- 贝叶斯推断的自动批处理
并行计算
- 在多主机和多进程环境中使用 JAX
- 分布式数组和自动并行化
- 带有
shard_map
的 SPMD 多设备并行性 - API 规范
- 集合教程
- 玩具示例
- 多主机/多进程环境中的分布式数据加载
- 带有
xmap
的命名轴和易于修改的并行性
自动微分
- 自动微分食谱
- 为可转换为 JAX 的 Python 函数编写自定义导数规则
- 使用
jax.checkpoint
(又名jax.remat
)控制自动微分的保存值
JAX 内部机制
- JAX 原语的工作原理
- 在 JAX 中编写自定义 Jaxpr 解释器
- 使用 C++ 和 CUDA 为 GPU 的自定义操作
- 检查正确性
深入探讨
- JAX 中的广义卷积
训练一个简单的神经网络,使用 tensorflow/datasets
进行数据加载
原文:
jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
[外链图片转存中…(img-EaO4RJX2-1718950514655)]
从 neural_network_and_data_loading.ipynb
衍生
让我们结合我们在快速入门中展示的所有内容来训练一个简单的神经网络。我们将首先使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算。我们将使用 tensorflow/datasets
数据加载 API 来加载图像和标签(因为它非常出色,世界上不需要再另外一种数据加载库 😛)。
当然,您可以使用 JAX 与任何与 NumPy 兼容的 API,使模型的指定更加即插即用。这里,仅供解释用途,我们不会使用任何神经网络库或特殊的 API 来构建我们的模型。
import jax.numpy as jnp from jax import grad, jit, vmap from jax import random
超参数
让我们先处理一些簿记事项。
# A helper function to randomly initialize weights and biases # for a dense neural network layer def random_layer_params(m, n, key, scale=1e-2): w_key, b_key = random.split(key) return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) # Initialize all layers for a fully-connected neural network with sizes "sizes" def init_network_params(sizes, key): keys = random.split(key, len(sizes)) return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] layer_sizes = [784, 512, 512, 10] step_size = 0.01 num_epochs = 10 batch_size = 128 n_targets = 10 params = init_network_params(layer_sizes, random.key(0))
自动批量预测
让我们首先定义我们的预测函数。请注意,我们为单个图像示例定义了这个函数。我们将使用 JAX 的 vmap
函数自动处理小批量数据,而不会影响性能。
from jax.scipy.special import logsumexp def relu(x): return jnp.maximum(0, x) def predict(params, image): # per-example predictions activations = image for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits)
让我们检查我们的预测函数只适用于单个图像。
# This works on single examples random_flattened_image = random.normal(random.key(1), (28 * 28,)) preds = predict(params, random_flattened_image) print(preds.shape)
(10,)
# Doesn't work with a batch random_flattened_images = random.normal(random.key(1), (10, 28 * 28)) try: preds = predict(params, random_flattened_images) except TypeError: print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap` # Make a batched version of the `predict` function batched_predict = vmap(predict, in_axes=(None, 0)) # `batched_predict` has the same call signature as `predict` batched_preds = batched_predict(params, random_flattened_images) print(batched_preds.shape)
(10, 10)
到目前为止,我们已经具备了定义和训练神经网络所需的所有要素。我们已经构建了一个自动批处理版本的 predict
函数,应该可以在损失函数中使用。我们应该能够使用 grad
对神经网络参数的损失函数进行求导。最后,我们应该能够使用 jit
加速整个过程。
实用函数和损失函数
def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) return jnp.mean(predicted_class == target_class) def loss(params, images, targets): preds = batched_predict(params, images) return -jnp.mean(preds * targets) @jit def update(params, x, y): grads = grad(loss)(params, x, y) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
使用 tensorflow/datasets
进行数据加载
JAX 主要专注于程序转换和支持加速的 NumPy,因此我们不包括数据加载或整理在 JAX 库中。已经有很多出色的数据加载器,所以我们只需使用它们,而不是重新发明轮子。我们将使用 tensorflow/datasets
数据加载器。
import tensorflow as tf # Ensure TF does not see GPU and grab all GPU memory. tf.config.set_visible_devices([], device_type='GPU') import tensorflow_datasets as tfds data_dir = '/tmp/tfds' # Fetch full datasets for evaluation # tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) # You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) mnist_data = tfds.as_numpy(mnist_data) train_data, test_data = mnist_data['train'], mnist_data['test'] num_labels = info.features['label'].num_classes h, w, c = info.features['image'].shape num_pixels = h * w * c # Full train set train_images, train_labels = train_data['image'], train_data['label'] train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) train_labels = one_hot(train_labels, num_labels) # Full test set test_images, test_labels = test_data['image'], test_data['label'] test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape) print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10) Test: (10000, 784) (10000, 10)
训练循环
import time def get_train_batches(): # as_supervised=True gives us the (image, label) as a tuple instead of a dict ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) # You can build up an arbitrary tf.data input pipeline ds = ds.batch(batch_size).prefetch(1) # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays return tfds.as_numpy(ds) for epoch in range(num_epochs): start_time = time.time() for x, y in get_train_batches(): x = jnp.reshape(x, (len(x), num_pixels)) y = one_hot(y, num_labels) params = update(params, x, y) epoch_time = time.time() - start_time train_acc = accuracy(params, train_images, train_labels) test_acc = accuracy(params, test_images, test_labels) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc))
Epoch 0 in 28.30 sec Training set accuracy 0.8400499820709229 Test set accuracy 0.8469000458717346 Epoch 1 in 14.74 sec Training set accuracy 0.8743667006492615 Test set accuracy 0.8803000450134277 Epoch 2 in 14.57 sec Training set accuracy 0.8901500105857849 Test set accuracy 0.8957000374794006 Epoch 3 in 14.36 sec Training set accuracy 0.8991333246231079 Test set accuracy 0.903700053691864 Epoch 4 in 14.20 sec Training set accuracy 0.9061833620071411 Test set accuracy 0.9087000489234924 Epoch 5 in 14.89 sec Training set accuracy 0.9113333225250244 Test set accuracy 0.912600040435791 Epoch 6 in 13.95 sec Training set accuracy 0.9156833291053772 Test set accuracy 0.9176000356674194 Epoch 7 in 13.32 sec Training set accuracy 0.9192000031471252 Test set accuracy 0.9214000701904297 Epoch 8 in 13.55 sec Training set accuracy 0.9222500324249268 Test set accuracy 0.9241000413894653 Epoch 9 in 13.40 sec Training set accuracy 0.9253666996955872 Test set accuracy 0.9269000291824341
我们现在已经使用了大部分 JAX API:grad
用于求导,jit
用于加速和 vmap
用于自动向量化。我们使用 NumPy 来指定所有的计算,并从 tensorflow/datasets
借用了优秀的数据加载器,并在 GPU 上运行了整个过程。
训练一个简单的神经网络,使用 PyTorch 进行数据加载
原文:
jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
[外链图片转存中…(img-Jm94pLhk-1718950514656)]
版权所有 2018 年 JAX 作者。
根据 Apache 许可证第 2.0 版许可使用本文件;除非符合许可证,否则不得使用本文件。您可以在以下链接获取许可证的副本
https://www.apache.org/licenses/LICENSE-2.0
除非适用法律要求或书面同意,否则在许可证下发布的软件是按“原样”分发的,不提供任何明示或暗示的担保或条件。有关特定语言下的权限和限制,请参阅许可证。
让我们结合我们在快速入门中展示的一切,来训练一个简单的神经网络。我们将首先使用 JAX 进行计算,指定并训练一个简单的 MLP 来处理 MNIST 数据集。我们将使用 PyTorch 的数据加载 API 加载图像和标签(因为它非常棒,世界上不需要另一个数据加载库)。
当然,您可以使用 JAX 与任何与 NumPy 兼容的 API,以使模型的指定更加即插即用。在这里,仅用于解释目的,我们不会使用任何神经网络库或特殊的 API 来构建我们的模型。
import jax.numpy as jnp from jax import grad, jit, vmap from jax import random
超参数
让我们先处理一些记录事项。
# A helper function to randomly initialize weights and biases # for a dense neural network layer def random_layer_params(m, n, key, scale=1e-2): w_key, b_key = random.split(key) return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) # Initialize all layers for a fully-connected neural network with sizes "sizes" def init_network_params(sizes, key): keys = random.split(key, len(sizes)) return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] layer_sizes = [784, 512, 512, 10] step_size = 0.01 num_epochs = 8 batch_size = 128 n_targets = 10 params = init_network_params(layer_sizes, random.key(0))
自动批处理预测
让我们首先定义我们的预测函数。请注意,我们正在为单个图像示例定义这个函数。我们将使用 JAX 的vmap
函数自动处理小批量,而无需性能损失。
from jax.scipy.special import logsumexp def relu(x): return jnp.maximum(0, x) def predict(params, image): # per-example predictions activations = image for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits)
让我们检查我们的预测函数是否只适用于单个图像。
# This works on single examples random_flattened_image = random.normal(random.key(1), (28 * 28,)) preds = predict(params, random_flattened_image) print(preds.shape)
(10,)
# Doesn't work with a batch random_flattened_images = random.normal(random.key(1), (10, 28 * 28)) try: preds = predict(params, random_flattened_images) except TypeError: print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap` # Make a batched version of the `predict` function batched_predict = vmap(predict, in_axes=(None, 0)) # `batched_predict` has the same call signature as `predict` batched_preds = batched_predict(params, random_flattened_images) print(batched_preds.shape)
(10, 10)
到这一步,我们已经具备了定义和训练神经网络所需的所有要素。我们已经构建了predict
的自动批处理版本,我们应该能够在损失函数中使用它。我们应该能够使用grad
来计算损失相对于神经网络参数的导数。最后,我们应该能够使用jit
来加速整个过程。
实用工具和损失函数
def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) return jnp.mean(predicted_class == target_class) def loss(params, images, targets): preds = batched_predict(params, images) return -jnp.mean(preds * targets) @jit def update(params, x, y): grads = grad(loss)(params, x, y) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
使用 PyTorch 进行数据加载
JAX 专注于程序转换和支持加速器的 NumPy,因此我们不在 JAX 库中包括数据加载或数据处理。已经有很多出色的数据加载器,所以我们只需使用它们,而不是重新发明轮子。我们将获取 PyTorch 的数据加载器,并制作一个小的 shim 以使其与 NumPy 数组兼容。
!pip install torch torchvision
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0) Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0) Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2) Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0) Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)
import numpy as np from jax.tree_util import tree_map from torch.utils import data from torchvision.datasets import MNIST def numpy_collate(batch): return tree_map(np.asarray, data.default_collate(batch)) class NumpyLoader(data.DataLoader): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): super(self.__class__, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=numpy_collate, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn) class FlattenAndCast(object): def __call__(self, pic): return np.ravel(np.array(pic, dtype=jnp.float32))
# Define our dataset, using torch datasets mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast()) training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw Processing... Done!
# Get the full train dataset (for checking accuracy while training) train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1) train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets) # Get full test dataset mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False) test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32) test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data warnings.warn("train_data has been renamed data") /opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets warnings.warn("train_labels has been renamed targets") /opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data warnings.warn("test_data has been renamed data") /opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets warnings.warn("test_labels has been renamed targets")
训练循环
import time for epoch in range(num_epochs): start_time = time.time() for x, y in training_generator: y = one_hot(y, n_targets) params = update(params, x, y) epoch_time = time.time() - start_time train_acc = accuracy(params, train_images, train_labels) test_acc = accuracy(params, test_images, test_labels) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc))
Epoch 0 in 55.15 sec Training set accuracy 0.9157500267028809 Test set accuracy 0.9195000529289246 Epoch 1 in 42.26 sec Training set accuracy 0.9372166991233826 Test set accuracy 0.9384000301361084 Epoch 2 in 44.37 sec Training set accuracy 0.9491666555404663 Test set accuracy 0.9469000697135925 Epoch 3 in 41.75 sec Training set accuracy 0.9568166732788086 Test set accuracy 0.9534000158309937 Epoch 4 in 41.16 sec Training set accuracy 0.9631333351135254 Test set accuracy 0.9577000737190247 Epoch 5 in 38.89 sec Training set accuracy 0.9675000309944153 Test set accuracy 0.9616000652313232 Epoch 6 in 40.68 sec Training set accuracy 0.9708333611488342 Test set accuracy 0.9650000333786011 Epoch 7 in 41.50 sec Training set accuracy 0.973716676235199 Test set accuracy 0.9672000408172607
我们现在已经完全使用了 JAX API:grad
用于求导,jit
用于加速,vmap
用于自动向量化。我们使用 NumPy 来指定所有的计算,借用了 PyTorch 中优秀的数据加载器,并且在 GPU 上运行整个过程。
JAX 中文文档(六)(2)https://developer.aliyun.com/article/1559682