JAX 中文文档(九)(4)

简介: JAX 中文文档(九)

JAX 中文文档(九)(3)https://developer.aliyun.com/article/1559675


在 JAX 中编写自定义 Jaxpr 解释器

原文:jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html

[外链图片转存中…(img-VdY1f7HH-1718950445125)]

JAX 提供了几个可组合的函数转换(jitgradvmap等),可以编写简洁且加速的代码。

这里我们展示了如何通过编写自定义 Jaxpr 解释器来向系统添加自己的函数转换。而且我们将自动获得与所有其他转换的可组合性。

此示例使用了内部 JAX API,可能随时会中断。任何不在API 文档中的内容都应视为内部内容。

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random 

JAX 在做什么?

JAX 为数值计算提供了类似 NumPy 的 API,可以直接使用,但 JAX 真正的强大之处在于可组合的函数转换。例如jit函数转换接受一个函数并返回一个语义上相同的函数,但由 XLA 进行惰性编译以加速器。

x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
  return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f) 

当我们调用fast_f时,会发生什么?JAX 会追踪函数并构建一个 XLA 计算图。然后将图进行即时编译(JIT)并执行。其他转换类似,它们首先会追踪函数并以某种方式处理输出追踪。要了解更多关于 JAX 追踪机制的信息,您可以参考 README 中的“How it works”部分。

Jaxpr 追踪器

Jax 中一个特别重要的追踪器是 Jaxpr 追踪器,它将操作记录到一个 Jaxpr(Jax 表达式)中。Jaxpr 是一种数据结构,可以像小型函数式编程语言一样进行评估,因此 Jaxprs 是函数转换的有用中间表示。

要首次查看 Jaxprs,可以考虑make_jaxpr转换。make_jaxpr本质上是一个“漂亮打印”转换:它将一个函数转换为一个函数,给定示例参数,生成其计算的 Jaxpr 表示。make_jaxpr对于调试和内省非常有用。让我们使用它来查看一些示例 Jaxprs 的结构。

def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)
def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10))) 
foo
=====
invars: [Var(id=140117887103104):int32[]]
outvars: [Var(id=140117887103296):int32[]]
constvars: []
equation: [Var(id=140117887103104):int32[], 1] add [Var(id=140117887103296):int32[]] {}
jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
bar
=====
invars: [Var(id=140117843771968):float32[5,10], Var(id=140117843772032):float32[5], Var(id=140117843772096):float32[10]]
outvars: [Var(id=140117843772352):float32[5], Var(id=140117843772096):float32[10]]
constvars: []
equation: [Var(id=140117843771968):float32[5,10], Var(id=140117843772096):float32[10]] dot_general [Var(id=140117843772160):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [Var(id=140117843772160):float32[5], Var(id=140117843772032):float32[5]] add [Var(id=140117843772224):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=140117843772288):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=140117843772224):float32[5], Var(id=140117843772288):float32[5]] add [Var(id=140117843772352):float32[5]] {}
jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
    d:f32[5] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a c
    e:f32[5] = add d b
    f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
    g:f32[5] = add e f
  in (g, c) } 
  • jaxpr.invars - Jaxpr 的invars是一个输入变量列表,类似于 Python 函数的参数。
  • jaxpr.outvars - Jaxpr 的outvars是由 Jaxpr 返回的变量。每个 Jaxpr 都有多个输出。
  • jaxpr.constvars - constvars是一个变量列表,它们也是 Jaxpr 的输入之一,但对应于跟踪中的常量(我们稍后会更详细地讨论这些内容)。
  • jaxpr.eqns - 一个方程列表,实质上是 let 绑定。每个方程包含输入变量列表、输出变量列表和一个原语,用于评估输入以生成输出。每个方程还有一个 params,即参数字典。

总的来说,一个 Jaxpr 封装了一个简单的程序,可以使用输入进行评估以生成输出。稍后我们将详细介绍如何做到这一点。现在需要注意的重要事项是,Jaxpr 是一个可以按我们想要的方式操作和评估的数据结构。

Jaxprs 有什么用处?

Jaxprs 是简单的程序表示,易于转换。由于 Jax 允许我们从 Python 函数中分离出 Jaxprs,它为我们提供了一种转换用 Python 编写的数值程序的方法。

您的第一个解释器:invert

让我们尝试实现一个简单的函数“inverter”,它接收原始函数的输出,并返回产生这些输出的输入。现在,让我们专注于由其他可逆的一元函数组成的简单一元函数。

目标:

def f(x):
  return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0) 

我们将通过 (1) 将 f 追踪到 Jaxpr 中,然后 (2) 反向解释 Jaxpr 的方式来实现这一点。在反向解释 Jaxpr 过程中,对于每个方程,我们将在表中查找原语的逆,并应用它。

1. 追踪一个函数

让我们使用 make_jaxpr 来追踪一个函数到 Jaxpr 中。

# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps
from jax import core
from jax import lax
from jax._src.util import safe_map 

jax.make_jaxpr 返回一个封闭的 Jaxpr,即一个已经与跟踪中的常量(literals)捆绑在一起的 Jaxpr。

def f(x):
  return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals) 
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[] 

2. 评估 Jaxpr

在编写自定义 Jaxpr 解释器之前,让我们首先实现“默认”解释器 eval_jaxpr,它按原样评估 Jaxpr,计算与未转换的原始 Python 函数相同的值。

为此,我们首先创建一个环境来存储每个变量的值,并在评估 Jaxpr 中的每个方程时更新该环境。

def eval_jaxpr(jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}
  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]
  def write(var, val):
    env[var] = val
  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)
  # Loop through equations and evaluate primitives using `bind`
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)  
    # `bind` is how a primitive is called
    outvals = eqn.primitive.bind(*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results: 
      outvals = [outvals]
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals) 
  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars) 
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5)) 
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)] 

注意,即使原始函数不返回平坦列表,eval_jaxpr 也将始终返回一个平坦列表。

此外,这个解释器不处理高阶原语(如 jitpmap),这些内容不在本指南讨论范围内。您可以参考 core.eval_jaxpr (链接) 来查看此解释器不涵盖的边界情况。

自定义inverse Jaxpr 解释器

inverse 解释器看起来与 eval_jaxpr 并无太大不同。我们首先设置注册表,将原语映射到它们的逆。然后编写一个自定义解释器,在注册表中查找原语。

结果表明,这个解释器看起来也类似于反向模式自动微分中使用的“转置”解释器,可以在此处找到:链接

inverse_registry = {} 

现在我们将为一些原语注册它们的逆。按照惯例,Jax 中的原语以 _p 结尾,而其中许多流行的原语位于 lax 中。

inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh 

inverse 将首先跟踪函数,然后自定义解释 Jaxpr。让我们建立一个简单的框架。

def inverse(fun):
  @wraps(fun)
  def wrapped(*args, **kwargs):
    # Since we assume unary functions, we won't worry about flattening and
    # unflattening arguments.
    closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
    out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
    return out[0]
  return wrapped 

现在我们只需要定义 inverse_jaxpr,它将反向遍历 Jaxpr 并在可能时反转原语。

def inverse_jaxpr(jaxpr, consts, *args):
  env = {}
  def read(var):
    if type(var) is core.Literal:
      return var.val
    return env[var]
  def write(var, val):
    env[var] = val
  # Args now correspond to Jaxpr outvars
  safe_map(write, jaxpr.outvars, args)
  safe_map(write, jaxpr.constvars, consts)
  # Looping backward
  for eqn in jaxpr.eqns[::-1]:
    #  outvars are now invars 
    invals = safe_map(read, eqn.outvars)
    if eqn.primitive not in inverse_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have registered inverse.")
    # Assuming a unary function 
    outval = inverse_registryeqn.primitive
    safe_map(write, eqn.invars, [outval])
  return safe_map(read, jaxpr.invars) 

就是这样!

def f(x):
  return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0) 

重要的是,你可以通过 Jaxpr 解释器进行跟踪。

jax.make_jaxpr(inverse(f))(f(1.)) 
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) } 

这就是向系统添加新转换所需的全部内容,而且你可以免费获得所有其他转换的组合!例如,我们可以在 inverse 中使用 jitvmapgrad

jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.) 
Array([-3.1440797, 15.584931 ,  2.2551253,  1.3155028,  1\.       ],      dtype=float32, weak_type=True) 

读者的练习

  • 处理具有多个参数的原语,其中输入部分已知,例如 lax.add_plax.mul_p
  • 处理 xla_callxla_pmap 原语,这些原语不会与 eval_jaxprinverse_jaxpr 一样正常工作。

使用 C++ 和 CUDA 进行 GPU 自定义操作

原文:jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html

JAX 预装有大量内置操作,但用户偶尔会遇到需要新操作但 JAX 不支持的情况。

为了适应这些情况,JAX 允许用户定义自定义操作,本教程旨在解释如何为 GPU 定义并在单 GPU 和多 GPU 环境中使用它们。

本教程包含来自 使用自定义 C++ 和 CUDA 代码扩展 JAX 的信息,并假设您熟悉 JAX 原语

RMS 标准化

本教程将 RMS 标准化作为 JAX 中的自定义操作添加。请注意,可以直接使用 jax.numpy 表达 RMS 标准化。但是,我们使用它作为示例来展示如何为 GPU 创建自定义操作的过程。此操作在 gpu_ops/rms_norm_kernels.cu 中的 CUDA 代码已从 Apex 借用,并进行了修改,以消除对 PyTorch 的任何依赖。

高级步骤

本教程展示了如何编写自定义操作及其梯度。

在 C 中:每个新的 JAX 原语都需要按照以下步骤进行操作。

  • 具有 CUDA 核心(核心)。
  • 创建分派 CUDA 核心的 C 函数,该函数将由 XLA 调用。
  • 创建描述符以传达计算所需的信息。
  • 类型、形状和其他属性。
  • 将 C 函数绑定到 Python
  • 以创建描述符并在执行期间调用原语。

在 Python 中:您需要按照以下步骤进行操作。

  • 定义新的 JAX 原语(指令/操作)
  • 编写 Python 函数以使用原语构建图节点。
  • 定义其抽象评估。
  • 定义其降低到 MLIR。
  • [可选] 定义梯度。
  • [可选] 使用 custom_partitioningshard_map 函数实现快速多 GPU 支持。

C 代码

参见 gpu_ops 代码列表,其中包含完整的 C++ 和 CUDA 文件代码列表。gpu_ops/rms_norm_kernels.cu 定义了以下函数,这些函数使用给定的 buffers 在指定的 stream 上启动 RMS 标准化核心。

namespace  gpu_ops  {
void  rms_forward_affine_mixed_dtypes(cudaStream_t  stream,  void  **buffers,
  const  char  *opaque,
  std::size_t  opaque_len);
void  rms_backward_affine(cudaStream_t  stream,  void  **buffers,
  const  char  *opaque,  std::size_t  opaque_len);
}  // namespace gpu_ops 
  • stream 是用于在 GPU 上执行任何核心的 CUDA 流。
  • buffers 包含所有指向输入缓冲区的指针,后跟所有指向输出缓冲区的指针。
  • opaque 是传递给自定义函数的任何额外信息的缓冲区,而 opaque_lenopaque 的长度。

在本教程中,我们将通过opaque将一个RMSNormDescriptor对象传递给这些函数。

namespace  gpu_ops  {
enum  ElementType  {  BF16,  F16,  F32,  F64  };
struct  RMSNormDescriptor  {
  int  n1;
  int  n2;
  double  eps;
  ElementType  x_type;
  ElementType  w_type;
  int  part_grad_size;
};
}  // namespace gpu_ops 

现在,我们需要通过pybind11将这些函数以及ElementTypeRMSNormDescriptor作为 Python 模块gpu_ops公开。

pybind11::dict  RMSNormRegistrations()  {
  pybind11::dict  dict;
  dict["rms_forward_affine_mixed_dtype"]  =
  gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
  dict["rms_backward_affine"]  =
  gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
  return  dict;
}
PYBIND11_MODULE(gpu_ops,  m)  {
  m.def("get_rms_norm_registrations",  &RMSNormRegistrations);
  m.def("create_rms_norm_descriptor",
  [](int  n1,  int  n2,  double  eps,  gpu_ops::ElementType  x_type,
  gpu_ops::ElementType  w_type,  int  part_grad_size)  {
  return  gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
  n1,  n2,  eps,  x_type,  w_type,  part_grad_size});
  });
  pybind11::enum_<gpu_ops::ElementType>(m,  "ElementType")
  .value("BF16",  gpu_ops::ElementType::BF16)
  .value("F16",  gpu_ops::ElementType::F16)
  .value("F32",  gpu_ops::ElementType::F32)
  .value("F64",  gpu_ops::ElementType::F64);
} 

构建gpu_ops扩展模块

我们使用上述代码构建了gpu_ops Python 扩展模块。(请参阅 C++和 CUDA 文件的完整代码清单,查看gpu_ops代码列表。)

python  -m  pip  install  pybind11==2.10.1
mkdir  -p  build
pybind_include_path=$(python  -c  "import pybind11; print(pybind11.get_include())")
python_executable=$(python  -c  'import sys; print(sys.executable)')
nvcc  --threads  4  -Xcompiler  -Wall  -ldl  --expt-relaxed-constexpr  -O3  -DNDEBUG  -Xcompiler  -O3  --generate-code=arch=compute_70,code=[compute_70,sm_70]  --generate-code=arch=compute_75,code=[compute_75,sm_75]  --generate-code=arch=compute_80,code=[compute_80,sm_80]  --generate-code=arch=compute_86,code=[compute_86,sm_86]  -Xcompiler=-fPIC  -Xcompiler=-fvisibility=hidden  -x  cu  -c  gpu_ops/rms_norm_kernels.cu  -o  build/rms_norm_kernels.cu.o
c++  -I/usr/local/cuda/include  -I$pybind_include_path  $(${python_executable}-config  --cflags)  -O3  -DNDEBUG  -O3  -fPIC  -fvisibility=hidden  -flto  -fno-fat-lto-objects  -o  build/gpu_ops.cpp.o  -c  gpu_ops/gpu_ops.cpp
c++  -fPIC  -O3  -DNDEBUG  -O3  -flto  -shared  -o  build/gpu_ops$(${python_executable}-config  --extension-suffix)  build/gpu_ops.cpp.o  build/rms_norm_kernels.cu.o  -L/usr/local/cuda/lib64  -lcudadevrt  -lcudart_static  -lrt  -lpthread  -ldl
strip  build/gpu_ops$(${python_executable}-config  --extension-suffix) 

将 RMS 归一化添加到 JAX 作为自定义调用

gpu_ops只是一个 Python 扩展模块,我们需要更多工作来将其插入到 JAX 中。

创建原语

我们首先创建了原语_rms_norm_fwd_p_rms_norm_bwd_p,这些原语可以映射到自定义函数。我们为这些操作设置了multiple_results属性为True,表示该操作作为元组产生多个输出。当设置为False时,该操作将产生单个输出而不是元组。有关更多详细信息,请参见How JAX primitives work

from functools import partial
import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from build import gpu_ops
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client
# Create _rms_norm_fwd_p for forward operation.
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
_rms_norm_fwd_p.multiple_results = True
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))
def rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
    return output
# Create _rms_norm_bwd_p for backward operation.
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
_rms_norm_bwd_p.multiple_results = True
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))
def rms_norm_bwd(g, invvar, x, weight, eps):
    grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
        g, invvar, x, weight, eps=eps
    )
    return grad_input, grad_weight 

降低到 MLIR 自定义调用

为了将自定义函数映射到新原语_rms_norm_fwd_p_rms_norm_bwd_p,我们需要:

  • 使用xla_client.register_custom_call_target注册自定义函数作为自定义调用目标,并且
  • 注册将原语降低为 MLIR 自定义调用的降低函数,并使用注册的自定义调用目标。

下面的函数_rms_norm_fwd_cuda_lowering_rms_norm_bwd_cuda_lowering通过gpu_ops中的自定义目标将原语降低为 MLIR 自定义调用操作。这些函数已经注册到jax.interpreters.mlir.register_lowering中。

注意,在降低函数中创建了一个RMSNormDescriptor对象,并将其作为opaque传递给自定义调用。

from functools import reduce
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call
# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")
def element_type_to_descriptor_type_mapping(element_type):
    _element_type_to_descriptor_type_mapping = {
        ir.BF16Type.get(): gpu_ops.ElementType.BF16,
        ir.F16Type.get(): gpu_ops.ElementType.F16,
        ir.F32Type.get(): gpu_ops.ElementType.F32,
        ir.F64Type.get(): gpu_ops.ElementType.F64,
    }
    return _element_type_to_descriptor_type_mapping.get(element_type)
def default_layouts(*shapes):
    return [range(len(shape) - 1, -1, -1) for shape in shapes]
def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
    x_type = ir.RankedTensorType(x.type)
    x_shape = x_type.shape
    w_type = ir.RankedTensorType(weight.type)
    w_shape = w_type.shape
    iv_element_type = (
        ir.F32Type.get()
        if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
        else x_type.element_type
    )
    n2 = reduce(lambda x, y: x * y, w_shape)
    n1 = reduce(lambda x, y: x * y, x_shape) // n2
    opaque = gpu_ops.create_rms_norm_descriptor(
        n1,
        n2,
        eps,
        element_type_to_descriptor_type_mapping(x_type.element_type),
        element_type_to_descriptor_type_mapping(w_type.element_type),
        0,  # unused
    )
    out = custom_call(
        b"rms_forward_affine_mixed_dtype",
        result_types=[
            ir.RankedTensorType.get(x_shape, w_type.element_type),
            ir.RankedTensorType.get((n1,), iv_element_type),
        ],
        operands=[x, weight],
        backend_config=opaque,
        operand_layouts=default_layouts(x_shape, w_shape),
        result_layouts=default_layouts(x_shape, (n1,)),
    ).results
    return out
mlir.register_lowering(
    _rms_norm_fwd_p,
    _rms_norm_fwd_cuda_lowering,
    platform="gpu",
)
def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
    x_type = ir.RankedTensorType(x.type)
    x_shape = x_type.shape
    w_type = ir.RankedTensorType(weight.type)
    w_shape = w_type.shape
    iv_type = ir.RankedTensorType(invvar.type)
    n2 = reduce(lambda x, y: x * y, w_shape)
    n1 = reduce(lambda x, y: x * y, x_shape) // n2
    part_grad_shape = ctx.avals_out[-1].shape
    opaque = gpu_ops.create_rms_norm_descriptor(
        n1,
        n2,
        eps,
        element_type_to_descriptor_type_mapping(x_type.element_type),
        element_type_to_descriptor_type_mapping(w_type.element_type),
        part_grad_shape[0],
    )
    out = custom_call(
        b"rms_backward_affine",
        result_types=[
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(w_shape, w_type.element_type),
            ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
        ],
        operands=[grad_output, invvar, x, weight],
        backend_config=opaque,
        operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
        result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
    ).results
    return out
mlir.register_lowering(
    _rms_norm_bwd_p,
    _rms_norm_bwd_cuda_lowering,
    platform="gpu",
) 


JAX 中文文档(九)(5)https://developer.aliyun.com/article/1559677

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
3天前
|
存储 机器学习/深度学习 编译器
JAX 中文文档(九)(1)
JAX 中文文档(九)
12 0
|
3天前
|
机器学习/深度学习 并行计算 安全
JAX 中文文档(七)(1)
JAX 中文文档(七)
8 0
|
2天前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
8 0
|
3天前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
7 0
|
2天前
|
机器学习/深度学习 异构计算 Python
JAX 中文文档(四)(3)
JAX 中文文档(四)
8 0
|
2天前
|
编译器 异构计算 Python
JAX 中文文档(四)(2)
JAX 中文文档(四)
7 0
|
3天前
|
并行计算 编译器
JAX 中文文档(六)(4)
JAX 中文文档(六)
6 0
|
3天前
|
机器学习/深度学习 缓存 编译器
JAX 中文文档(二)(1)
JAX 中文文档(二)
8 0
|
2天前
|
缓存 PyTorch API
JAX 中文文档(一)(3)
JAX 中文文档(一)
8 0
|
2天前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
7 0