JAX 中文文档(三)(4)

简介: JAX 中文文档(三)

JAX 中文文档(三)(3)https://developer.aliyun.com/article/1559704


checkify转换

原文:jax.readthedocs.io/en/latest/debugging/checkify_guide.html

TL;DR checkify允许您向您的 JAX 代码添加可jit的运行时错误检查(例如越界索引)。使用checkify.checkify转换与类似断言的checkify.check函数一起向 JAX 代码添加运行时检查:

from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
  y = x[i]
  z = jnp.sin(y)
  return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f)) 

您还可以使用checkify来自动添加常见的检查:

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
err, z = checked_f(jnp.array([5, 1]), 0)
err.throw()  # if no error occurred, throw does nothing! 

功能化检查

与 assert 类似的检查 API 本身不是函数纯粹的:它可以作为副作用引发 Python 异常,就像 assert 一样。因此,它不能与jitpmappjitscan分阶段执行:

jax.jit(f)(jnp.ones((5,)), -1)  # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized. 

但是checkify转换功能化(或卸载)这些效果。一个经过checkify转换的函数将错误作为新输出返回,并保持函数纯粹。这种功能化意味着checkify转换的函数可以与我们喜欢的任何分阶段/转换进行组合:

err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
..  at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
..  at mapped index 2: out-of-bounds indexing at <..>:7 (f)
""" 

JAX 为什么需要checkify

在某些 JAX 转换下,您可以使用普通的 Python 断言表达运行时错误检查,例如仅使用jax.gradjax.numpy时。

def f(x):
  assert x > 0., "must be positive!"
  return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!" 

但是普通的断言在jitpmappjitscan中不起作用。在这些情况下,数值计算是在 Python 执行期间被分阶段地进行评估,因此数值值不可用:

jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..." 

在组合多个转换时,JAX 转换语义依赖于函数纯度,因此我们如何在不干扰所有这些的情况下提供一个错误机制?除了需要一个新的 API  之外,情况还更加棘手:XLA HLO 不支持断言或抛出错误,因此即使我们有一个能够分阶段断言的 JAX API,我们如何将这些断言降低到 XLA  呢?

您可以想象手动向函数添加运行时检查并通过值来传递表示错误:

def f_checked(x):
  error = x <= 0.
  result = jnp.log(x)
  return error, result
err, y = jax.jit(f_checked)(0.)
if err:
  raise ValueError("must be positive!")
# ValueError: "must be positive!" 

错误是由函数计算出的常规值,并且错误是在f_checked外部引发的。f_checked是函数式纯粹的,因此我们知道通过构造,它已经可以与jitpmappjitscan以及所有 JAX 的转换一起工作。唯一的问题是这些管道可能会很麻烦!

checkify为您完成了这个重写工作:包括通过函数传递错误值、将检查重写为布尔操作并将结果与跟踪的错误值合并,并将最终错误值作为检查函数的输出返回:

def f(x):
  checkify.check(x > 0., "{} must be positive!", x)  # convenient but effectful API
  return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1\. must be positive! (check failed at <...>:2 (f)) 

我们称这个过程为功能化或者通过调用检查引入的效果。 (在上面的“手动”示例中,错误值只是一个布尔值。checkify的错误值在概念上类似,但还跟踪错误消息并公开抛出和获取方法;参见jax.experimental.checkify)。checkify.check还允许您通过将其作为格式参数提供给错误消息来将运行时值添加到您的错误消息中。

您现在可以手动为您的代码添加运行时检查,但 checkify 也可以自动添加常见错误的检查!考虑这些错误情况:

jnp.arange(3)[5]                # out of bounds
jnp.sin(jnp.inf)                # NaN generated
jnp.ones((5,)) / jnp.arange(5)  # division by zero 

默认情况下,checkify 仅释放 checkify.check,不会捕获类似上述的错误。但如果您要求,checkify 也会自动在您的代码中添加检查。

def f(x, i):
  y = x[i]        # i could be out of bounds.
  z = jnp.sin(y)  # z could become NaN
  return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f) 

基于 Sets 的 API,用于选择要启用的自动检查。详见 jax.experimental.checkify 获取更多详情。

在 JAX 变换下的 checkify

如上例所示,checkified 函数可以愉快地进行 jitted 处理。以下是 checkify 与其他 JAX 变换的几个示例。请注意,checkified 函数在功能上是纯粹的,并且应与所有 JAX 变换轻松组合!

jit

您可以安全地向 checkified 函数添加 jax.jit,或者 checkify 一个 jitted 函数,两者都可以正常工作。

def f(x, i):
  return x[i]
checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ =  checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f) 

vmap/pmap

您可以 vmappmap checkified 函数(或 checkify 映射函数)。映射一个 checkified 函数将为您提供一个映射的错误,该错误可以包含映射维度的每个元素的不同错误。

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
 at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
 at mapped index 2: out-of-bounds indexing at <...>:3 (f)
""" 

然而,checkify-of-vmap 将产生单个(未映射)的错误!

@jax.vmap
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f)) 

pjit

对于 checkified 函数的 pjit 可以正常工作,您只需为错误值输出的 out_axis_resources 指定额外的 None

def f(x):
  return x / x
f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
  f,
  in_shardings=PartitionSpec('x', None),
  out_shardings=(None, PartitionSpec('x', None)))
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
 err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f) 

grad

如果您使用 checkify-of-grad,还将对您的梯度计算进行检查:

def f(x):
 return x / (1 + jnp.sqrt(x))
grad_f = jax.grad(f)
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f) 

请注意,f 中没有乘法,但在其梯度计算中有乘法(这就是生成 NaN 的地方!)。因此,请使用 checkify-of-grad 为前向和后向传递操作添加自动检查。

checkify.check 仅应用于函数的主值。如果您想在梯度值上使用 check,请使用 custom_vjp

@jax.custom_vjp
def assert_gradient_negative(x):
 return x
def fwd(x):
 return assert_gradient_negative(x), None
def bwd(_, grad):
 checkify.check(grad < 0, "gradient needs to be negative!")
 return (grad,)
assert_gradient_negative.defvjp(fwd, bwd)
jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative! 

jax.experimental.checkify 的优势和限制

优势

  • 您可以在任何地方使用它(错误只是“值”,并在像其他值一样的转换下直观地表现)。
  • 自动插装:您无需对代码进行本地修改。相反,checkify 可以为其所有部分添加插装!

限制

  • 添加大量运行时检查可能很昂贵(例如,对每个原语添加 NaN 检查将增加计算中的许多操作)。
  • 需要将错误值从函数中线程化并手动抛出错误。如果未显式抛出错误,则可能会错过错误!
  • 抛出一个错误值将在主机上实现该错误值,这意味着它是一个阻塞操作,这会打败 JAX 的异步先行运行。


JAX 中文文档(三)(5)https://developer.aliyun.com/article/1559706

相关文章
|
数据库连接 数据库 Python
SQLAlchemy映射表结构和对数据的CRUD
SQLAlchemy映射表结构和对数据的CRUD
|
关系型数据库 MySQL C语言
gcc版本过低导致charconv: No such file or directory
gcc版本过低导致charconv: No such file or directory
2171 0
|
5月前
|
机器学习/深度学习 数据采集 人工智能
指令微调是什么:让大模型听懂人话的关键技术
指令微调(Instruction Tuning)是提升大模型“听懂人话”能力的关键技术:通过高质量指令-响应对训练,使模型从“会说话”进阶为“懂意图、会回应”,显著增强零样本泛化、任务适应与安全性,已成为大模型落地的必备环节。
|
人工智能 JavaScript 前端开发
从零开始,国内实现调用Open Ai
从零开始,国内实现调用Open Ai
2165 0
|
Linux Shell Python
-bash: pip: command not found pip命令报错 解决方法(Centos版)
-bash: pip: command not found pip命令报错 解决方法(Centos版)
4905 0
|
5月前
|
存储 物联网 数据中心
拒绝玄学炼丹:大模型微调显存需求精确计算指南,全参数微调与LoRA对比全解析
本文揭秘大模型微调显存消耗的本质,系统拆解模型权重、梯度、优化器状态、激活值四大组成部分的计算逻辑,推导可复用的显存估算公式;对比全量微调、LoRA、QLoRA等方案的显存需求,提供实用工具与配置建议,助开发者告别“玄学估算”,精准规划GPU资源。
|
自然语言处理 数据采集 运维
高质量行业大模型数据集构建的实战路径
一文讲透高质量行业大模型数据集从预训练、指令微调到合成数据的全流程实战构建路径。
|
4月前
|
人工智能 JavaScript Linux
立即吃上AI龙虾!OpenClaw 1分钟阿里云上/本地部署+MiniMax/Claude/百炼免费模型配置解析
本文详细介绍了2026年OpenClaw在阿里云轻量服务器与本地多系统的部署流程,以及MiniMax、Claude、阿里云百炼三大免费模型的配置方法,针对新手常见问题提供了避坑指南与解决方案。通过阿里云部署可实现7×24小时稳定运行,本地部署则兼顾隐私与灵活性,用户可根据自身需求选择合适方案。后续可进一步探索OpenClaw的技能市场、任务调度等高级功能,打造专属AI助理,提升工作效率。
1215 0
|
6月前
|
人工智能 缓存 物联网
从0到1:大模型算力配置不需要人,保姆级选卡与显存计算手册
本文深入解析大模型算力三阶段:训练、微调与推理,类比为“教育成长”过程,详解各阶段技术原理与GPU选型策略,涵盖显存计算、主流加速技术(如LoRA/QLoRA)、性能评估方法及未来趋势,助力开发者高效构建AI模型。
1246 2
|
6月前
|
存储 SQL 供应链
什么是PDF_417编码?
PDF417是一种高容量二维条码,支持文本、数字、字节等多种数据类型,具备强纠错能力与良好兼容性,广泛应用于政务、交通、医疗、物流等领域。通过HCreateLabelView软件可轻松生成与打印,适用于航空票务等场景,实现信息离线存储与高效识别。
388 3