JAX 核心特性详解:纯函数、JIT 编译、自动微分等十大必知概念

简介: JAX是Google与NVIDIA联合开发的高性能数值计算库,依托XLA实现CPU/GPU/TPU加速,支持自动微分、JIT编译、向量化与并行化。生态丰富,含Flax、Optax等工具,适合深度学习与科学计算。

JAX 是 Google 和 NVIDIA 联合开发的高性能数值计算库,这两年 JAX 生态快速发展,周边工具链也日益完善了。如果你用过 NumPy 或 PyTorch,但还没接触过 JAX,这篇文章能帮助你快速上手。

围绕 JAX 已经涌现出一批好用的库:Flax 用来搭神经网络,Optax 处理梯度和优化,Equinox 提供类似 PyTorch 的接口,Haiku 则是简洁的函数式 API,Jraph 用于图神经网络,RLax 是强化学习工具库,Chex 提供测试和调试工具,Orbax 负责模型检查点和持久化。

纯函数是硬需求

JAX 对函数有个基本要求:必须是纯函数。这意味着函数不能有副作用,对同样的输入必须总是返回同样的输出。

这个约束来自函数式编程范式。JAX 内部做各种变换(编译、自动微分等)依赖纯函数的特性,用不纯的函数可能导致错误或静默失败,结果完全不对。

 # 纯函数,没问题
def pure_addition(a, b):  
  return a + b  

# 不纯的函数,JAX 不接受
counter = 0  
def impure_addition(a, b):  
  global counter  
  counter += 1  
   return a + b

JAX NumPy 与原生 NumPy

JAX 提供了类 NumPy 的接口,核心优势在于能自动高效地在 CPU、GPU 甚至 TPU 上运行,支持本地或分布式执行。这套能力来自 XLA(Accelerated Linear Algebra) 编译器,它把 JAX 代码翻译成针对不同硬件优化的机器码。

NumPy 默认只在 CPU 上跑,JAX NumPy 则不同。用法上两者很相似,这也是 JAX 容易上手的原因。


# JAX 也差不多
import jax.numpy as jnp  

print(jnp.sqrt(4))# NumPy 的写法
import numpy as np  

print(np.sqrt(4))
# JAX 也差不多
import jax.numpy as jnp  

 print(jnp.sqrt(4))

常见的操作两者看起来基本一样:

 import numpy as np  
import jax.numpy as jnp  

# 创建数组
np_a = np.array([1.0, 2.0, 3.0])  
jnp_a = jnp.array([1.0, 2.0, 3.0])  

# 元素级操作
print(np_a + 2)  
print(jnp_a + 2)  

# 广播
np_b = np.array([[1, 2, 3]])  
jnp_b = jnp.array([[1, 2, 3]])  
print(np_b + np.arange(3))  
print(jnp_b + jnp.arange(3))  

# 求和
print(np.sum(np_a))  
print(jnp.sum(jnp_a))  

# 平均值
print(np.mean(np_a))   
print(jnp.mean(jnp_a))  

# 点积
print(np.dot(np_a, np_a))   
 print(jnp.dot(jnp_a, jnp_a))

但有个重要差异需要注意:

JAX 数组是不可变的*,对数组的修改操作会返回新数组而不是改变原数组。*

NumPy 数组则可以直接修改:

 import numpy as np  

 x = np.array([1, 2, 3])  
 x[0] = 10  # 直接修改,没问题

JAX 这边就不行了:

 import jax.numpy as jnp  

 x = jnp.array([1, 2, 3])  
 x[0] = 10  # 报错

但是JAX 提供了专门的 API 来处理这种情况,通过返回一个新数组的方式实现"修改":

 z=x.at[idx].set(y)

完整的例子:

 x = jnp.array([1, 2, 3])  
 y = x.at[0].set(10)  

 print(y)  # [10, 2, 3]  
 print(x)  # [1, 2, 3](没变)

JIT 编译加速

即时编译(JIT)是 JAX 里一个核心特性,通过 XLA 把 Python/JAX 代码编译成优化后的机器码。

直接用 Python 解释器跑函数会很慢。加上

@jit

装饰器后,函数会被编译成快速的原生代码:

 from jax import jit  

# 不编译的版本
def square(x):  
  return x * x   

# 编译过的版本
@jit   
def jit_square(x):  
   return x * x
jit_square

快好几个数量级。函数首次调用时,JIT 引擎会:

  1. 追踪函数逻辑,构建计算图
  2. 把图编译成优化的 XLA 代码
  3. 缓存编译结果
  4. 后续调用直接用缓存的版本

自动微分

JAX 的 grad 模块能自动计算函数的导数。

 import jax.numpy as jnp  
from jax import grad  

# 定义函数:f(x) = x² + 2x + 2
def f(x):  
  return x**2 + 2 * x + 2  

# 计算导数
df_dx = grad(f)  

# 在 x = 2.0 处求值
 print(df_dx(2.0))  # 6.0

随机数处理

NumPy 用全局随机状态生成随机数。每次调用

np.random.random()

时,NumPy 会更新隐藏的内部状态:

 import numpy as np  

 np.random.random()  
 # 0.9539264374520571

JAX 的做法完全不同。作为纯函数库,它不能维护全局状态,所以要求显式传入一个伪随机数生成器(PRNG)密钥。每次生成随机数前要先分割密钥:

 from jax import random  

# 初始化密钥
key = random.PRNGKey(0)  

# 每次生成前分割
key, subkey = random.split(key)  

# 从正态分布采样
x = random.normal(subkey, ())  
print(x)  # -2.4424558  

# 从均匀分布采样
key, subkey = random.split(key)  
u = random.uniform(subkey, (), minval=0.0, maxval=1.0)  
 print(u)  # 0.104290366

一个常见的坑:同一个密钥生成的随机数始终相同。

 # 用同一个 subkey,结果重复
 x = random.normal(subkey, ())  
 print(x)  # -2.4424558  

 x = random.normal(subkey, ())  
 print(x)  # -2.4424558(还是这个值)

所以要记住总是用新密钥。

向量化:vmap

vmap 自动把函数转换成能处理批量数据的版本。逻辑上就像循环遍历每个样本,但执行效率远高于 Python 循环。

 import jax.numpy as jnp  
from jax import vmap  

def f(x):  
    return x * x + 1  

arr = jnp.array([1., 2., 3., 4.])  

# Python 循环(慢)
outputs_loop = jnp.array([f(x) for x in arr])  

# vmap 版本(快)
f_vectorized = vmap(f)  
 outputs_vmap = f_vectorized(arr)

并行化:pmap

pmap 不同于 vmap。vmap 在单个设备上做批处理,pmap 把计算分散到多个设备(GPU/TPU 核心),每个设备处理输入的一部分。

VMAP:单设备批处理向量化

PMAP:跨多设备并行执行

 import jax.numpy as jnp  
from jax import pmap  

# 查看可用设备
print(jax.devices())  # [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)]  

def f(x):  
    return x * x + 1  

arr = jnp.array([1., 2., 3., 4.])  

# pmap 会把数组分配到不同设备
 ys = pmap(f)(arr)

PyTrees

PyTree 在 JAX 里是个常见的概念:任何嵌套的 Python 容器(列表、字典、元组等)加上基本类型的组合。JAX 里用它来组织模型参数、优化器状态、梯度等。

 import jax.numpy as jnp  
from jax import tree_util as tu

# 构建 PyTree
pytree = {  
    "a": jnp.array([1, 2]),  
    "b": [jnp.array([3, 4]), 5]  
}  

# 获取所有叶子节点
leaves = tu.tree_leaves(pytree)  

# 对每个叶子应用函数
 doubled = tu.tree_map(lambda x: x * 2, pytree)

Optax:梯度处理和优化

Optax 是 JAX 生态里的优化库。它包含损失函数、优化器、梯度变换、学习率调度等一套工具。

损失函数:

 logits = jnp.array([[2.0, -1.0]])  
labels_onehot = jnp.array([[1.0, 0.0]])  
labels_int = jnp.array([0])  

# Softmax 交叉熵(独热编码)
loss_softmax_onehot = optax.softmax_cross_entropy(logits, labels_onehot).mean()  

# Softmax 交叉熵(整数标签)
loss_softmax_int = optax.softmax_cross_entropy_with_integer_labels(logits, labels_int).mean()  

# 二元交叉熵
loss_bce = optax.sigmoid_binary_cross_entropy(logits, labels_onehot).mean()  

# L2 损失
loss_l2 = optax.l2_loss(jnp.array([1., 2.]), jnp.array([0., 1.])).mean()  

# Huber 损失
 loss_huber = optax.huber_loss(jnp.array([1.,2.]), jnp.array([0.,1.])).mean()

优化器:

 # SGD
opt_sgd = optax.sgd(learning_rate=1e-2)  

# SGD with momentum
opt_momentum = optax.sgd(learning_rate=1e-2, momentum=0.9)   

# RMSProp
opt_rmsprop = optax.rmsprop(1e-3)  

# Adafactor
opt_adafactor = optax.adafactor(learning_rate=1e-3)  

# Adam
opt_adam = optax.adam(1e-3)  

# AdamW
 opt_adamw = optax.adamw(1e-3, weight_decay=1e-4)

梯度变换:

 # 梯度裁剪
tx_clip = optax.clip(1.0)  

# 全局梯度范数裁剪
tx_clip_global = optax.clip_by_global_norm(1.0)  

# 权重衰减(L2)
tx_weight_decay = optax.add_decayed_weights(1e-4)  

# 添加梯度噪声
 tx_noise = optax.add_noise(0.01)

学习率调度:

 # 指数衰减
 lr_exp = optax.exponential_decay(init_value=1e-3, transition_steps=1000, decay_rate=0.99)  

 # 余弦衰减
 lr_cos = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=10_000)  

 # 线性预热
 lr_linear = optax.linear_schedule(init_value=0.0, end_value=1e-3, transition_steps=500)

更新步骤:

 # 计算梯度
 loss, grads = jax.value_and_grad(loss_fn)(params)  

 # 生成优化器更新
 updates, opt_state = optimizer.update(grads, opt_state)  

 # 应用更新
 params = optax.apply_updates(params, updates)

链式组合:

 # 把多个操作链起来
 optimizer = optax.chain(  
     optax.clip_by_global_norm(1.0),  # 梯度裁剪
     optax.add_decayed_weights(1e-4),  # 权重衰减
     optax.adam(1e-3)  # Adam 优化
 )

Flax 与神经网络

JAX 本身只是数值计算库,Flax 在其基础上提供了神经网络定义和训练的高级 API。Flax 代码风格接近 PyTorch,如果你用过 PyTorch 会很快上手。

Flax 提供了丰富的层和操作。基础层 包括全连接层

Dense

、卷积

Conv

、嵌入

Embed

、多头注意力

MultiHeadDotProductAttention

等:

 flax.linen.Dense(features=128)  
 flax.linen.Conv(features=64, kernel_size=(3, 3))  
 flax.linen.Embed(num_embeddings=10000, features=256)  
 flax.linen.MultiHeadDotProductAttention(num_heads=8)  
 flax.linen.SelfAttention(num_heads=8)

归一化 支持多种方式:

 flax.linen.BatchNorm()  
 flax.linen.LayerNorm()  
 flax.linen.GroupNorm(num_groups=32)  
 flax.linen.RMSNorm()

激活和 Dropout:

 flax.linen.relu(x)  
 flax.linen.gelu(x)  
 flax.linen.sigmoid(x)  
 flax.linen.tanh(x)  
 flax.linen.Dropout(rate=0.1)

池化:

 flax.linen.avg_pool(x, window_shape=(2,2), strides=(2,2))  
 flax.linen.max_pool(x, window_shape=(2,2), strides=(2,2))

循环层:

 flax.linen.LSTMCell()  
 flax.linen.GRUCell()  
 flax.linen.OptimizedLSTMCell()

下面是一个简单的多层感知机(MLP)例子:

 import jax  
import jax.numpy as jnp  
from flax import linen as nn  

class MLP(nn.Module):  
    features: list  

    @nn.compact  
    def __call__(self, x):  
        for f in self.features[:-1]:  
            x = nn.Dense(f)(x)
            x = nn.relu(x)

        x = nn.Dense(self.features[-1])(x)  
        return x  

model = MLP([32, 16, 10])  
key = jax.random.PRNGKey(0)  

# 输入:batch_size=1, 特征数=4
x = jnp.ones((1, 4))  

# 初始化参数
params = model.init(key, x)  

# 前向传播
y = model.apply(params, x)  

print("Input:", x)  
# Input: [[1. 1. 1. 1.]]  

print("Input shape:", x.shape)  
# Input shape: (1, 4)   

print("Output:", y)  
# Output: [[ 0.51415515  0.36979797  0.6212194  -0.74496573 -0.8318489   0.6590691 0.89224255  0.00737424  0.33062232  0.34577468]]  

print("Output shape:", y.shape)  
 # Output shape: (1, 10)

Flax 用

@nn.compact

装饰器,让你在

__call__

方法里直接定义层。参数是独立于模型对象存储的,需要通过

init

方法显式初始化,然后在

apply

方法中使用。

总结

JAX 的出现解决了一个长期存在的问题:如何让 Python 科学计算既保持灵活性,又能获得接近 C/CUDA 的性能。

不过 JAX 的学习曲线确实比 PyTorch 陡。纯函数的约束、不可变数组的特性、显式密钥管理等细节起初会有些别扭。但一旦习惯会发现它带来的优雅和灵活性。

https://avoid.overfit.cn/post/a16194fdc3ea450f858515d7cb3d49c4

作者:Ashish Bamania

目录
相关文章
|
4月前
|
PyTorch 算法框架/工具
JAX核心设计解析:函数式编程让代码更可控
JAX采用函数式编程,参数与模型分离,随机数需显式传递key,确保无隐藏状态。这使函数行为可预测,便于自动微分、编译优化与分布式训练,虽初学略显繁琐,但在科研、高精度仿真等场景下更具可控性与可复现优势。
444 115
|
监控 网络协议 Ubuntu
Linux网络监控工具 - iftop
Linux网络监控工具 - iftop
632 1
|
6月前
|
人工智能 Java Nacos
基于 Spring AI Alibaba + Nacos 的分布式 Multi-Agent 构建指南
本文将针对 Spring AI Alibaba + Nacos 的分布式多智能体构建方案展开介绍,同时结合 Demo 说明快速开发方法与实际效果。
4829 92
|
6月前
|
缓存 PyTorch API
TensorRT-LLM 推理服务实战指南
`trtllm-serve` 是 TensorRT-LLM 官方推理服务工具,支持一键部署兼容 OpenAI API 的生产级服务,提供模型查询、文本与对话补全等接口,并兼容多模态及分布式部署,助力高效推理。
885 155
|
7月前
|
存储 人工智能 运维
AI 网关代理 RAG 检索:Dify 轻松对接外部知识库的新实践
Higress AI 网关通过提供关键桥梁作用,支持 Dify 应用便捷对接业界成熟的 RAG 引擎。通过 AI 网关将 Dify 的高效编排能力与专业 RAG 引擎的检索效能结合,企业可在保留现有 Dify 应用资产的同时,有效规避其内置 RAG 的局限,显著提升知识驱动型 AI 应用的生产环境表现。
3235 130
|
6月前
|
机器学习/深度学习 算法 数据可视化
6G时代的新型延迟多普勒通信范式:正交时频空间(OTFS)综述
本文综述正交时频空间(OTFS)技术,一种面向6G高移动性场景的新型延迟-多普勒域通信范式。OTFS通过在延迟-多普勒域调制信号,克服传统OFDM在高速移动下的多普勒扩展难题,具备信道稳定性强、抗干扰能力优、峰均比低等优势。文章系统阐述OTFS的信道模型、调制原理、收发机设计、ISAC一体化及在卫星、水声、可见光等新兴场景的应用前景,为其在6G空天地海一体化网络中的应用提供理论支撑与技术路径。
1308 0
6G时代的新型延迟多普勒通信范式:正交时频空间(OTFS)综述
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
5397 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
机器学习/深度学习 自然语言处理
不是RNN的锅!清华团队深入分析长上下文建模中的状态崩溃,Mamba作者点赞
清华大学团队发表论文,深入分析RNN在长上下文建模中的状态崩溃现象,并提出四种缓解方法:减少记忆与增加遗忘、状态归一化、滑动窗口机制及训练更长序列。实验表明,这些方法显著提升Mamba-2模型处理超过1M tokens的能力。尽管存在局限性,该研究为RNN长上下文建模提供了新思路,得到Mamba作者认可。
328 6
|
网络协议 安全 CDN
你的连接不是专用连接 攻击者可能试图从 github.com 窃取你的信息 通过修改DNS连接解决无法连接问题
你的连接不是专用连接 攻击者可能试图从 github.com 窃取你的信息 通过修改DNS连接解决无法连接问题
2525 0
|
缓存 编译器 C++
第十五问:volatile是什么?有什么用?
本文深入探讨了C/C++中的`volatile`关键字,解释了其防止编译器不当优化、保证多线程间可见性和确保硬件状态正确读写的作用。同时,文章也指出了使用`volatile`可能带来的性能影响,并强调了它在多线程同步中的局限性。通过具体示例,帮助读者更好地理解和应用这一强大工具。