JAX 训练加速指南:8 个让 TPU 满跑的工程实战习惯

简介: 本文总结8条JAX在TPU上高效训练的工程实践:固定Shape、使用bfloat16+FP32主权重、显式pjit切分、jit/vmap/scan融合、优化数据管道、PRNG与Step/Device绑定、Remat与梯度累积、善用Profiler。遵循这些原则可避免重编译与内存瓶颈,最大化TPU算力利用率,实现高效稳定训练。

TPU 训练的真实效率往往取决于两个核心要素:Shape 的稳定性算子的融合度

很多时候,JAX 任务之所以出现严重的性能瓶颈,并非算法本身设计有问题,而是忽视了 XLA 编译器与底层硬件对“确定性”的极度偏好。基于大量实战调优经验,本文总结了八条能让 JAX 训练任务从“甚至跑不通”蜕变为“跑满 TPU 算力”的工程经验。

1、尽早锁定 Shape

TPU 喜欢静态 Shape,JAX 也是,所以动态 Shape 是性能杀手,它会触发重新编译(Recompile)。一旦发生重编译,Step time 和内存占用都会直接炸裂。所以解决方法也很简单,选定几个规范的尺寸,剩下的全填(Pad)满。

全局 Batch Size 要能被 TPU 核心数整除,然后就是对于变长序列,别指望它原本多长就多长,把它 Pad 到几个固定的“桶(Bucket)”里,比如 128、256 或 512,这步工作最好在输入(Input Pipeline)里就做完。

Python层面的条件判断尽量别依赖 Shape,真要分支逻辑,就老老实实让

lax.cond

lax.switch

来接管。

     # Example: bucketing & padding (conceptual)  
    def pad_to_length(arr, L):  
        pad = L - arr.shape[0]  
        return jnp.pad(arr, ((0, pad), (0, 0)), mode='constant')  

    bucket_sizes = [128, 256, 512]  
    def bucket_len(n):   
        return next(b for b in bucket_sizes if n <= b)  

    def preprocess_batch(batch):  
        L = bucket_len(batch["tokens"].shape[1])  
        batch["tokens"] = pad_to_length(batch["tokens"], L)  
        batch["mask"]   = pad_to_length(batch["mask"], L)  
         return batch

每个 Step 喂给 TPU 的 Shape 只要是固定的,XLA 编译器就不会找麻烦。

2、激活值默认用 bfloat16,主权重要 FP32

在 TPU 上

bfloat16

(bf16) 是个好东西,兼顾了速度、内存和数值稳定性。

工程上的常规操作是:激活(Activations)和梯度(Gradients)存成 bf16。但是,优化器状态里的权重必须保留一份 FP32 的“主副本”,不然跑久了数值就会漂移。所欲需要在模型边界做类型转换(Cast)的时候小心点。

     class MLP(nn.Module):  
        features: int  
        @nn.compact  
        def __call__(self, x):  
            x = x.astype(jnp.bfloat16)     # fast path on TPUs  
            x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)  
            x = nn.gelu(x)  
            x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)  
            return x  

    # Optimizer state stays in FP32 (conceptual)  
    params_fp32 = params.astype(jnp.float32)  
    grads_bf16  = compute_grads_bf16(...)  
     updates_fp32 = opt.update(grads_bf16.astype(jnp.float32), opt_state, params_fp32)

3、pjit和命名网格:切分要明确,别靠猜

JAX 在 TPU 上最强的一点就是通过

pjit

实现了 GSPMD。你通过 PartitionSpecs 告诉它想要什么切分方式,XLA 负责搞定如何在设备间搬运数据。

在 TPU 核心上建个命名网格(Mesh)。做数据并行(Data Parallelism)时,用

PartitionSpec('data', None)

这种模式。如果模型太大需要张量并行(Tensor Model Parallelism),就加个

'model'

轴。

     import numpy as np  
    import jax  
    import jax.numpy as jnp  
    from jax.sharding import Mesh, PartitionSpec as P  
    from jax.experimental import pjit  

    devices = np.array(jax.devices()).reshape(1, -1)  # 1 x N mesh  
    mesh = Mesh(devices, ('data',))  

    def loss_fn(params, batch):  
        logits = model_apply(params, batch['x'])  
        return cross_entropy(logits, batch['y'])  

    @pjit.pjit(  
        in_shardings=(P(None), P('data')),   # params replicated, batch sharded on 'data'  
        out_shardings=P(None),               # scalar loss replicated  
    )  
    def step(params, batch):  
        grads = jax.grad(loss_fn)(params, batch)  
        # aggregate grads across cores  
        grads = jax.tree.map(lambda g: jax.lax.pmean(g, axis_name='data'), grads)  
        return grads  

    with mesh:  
         grads = step(params, sharded_batch)

切分(Sharding)这事必须显式。如果偷懒依赖自动推导,等到后期 debug 那些悄无声息的跨设备数据传输时,绝对会很痛苦。

4、jit, vmap, scan 三件套

TPU 喜欢大块头的 Kernel,讨厌成千上万个细碎的小算子。训练 Step 和任何中大型计算逻辑,必须用

jit

包起来。遇到 Python 循环,如果是时间步逻辑就换成

lax.scan

,如果是批次并行就用

vmap

把 Loss 计算、梯度计算和参数更新塞进同一个 jitted 函数里,这样编译器才有机会把它们融合成一个大算子。

     import optax  
    import jax  

    optimizer = optax.adamw(3e-4)  

    def loss_and_grads(params, batch):  
        def _loss(p):  
            logits = model_apply(p, batch['x'])  
            return cross_entropy(logits, batch['y'])  
        loss, grads = jax.value_and_grad(_loss)(params)  
        return loss, grads  

    @jax.jit  
    def train_step(state, batch):  
        loss, grads = loss_and_grads(state.params, batch)  
        grads = jax.lax.pmean(grads, axis_name='data')  
        updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)  
        new_params = optax.apply_updates(state.params, updates)  
         return state.replace(params=new_params, opt_state=new_opt_state), loss

5、别让输入管道拖后腿

Host 到 Device 的数据传输一旦停顿,吞吐量就掉下来了,所以永远别让计算单元等数据。

tf.data

或者高效的 NumPy loader 配合 prefetch。数据预取到设备(Stage to device) 最好做双重缓冲。全局 Batch 尽量大(当然要能被核心数整除),数据增强这种脏活累活在 Host 上一次性做完。

     # tf.data pipeline (conceptual)  
    ds = (tf.data.TFRecordDataset(files)  
          .map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)  
          .batch(global_batch_size, drop_remainder=True)  
          .prefetch(tf.data.AUTOTUNE))  

    # Convert to NumPy and prefetch onto devices  
    from flax.jax_utils import prefetch_to_device  
    it = prefetch_to_device(map(npify, ds.as_numpy_iterator()), size=2)  

    with mesh:  
        for step_i in range(num_steps):  
            batch = next(it)     # already sharded/prefetched  
             state, loss = train_step(state, batch)

6、PRNG要Fold 进 Step 和 Device ID

JAX 的 PRNG 是无状态的,这意味如果不小心,很容易在不同 Step 或者不同设备上用了一样的随机数 Key。

每个 Step 都要 Split 一次绝对别复用。所以说为了保证独立性必须把 Global StepDevice IndexFold 进去。数据增强/Dropout 的 Key 和参数初始化的 Key 得分开管理。

     def make_step_rng(rng, step):  
        step_key = jax.random.fold_in(rng, step)  
        dev_key  = jax.random.fold_in(step_key, jax.lax.axis_index('data'))  
        return jax.random.split(dev_key, 1)[0]  

    @jax.jit  
    def train_step(state, batch, base_rng):  
        rng = make_step_rng(base_rng, state.step)  
        logits = model_apply(state.params, batch['x'], rngs={'dropout': rng})  
         ...

7、Remat,智能 Checkpoint,梯度累积

TPU 内存看着大,模型一跑起来就不够用。深层网络可以直接用 Activation Checkpointing(

jax.checkpoint

nn.remat

),用计算换显存。想跑大 Batch 但显存不够,就用梯度累积(Gradient Accumulation) 把它切成小的 micro-step。

存盘的时候,推荐用 Orbax 做异步、分片(Sharded)的 Checkpoint,稳。

     from flax import linen as nn  

    class DeepBlock(nn.Module):  
        @nn.compact  
        def __call__(self, x):  
            # recompute on backward to trim activation memory  
            f = nn.remat(lambda y: nn.gelu(nn.Dense(x.shape[-1])(y)))  
            return f(x)  

    # Gradient accumulation (conceptual)  
    @jax.jit  
    def accum_step(state, batch_slices):  
        def body(carry, micro):  
            state, grad_sum = carry  
            _, grads = loss_and_grads(state.params, micro)  
            return (state, jax.tree_util.tree_map(jnp.add, grad_sum, grads)), None  
        init_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params)  
        (state, grad_sum), _ = jax.lax.scan(body, (state, init_grads), batch_slices)  
        grads = jax.tree_map(lambda g: g / len(batch_slices), grad_sum)  
         ...

8、一定要跑 Profiler

把关键代码段用 Profiler Annotations 包起来,看 Step Timeline。重点找 Host Waits、Recompiles 和那些没融合好的细碎算子(Small op soup)。

稳态运行的时候,盯着 Tokens/sec 或者Images/sec,还有硬件利用率。

     from jax.experimental import host_callback as hcb  
    from jax import profiler  

    def tagged(name, fn, *a, **k):  
        profiler.annotate_function(name=name)  
        return fn(*a, **k)  

    @jax.jit  
    def train_step(state, batch):  
        profiler.annotate_function(name="train_step")  
        # do work...  
         return state, loss

一定要在锁定 Shape 并且 JIT 完热点路径之后再做 Profile,不然全是噪音,根本看不到真正的瓶颈。

极简 TPU 训练示例

这基本包含了上面所有的内容

     # Pseudo-skeleton (Flax + JAX + TPU)  
    mesh = Mesh(np.array(jax.devices()).reshape(1, -1), ('data',))  

    @pjit.pjit(in_shardings=(P(None), P('data'), P(None)), out_shardings=(P(None), P(None)))  
    def train_step(state, batch, base_rng):  
        rng = jax.random.fold_in(base_rng, state.step)  
        rng = jax.random.fold_in(rng, jax.lax.axis_index('data'))  
        def loss_fn(p):  
            logits = model_apply(p, batch['x'].astype(jnp.bfloat16),  
                                 rngs={'dropout': rng})  
            return cross_entropy(logits, batch['y'])  
        loss, grads = jax.value_and_grad(loss_fn)(state.params)  
        grads = jax.tree_map(lambda g: jax.lax.pmean(g, 'data'), grads)  
        updates, opt_state = optimizer.update(grads, state.opt_state, state.params)  
        params = optax.apply_updates(state.params, updates)  
        return state.replace(params=params, opt_state=opt_state, step=state.step+1), loss  

    with mesh:  
        for step_i, batch in enumerate(prefetched_iterator):  
            state, loss = train_step(state, batch, base_rng)  
            if step_i % log_every == 0:  
                # Pull back just tiny scalars; keep big tensors on device  
                host_loss = jax.device_get(loss)  
                 print(f"[{step_i}] loss={host_loss:.4f}")

总结

TPU 需要的是 一致性:稳定的 Shape,融合的 Kernel,目的明确的切分,不掉链子的数据管道,把上面的这八件事做好,写 JAX 训练循环就非常顺畅了。

https://avoid.overfit.cn/post/16b582a493ba4eca8333314859665dd2

作者:Modexa

目录
相关文章
|
3月前
|
XML 机器学习/深度学习 监控
高级检索增强生成系统:LongRAG、Self-RAG 和 GraphRAG 的实现与选择
检索增强生成(RAG)已超越简单向量匹配,迈向LongRAG、Self-RAG与GraphRAG等高级形态。LongRAG通过大块重叠分片保留长上下文,提升连贯性;Self-RAG引入反思机制,动态判断检索必要性与内容相关性,增强可信度;GraphRAG构建知识图谱,支持多跳推理与复杂关系挖掘。三者分别应对上下文断裂、检索盲目性与关系表达缺失难题,代表2025年RAG工程化核心进展,可依场景组合使用以平衡准确性、成本与复杂度。
363 57
高级检索增强生成系统:LongRAG、Self-RAG 和 GraphRAG 的实现与选择
|
3月前
|
数据采集 人工智能 自然语言处理
Meta SAM3开源:让图像分割,听懂你的话
Meta发布并开源SAM 3,首个支持文本或视觉提示的统一图像视频分割模型,可精准分割“红色条纹伞”等开放词汇概念,覆盖400万独特概念,性能达人类水平75%–80%,推动视觉分割新突破。
1555 59
Meta SAM3开源:让图像分割,听懂你的话
|
2月前
|
人工智能 运维 安全
SOC 2.0 来了:不是加人加班,而是加“智能”!——智能化安全运营中心的建设之道
SOC 2.0 来了:不是加人加班,而是加“智能”!——智能化安全运营中心的建设之道
243 15
|
2月前
|
弹性计算 搜索推荐 应用服务中间件
今非昔比:看完阿里云服务器租赁价格,沉默了~
阿里云服务器优惠汇总:轻量应用服务器200M带宽38元起/年,ECS云服务器2核2G仅99元/年,4核16G 89元/月,8核32G 160元/月,香港轻量服务器25元/月起,爆款低至1折,新老用户同享,续费同价,限时抢购!
423 14
|
3月前
|
人工智能 前端开发 算法
大厂CIO独家分享:AI如何重塑开发者未来十年
在 AI 时代,若你还在紧盯代码量、执着于全栈工程师的招聘,或者仅凭技术贡献率来评判价值,执着于业务提效的比例而忽略产研价值,你很可能已经被所谓的“常识”困住了脚步。
1856 89
大厂CIO独家分享:AI如何重塑开发者未来十年
|
2月前
|
机器学习/深度学习 人工智能 运维
别只盯着 CPU 爆了!一篇文章带你看懂:从指标到根因的 AIOps 自动化故障定位流水线
别只盯着 CPU 爆了!一篇文章带你看懂:从指标到根因的 AIOps 自动化故障定位流水线
359 15
|
3月前
|
机器学习/深度学习 传感器 算法
BipedalWalker实战:SAC算法如何让机器人学会稳定行走
本文探讨基于Soft Actor-Critic(SAC)算法的下肢假肢自适应控制。传统方法依赖精确建模,难以应对复杂环境变化。SAC通过最大熵强化学习,使假肢在仿真中自主探索、学习稳定步态,具备抗干扰与容错能力。结合生物工程视角,将神经网络映射为神经系统,奖励函数关联代谢效率,实现从试错到自然行走的演化。相位图分析显示极限环形成,标志动态稳定步态建立,能效曲线表明后期动作更节能。研究为智能假肢迈向临床应用提供新思路。
339 117
BipedalWalker实战:SAC算法如何让机器人学会稳定行走
|
2月前
|
人工智能 BI 开发工具
适合个人开发者的5款开发工具,开发者必须知道
2025年,个人开发者迎来工具黄金时代。本文精选5款高效开发利器:GitHub Copilot(AI智能编程)、Trae(中文友好)、Cursor(项目级理解)、VS Code(开源全能)和Zoho Creator(低代码平台),覆盖从代码生成到应用搭建,助你提升效率,快速实现创意。
753 2
|
2月前
|
机器学习/深度学习 人工智能 缓存
CALM自编码器:用连续向量替代离散token,生成效率提升4倍
近年来语言模型效率优化多聚焦参数规模与注意力机制,却忽视了自回归生成本身的高成本。CALM提出新思路:在token之上构建潜在空间,通过变分自编码器将多个token压缩为一个连续向量,实现“一次前向传播生成多个token”。该方法大幅减少计算次数,提升推理速度与吞吐量,同时引入无似然训练与BrierLM评估体系,突破传统语言建模范式,为高效大模型提供新路径。
160 7
CALM自编码器:用连续向量替代离散token,生成效率提升4倍
|
3月前
|
人工智能 安全 API
FastMCP 入门:用 Python 快速搭建 MCP 服务器接入 LLM
MCP协议为大语言模型连接外部工具与数据提供标准化方案,FastMCP是其Python最佳实践框架。本文详解MCP核心概念,演示如何用FastMCP快速搭建支持工具调用、资源访问与身份认证的MCP服务器,并集成至LLM应用,实现AI智能体与真实世界的高效交互。
1517 2
FastMCP 入门:用 Python 快速搭建 MCP 服务器接入 LLM