JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

简介: JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。

长话短说:

  • 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
  • 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
  • 使用 jax.grad 求导。
  • 使用 jax.vmap 进行矢量化,并使用 jax.pmap 进行跨设备并行化。

函数式编程

JAX 遵循函数式编程哲学。这意味着您的函数必须是独立的或纯粹的:不允许有副作用。本质上,纯函数看起来像数学函数(图 1)。有输入进来,有东西出来,但与外界没有沟通。

例子#1

以下代码片段是一个非功能纯的示例。

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

注意 impure_example 之外的偏差。在编译期间(见下文),偏差可能会被缓存,因此不再反映偏差的变化。

例子#2

这是一个pure的例子。

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

在这里,pure_example 是独立的:所有参数都作为参数传递。

确定性采样器

在计算机中,不存在真正的随机性。相反,NumPy 和 TensorFlow 等库会跟踪伪随机数状态来生成“随机”样本。

函数式编程的直接后果是随机函数的工作方式不同。由于不再允许全局状态,因此每次采样随机数时都需要显式传入伪随机数生成器 (PRNG) 密钥

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

此外,您有责任为任何后续调用推进“随机状态”。

key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

jit

您可以通过即时编译 JAX 指令来加快代码速度。例如,要编译缩放指数线性单位 (SELU) 函数,请使用 jax.numpy 中的 NumPy 函数并将 jax.jit 装饰器添加到该函数,如下所示:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

JAX 会跟踪您的指令并将其转换为 jaxpr。这使得加速线性代数 (XLA) 编译器能够为您的加速器生成非常高效的优化代码。

gard

JAX 最强大的功能之一是您可以轻松获取 gard。使用 jax.grad,您可以定义一个新函数,即符号导数。

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

正如您在示例中看到的,您不仅限于一阶导数。您可以通过简单地按顺序链接 grad 函数 n 次来获取 n 阶导数。

vmap 和 pmap

矩阵乘法使所有批次尺寸正确需要非常细心。 JAX 的矢量化映射函数 vmap 通过对函数进行矢量化来减轻这种负担。基本上,每个按元素应用函数 f 的代码块都是由 vmap 替换的候选者。让我们看一个例子。

计算线性函数:

def linear(x):
 return weights @ x

在一批示例 [x₁, x2,..] 中,我们可以天真地(没有 vmap)实现它,如下所示:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

相反,通过使用 vmap 对线性进行向量化,我们可以一次性计算整个批次:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)
相关文章
|
3天前
|
机器学习/深度学习 PyTorch API
pytorch与深度学习
【5月更文挑战第3天】PyTorch,Facebook开源的深度学习框架,以其动态计算图和灵活API深受青睐。本文深入浅出地介绍PyTorch基础,包括动态计算图、张量和自动微分,通过代码示例演示简单线性回归和卷积神经网络的实现。此外,探讨了模型架构、自定义层、数据加载及预处理等进阶概念,并分享了实战技巧、问题解决方案和学习资源,助力读者快速掌握PyTorch。
15 5
|
6天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【Python机器学习专栏】PyTorch在深度学习中的应用
【4月更文挑战第30天】PyTorch是流行的开源深度学习框架,基于动态计算图,易于使用且灵活。它支持张量操作、自动求导、优化器和神经网络模块,适合快速实验和模型训练。PyTorch的优势在于易用性、灵活性、社区支持和高性能(利用GPU加速)。通过Python示例展示了如何构建和训练神经网络。作为一个强大且不断发展的工具,PyTorch适用于各种深度学习任务。
|
6天前
|
机器学习/深度学习 自然语言处理 算法
PyTorch与NLP:自然语言处理的深度学习实战
随着人工智能技术的快速发展,自然语言处理(NLP)作为其中的重要分支,日益受到人们的关注。PyTorch作为一款强大的深度学习框架,为NLP研究者提供了强大的工具。本文将介绍如何使用PyTorch进行自然语言处理的深度学习实践,包括基础概念、模型搭建、数据处理和实际应用等方面。
|
10天前
|
机器学习/深度学习 PyTorch TensorFlow
Python数据科学之旅从基础到深度学习
【4月更文挑战第10天】在这系列文章中,我们探讨了数据科学中重要的Python库,如NumPy和Pandas,以及深度学习框架TensorFlow和PyTorch。NumPy提供高性能的多维数组操作,Pandas则提供了灵活的数据处理和分析。通过Matplotlib和Seaborn进行数据可视化
15 2
|
18天前
|
机器学习/深度学习 并行计算 PyTorch
PyTorch与CUDA:加速深度学习训练
【4月更文挑战第18天】本文介绍了如何使用PyTorch与CUDA加速深度学习训练。CUDA是NVIDIA的并行计算平台,常用于加速深度学习中的矩阵运算。PyTorch与CUDA集成,允许开发者将模型和数据迁移到GPU,利用`.to(device)`方法加速计算。通过批处理、并行化策略及优化技巧,如混合精度训练,可进一步提升训练效率。监控GPU内存和使用调试工具确保训练稳定性。PyTorch与CUDA的结合对深度学习训练的加速作用显著。
|
PyTorch 算法框架/工具 Android开发
PyTorch 深度学习(GPT 重译)(六)(4)
PyTorch 深度学习(GPT 重译)(六)
38 2
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 深度学习(GPT 重译)(六)(3)
PyTorch 深度学习(GPT 重译)(六)
29 2
|
18天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 深度学习(GPT 重译)(六)(2)
PyTorch 深度学习(GPT 重译)(六)
41 1
|
18天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 深度学习(GPT 重译)(六)(1)
PyTorch 深度学习(GPT 重译)(六)
37 1
|
18天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 深度学习(GPT 重译)(五)(4)
PyTorch 深度学习(GPT 重译)(五)
35 5