JAX 中文文档(二)(2)

简介: JAX 中文文档(二)

JAX 中文文档(二)(1)https://developer.aliyun.com/article/1559668


自动向量化

原文:jax.readthedocs.io/en/latest/automatic-vectorization.html

在前一节中,我们讨论了通过jax.jit()函数进行的 JIT 编译。本文档还讨论了 JAX 的另一个转换:通过jax.vmap()进行向量化。

手动向量化

考虑以下简单代码,计算两个一维向量的卷积:

import jax
import jax.numpy as jnp
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)
convolve(x, w) 
Array([11., 20., 29.], dtype=float32) 

假设我们希望将此函数应用于一批权重w到一批向量x

xs = jnp.stack([x, x])
ws = jnp.stack([w, w]) 

最简单的选择是在 Python 中简单地循环遍历批处理:

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)
manually_batched_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

这会产生正确的结果,但效率不高。

为了有效地批处理计算,通常需要手动重写函数,以确保它以向量化形式完成。这并不难实现,但涉及更改函数处理索引、轴和输入其他部分的方式。

例如,我们可以手动重写convolve(),以支持跨批处理维度的向量化计算,如下所示:

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)
manually_vectorized_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

随着函数复杂性的增加,这种重新实现可能会变得混乱且容易出错;幸运的是,JAX 提供了另一种方法。

自动向量化

在 JAX 中,jax.vmap()转换旨在自动生成这样的函数的向量化实现:

auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

它通过类似于jax.jit()的追踪函数来实现这一点,并在每个输入的开头自动添加批处理轴。

如果批处理维度不是第一维,则可以使用in_axesout_axes参数来指定输入和输出中批处理维度的位置。如果所有输入和输出的批处理轴相同,则可以使用整数,否则可以使用列表。

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
xst = jnp.transpose(xs)
wst = jnp.transpose(ws)
auto_batch_convolve_v2(xst, wst) 
Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32) 

jax.vmap()还支持只有一个参数被批处理的情况:例如,如果您希望将一组单一的权重w与一批向量x进行卷积;在这种情况下,in_axes参数可以设置为None

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

结合转换

与所有 JAX 转换一样,jax.jit()jax.vmap()都设计为可组合的,这意味着您可以用jit包装一个 vmapped 函数,或用vmap包装一个 jitted 函数,一切都会正常工作:

jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

自动微分

原文:jax.readthedocs.io/en/latest/automatic-differentiation.html

在本节中,您将学习 JAX 中自动微分(autodiff)的基本应用。JAX 具有一个非常通用的自动微分系统。计算梯度是现代机器学习方法的关键部分,本教程将引导您了解一些自动微分的入门主题,例如:

  • 1. 使用 jax.grad 计算梯度
  • 2. 在线性逻辑回归中计算梯度
  • 3. 对嵌套列表、元组和字典进行微分
  • 4. 使用 jax.value_and_grad 评估函数及其梯度
  • 5. 检查数值差异

还要确保查看高级自动微分教程,了解更多高级主题。

虽然理解自动微分的“内部工作原理”对于在大多数情况下使用 JAX 并不关键,但建议您观看这个非常易懂的视频,以深入了解发生的事情。

1. 使用jax.grad()计算梯度

在 JAX 中,您可以使用jax.grad()变换微分一个标量值函数:

import jax
import jax.numpy as jnp
from jax import grad
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0)) 
0.070650816 

jax.grad()接受一个函数并返回一个函数。如果你有一个 Python 函数f,它计算数学函数( f ),那么jax.grad(f)是一个 Python 函数,它计算数学函数( \nabla f )。这意味着grad(f)(x)表示值( \nabla f(x) )。

由于jax.grad()操作函数,您可以将其应用于其自身的输出,以任意次数进行微分:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0)) 
-0.13621868
0.25265405 

JAX 的自动微分使得计算高阶导数变得容易,因为计算导数的函数本身是可微的。因此,高阶导数就像堆叠转换一样容易。这可以在单变量情况下说明:

函数( f(x) = x³ + 2x² - 3x + 1 )的导数可以计算如下:

f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f) 

函数( f )的高阶导数为:

[]

在 JAX 中计算任何这些都像链接jax.grad()函数一样简单:

d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx) 

在( x=1 )处评估上述内容将给出:

[]

使用 JAX:

print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.)) 
4.0
10.0
6.0
0.0 
```## 2\. 在线性逻辑回归中计算梯度
下一个示例展示了如何在线性逻辑回归模型中使用`jax.grad()`计算梯度。首先,设置:
```py
key = jax.random.key(0)
def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
  return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
  preds = predict(W, b, inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ()) 

使用jax.grad()函数及其argnums参数对位置参数进行函数微分。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')
# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}') 
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32) 

jax.grad() API 直接对应于斯皮瓦克经典著作《流形上的微积分》(1965  年)中的优秀符号表示法,也用于苏斯曼和威斯登的《古典力学的结构与解释》(2015 年)及其《函数微分几何》(2013  年)。这两本书都是开放获取的。特别是,《函数微分几何》的“前言”部分为此符号的使用进行了辩护。

实际上,当使用argnums参数时,如果f是用于评估数学函数(f)的 Python 函数,则 Python 表达式jax.grad(f, i)评估为一个用于评估(\partial_i f)的 Python 函数。 ## 3. 对嵌套列表、元组和字典进行微分

由于 JAX 的 PyTree 抽象(详见处理 pytrees),关于标准 Python 容器的微分工作都能正常进行,因此你可以随意使用元组、列表和字典(及任意嵌套结构)。

继续前面的示例:

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b})) 
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)} 

你可以创建自定义的 pytree 节点,以便与不仅仅是jax.grad(),还有其他 JAX 转换(jax.jit()jax.vmap()等)一起使用。 ## 4. 使用jax.value_and_grad评估函数及其梯度

另一个方便的函数是jax.value_and_grad(),可以在一次计算中高效地同时计算函数值和其梯度值。

继续前面的示例:

loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b)) 
loss value 3.0519385
loss value 3.0519385 
```## 5\. 对数值差异进行检查
关于导数的一大好处是,它们对有限差异的检查非常直观。
继续前面的示例:
```py
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec)) 
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117 

JAX 提供了一个简单的便利函数,实际上做了相同的事情,但可以检查任意阶数的微分:

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives 

下一步

高级自动微分教程提供了关于如何在 JAX 后端实现本文档涵盖的思想的更高级和详细的解释。某些功能,如用于 JAX 可转换 Python 函数的自定义导数规则,依赖于对高级自动微分的理解,因此如果您感兴趣,请查看高级自动微分教程中的相关部分。


JAX 中文文档(二)(3)https://developer.aliyun.com/article/1559670

相关文章
|
10天前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
13 1
|
10天前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
11 1
|
10天前
|
存储 并行计算 数据可视化
JAX 中文文档(六)(3)
JAX 中文文档(六)
12 0
|
9天前
|
存储 编译器 芯片
JAX 中文文档(五)(5)
JAX 中文文档(五)
8 0
|
9天前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
9 0
|
10天前
|
存储 机器学习/深度学习 TensorFlow
JAX 中文文档(七)(5)
JAX 中文文档(七)
7 0
|
9天前
|
机器学习/深度学习 异构计算 Python
JAX 中文文档(四)(3)
JAX 中文文档(四)
10 0
|
10天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
15 0
|
10天前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
7 0
|
10天前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
6 0