JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

简介: JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。

长话短说:

  • 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
  • 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
  • 使用 jax.grad 求导。
  • 使用 jax.vmap 进行矢量化,并使用 jax.pmap 进行跨设备并行化。

函数式编程

JAX 遵循函数式编程哲学。这意味着您的函数必须是独立的或纯粹的:不允许有副作用。本质上,纯函数看起来像数学函数(图 1)。有输入进来,有东西出来,但与外界没有沟通。

例子#1

以下代码片段是一个非功能纯的示例。

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

注意 impure_example 之外的偏差。在编译期间(见下文),偏差可能会被缓存,因此不再反映偏差的变化。

例子#2

这是一个pure的例子。

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

在这里,pure_example 是独立的:所有参数都作为参数传递。

确定性采样器

在计算机中,不存在真正的随机性。相反,NumPy 和 TensorFlow 等库会跟踪伪随机数状态来生成“随机”样本。

函数式编程的直接后果是随机函数的工作方式不同。由于不再允许全局状态,因此每次采样随机数时都需要显式传入伪随机数生成器 (PRNG) 密钥

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

此外,您有责任为任何后续调用推进“随机状态”。

key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

jit

您可以通过即时编译 JAX 指令来加快代码速度。例如,要编译缩放指数线性单位 (SELU) 函数,请使用 jax.numpy 中的 NumPy 函数并将 jax.jit 装饰器添加到该函数,如下所示:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

JAX 会跟踪您的指令并将其转换为 jaxpr。这使得加速线性代数 (XLA) 编译器能够为您的加速器生成非常高效的优化代码。

gard

JAX 最强大的功能之一是您可以轻松获取 gard。使用 jax.grad,您可以定义一个新函数,即符号导数。

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

正如您在示例中看到的,您不仅限于一阶导数。您可以通过简单地按顺序链接 grad 函数 n 次来获取 n 阶导数。

vmap 和 pmap

矩阵乘法使所有批次尺寸正确需要非常细心。 JAX 的矢量化映射函数 vmap 通过对函数进行矢量化来减轻这种负担。基本上,每个按元素应用函数 f 的代码块都是由 vmap 替换的候选者。让我们看一个例子。

计算线性函数:

def linear(x):
 return weights @ x

在一批示例 [x₁, x2,..] 中,我们可以天真地(没有 vmap)实现它,如下所示:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

相反,通过使用 vmap 对线性进行向量化,我们可以一次性计算整个批次:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)
相关文章
|
3月前
|
数据采集 数据处理 Python
探索数据科学前沿:Pandas与NumPy库的高级特性与应用实例
探索数据科学前沿:Pandas与NumPy库的高级特性与应用实例
60 0
|
2月前
|
数据处理 Python
在数据科学领域,Pandas和NumPy是每位数据科学家和分析师的必备工具
在数据科学领域,Pandas和NumPy是每位数据科学家和分析师的必备工具。本文通过问题解答形式,深入探讨Pandas与NumPy的高级操作技巧,如复杂数据筛选、分组聚合、数组优化及协同工作,结合实战演练,助你提升数据处理能力和工作效率。
48 5
|
2月前
|
机器学习/深度学习 监控 PyTorch
深度学习工程实践:PyTorch Lightning与Ignite框架的技术特性对比分析
在深度学习框架的选择上,PyTorch Lightning和Ignite代表了两种不同的技术路线。本文将从技术实现的角度,深入分析这两个框架在实际应用中的差异,为开发者提供客观的技术参考。
56 7
|
2月前
|
存储 数据采集 数据处理
效率与精准并重:掌握Pandas与NumPy高级特性,赋能数据科学项目
在数据科学领域,Pandas和NumPy是Python生态中处理数据的核心库。Pandas以其强大的DataFrame和Series结构,提供灵活的数据操作能力,特别适合数据的标签化和结构化处理。NumPy则以其高效的ndarray结构,支持快速的数值计算和线性代数运算。掌握两者的高级特性,如Pandas的groupby()和pivot_table(),以及NumPy的广播和向量化运算,能够显著提升数据处理速度和分析精度,为项目成功奠定基础。
42 2
|
3月前
|
机器学习/深度学习 算法 PyTorch
深度学习笔记(十三):IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU、WIOU损失函数分析及Pytorch实现
这篇文章详细介绍了多种用于目标检测任务中的边界框回归损失函数,包括IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU和WIOU,并提供了它们的Pytorch实现代码。
483 1
深度学习笔记(十三):IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU、WIOU损失函数分析及Pytorch实现
|
4月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
279 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
3月前
|
机器学习/深度学习 算法 数据可视化
如果你的PyTorch优化器效果欠佳,试试这4种深度学习中的高级优化技术吧
在深度学习领域,优化器的选择对模型性能至关重要。尽管PyTorch中的标准优化器如SGD、Adam和AdamW被广泛应用,但在某些复杂优化问题中,这些方法未必是最优选择。本文介绍了四种高级优化技术:序列最小二乘规划(SLSQP)、粒子群优化(PSO)、协方差矩阵自适应进化策略(CMA-ES)和模拟退火(SA)。这些方法具备无梯度优化、仅需前向传播及全局优化能力等优点,尤其适合非可微操作和参数数量较少的情况。通过实验对比发现,对于特定问题,非传统优化方法可能比标准梯度下降算法表现更好。文章详细描述了这些优化技术的实现过程及结果分析,并提出了未来的研究方向。
46 1
|
3月前
|
PyTorch 算法框架/工具 Python
Pytorch学习笔记(十):Torch对张量的计算、Numpy对数组的计算、它们之间的转换
这篇文章是关于PyTorch张量和Numpy数组的计算方法及其相互转换的详细学习笔记。
54 0
|
4月前
|
机器学习/深度学习 数据挖掘 PyTorch
🎓PyTorch深度学习入门课:编程小白也能玩转的高级数据分析术
踏入深度学习领域,即使是编程新手也能借助PyTorch这一强大工具,轻松解锁高级数据分析。PyTorch以简洁的API、动态计算图及灵活性著称,成为众多学者与工程师的首选。本文将带你从零开始,通过环境搭建、构建基础神经网络到进阶数据分析应用,逐步掌握PyTorch的核心技能。从安装配置到编写简单张量运算,再到实现神经网络模型,最后应用于图像分类等复杂任务,每个环节都配有示例代码,助你快速上手。实践出真知,不断尝试和调试将使你更深入地理解这些概念,开启深度学习之旅。
61 1
|
3月前
|
机器学习/深度学习 数据采集 自然语言处理
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器