飞桨x昇腾生态适配方案:06_算子适配举例

简介: 本节详细解析了Paddle-API与CANN-Kernel之间的差异及适配策略,涵盖三种主要场景:参数缺失或不对应、数据类型不匹配以及layout转换。针对不同问题提出具体解决方案,如通过默认赋值或计算补充参数、使用`Cast`操作转换数据类型、借助`Transpose`调整数据布局等。同时,以ReluGrad和nll_loss算子为例,深入说明参数对齐、数据类型转换及转置操作的实现流程,为开发者提供清晰的适配指导。

本节介绍aclnn算子的三种适配场景。

Paddle-API 与 CANN-Kernel 差异剖析及适配策略

对于Paddle-API与CANN-Kernel两者中常见的差别与适配方法如下:

Paddle参数缺失或者参数无法直接对应

  • 如果Paddle算子只需要CANN提供的某个参数为默认值的功能,则可通过默认赋值的方式完成
  • 考虑通过计算取得需要参数

    CANN参数缺失

  • CANN算子没有某个Paddle有的参数,一般是此算子CANN支持的模式少于Paddle
  • 可通过多个算子分别完成算子的部分功能(如max_pool + avg_pool)
  • 如果CANN只能支持部分功能,则可以在调用处抛出参数值判断异常

    数据类型不支持

  • 输入数据类型不匹配时需要在计算前插入 Cast 操作,并且要更改输出的数据类型,在计算后对输出数据进行 Cast 操作返回原数据类型

    layout转换

  • NPU算子基本不支持NHWC,但是部分Paddle算子支持,如果遇到这样的情况需要在计算前后插入 Transpose

    小算子拼接

  • 部分Paddle-API的功能在NPU中没法直接完成,但可通过多个小算子拼接完成,一般会少许影响性能

    加入缺少的参数

    以ReluGrad算子为例,通过计算或者默认赋值方式加入缺少的参数:
    01_加入缺少的参数.png

在进行参数对齐时,需要检查是否存在需要默认参数的情况。以 Paddle 的 relugrad 算子为例,其对应的 aclnn 的 ThresholdBackward 算子包含额外参数 threshold。
在实际操作中,可通过默认赋值的方式实现参数对齐,如图所示,代码为phi::Scalar threshold = 0.0。完成参数对齐后,即可直接调用 NPU 的 aclnnThresholdBackward 算子 。

数据类型转换

以 nll_loss 算子为例,Paddle API 与 CANN API 所支持的数据类型存在差异。Paddle API 中,输入 x 的数据类型为 double,而 CANN 的对应算子仅支持 float32 这一特定数据类型,具体情况如下图所示:
Paddle侧:
02_Paddle侧数据类型.png

CAAN侧:
03_CANN侧数据类型.png

此情形需进行数据类型转换:

  • 首先,对输入数据执行 cast 操作,将其转换为 CANN 算子支持的数据类型;
  • 完成转换后,执行 NPU 的 aclnn 算子;
  • 算子运算结束后,再将计算结果的数据类型由 float32 转换回输入 x 原本的数据类型。
    具体流程如下图所示:
    04_数据类型转换流程.png

数据类型转换需要Cast_kernel算子,下面为Cast_kernel算子的声明:
05_Cast_kernel算子声明.png

以下介绍将变量 x 的数据类型从 double 转换为 float32(转换后的变量记为 x_cast)的流程:

对象声明

phi::DenseTensor x_cast;                // 声明目标张量 x_cast
phi::DenseTensorMeta x_cast_meta;       // 声明张量的元数据对象

phi::DenseTensor:深度学习框架中表示多维数组的核心数据结构,包含数据和元信息(如形状、数据类型)。
phi::DenseTensorMeta:用于存储张量的元信息(metadata)。

元数据初始化

x_cast_meta = {phi::DataType::FLOAT32, x.dims();

phi::DataType::FLOAT32:明确将目标张量的数据类型设为 float32。
x.dims():继承输入张量 x 的维度信息(如 [batch_size, channels, height, width])

绑定元数据

x_cast.set_meta(x_cast_meta);

作用:将初始化后的元数据绑定到目标张量 x_cast

执行类型转换

custom_kernel::CastKernel<T, Context>(dev_ctx, x, phi::DataType::FLOAT32, &x_cast);

核心参数:
dev_ctx: 设备上下文(如 CPU/GPU 资源管理)
x: 输入张量
phi::DataType::FLOAT32: 目标数据类型
&x_cast: 输出的目标张量指针
功能:将输入张量 x 的数据类型转换为 float32,结果写入 x_cast

执行NPU aclnn算子计算

把所有需要进行数据类型转换的参数转换完成后,使用算子执行宏EXEC_NPU_CMD执行aclnn算子:
06_EXEC_NPU_CMD宏.png

aclnn算子输出结果原类型恢复

07_aclnn算子输出结果原类型恢复.png

  • 将 out_cast(NPU计算结果)转换为paddle api的 out 张量类型(如 float32 → double)。
  • 将损失计算中累计的权重值 total_weight_cast 数据类型转换为目标类型后写入 total_weight

    转置操作

    在 Pool2dGradKernel 算子中,若输入数据格式data_format为NHWC,即高度、宽度、通道数位于最后,鉴于 NPU 的操作要求,需将数据转换为NCHW格式。此转换通过Transpose操作达成。
    Transpose操作的核心功能是实现张量维度重排,旨在适配 NPU 计算特性所规定的数据布局需求。在本场景中,其主要作用是完成从NHWC(Channel Last)到NCHW(Channel First)这两种内存布局的转换 。
    08_Pool2dGradKernel算子.png

流程图如下:
09_转置操作流程.png

变量声明

phi::DenseTensor transformed_out_grad;
通过临时变量隔离布局转换过程,保证原始数据的完整性

布局判断逻辑

接下来,程序执行条件判断if (channel_last)。该判断旨在检查输入数据是否采用NHWC格式。
if (channel_last)条件成立,即NHWC(Channel Last)格式时,程序将执行转置操作,把数据格式转换为通道优先Channel First的NCHW格式。

维度置换规则

std::vector<int> perm = {0, 3, 1, 2};
数学原理:对应张量维度[N,H,W,C]->[N,C,H,W]
定义了一个perm向量{0, 3, 1, 2},这应该是用来重新排列维度的顺序。原来的维度假设是NHWC(0,1,2,3),转置后变为NCHW(0,3,1,2)。

新形状构建

接着构造了out_grad_tensor_shape,调整形状以匹配新的维度顺序。调整后的形状是通过重新排列out_grad的维度得到的,例如将原维度[0]、[3]、[1]、[2]组合成新的形状。

std::vector<int> out_grad_tensor_shape = {
    out_grad.dims()[0],
    out_grad.dims()[3], 
    out_grad.dims()[1],
    out_grad.dims()[2],
};

通过维度复制而非引用保证形状独立性。

内存分配

然后将transformed_out_grad调整大小,分配内存,并通过TransposeKernel进行转置操作。

transformed_out_grad.Resize(phi::make_ddim(out_grad_tensor_shape));
dev_ctx.template Alloc<T>(&transformed_out_grad);

转置运算

custom_kernel::TransposeKernel<T, Context>(dev_ctx, out_grad, perm, &transformed_out_grad);
custom_kernel::TransposeKernel是调用NPU的转置内核函数,需在算子开发代码中进行函数声明,如图所示:
10_TransposeKernel声明.png

目录
相关文章
|
12月前
|
物联网
物联网卡:物联网卡停机多久会被注销
物联网卡(IoT SIM卡)的停机与注销政策通常取决于具体的服务提供商(如电信运营商、物联网平台提供商等)以及用户与这些服务提供商之间签订的合同条款。因此,没有一个统一的、适用于所有情况的规则来规定物联网卡停机多久后会被注销。 然而,一般来说,物联网卡的停机与注销可能遵循以下一些常见的逻辑或规定:
|
11月前
|
机器学习/深度学习 数据采集 传感器
使用Python实现深度学习模型:智能土壤质量监测与管理
使用Python实现深度学习模型:智能土壤质量监测与管理
623 69
|
12月前
|
运维 安全 Linux
IDC服务器故障排除思路
本文详细介绍了服务器维修流程,包括维修前的工具和备件准备,以及不拆机情况下的初步检查步骤。文中还提供了拆机维修的具体方法,如最小化测试法、替换法和交叉比较法,并针对CPU、主板、内存、硬盘、电源、风扇、网卡及BMC等主要配件的故障排除进行了说明,强调了注意事项,旨在帮助技术人员快速准确地定位并解决问题。
476 13
|
数据处理 UED
Axure中继器教程及案例详解
Axure RP 是一款强大的原型设计工具,广泛应用于产品设计、UI/UX 设计及交互设计中。中继器(Repeater)作为 Axure 中的一个重要元件,以其强大的数据处理和动态交互能力,成为设计师们不可或缺的工具。本文将从中继器基础、进阶、高级应用,以及分页控制、合计、列表拖动、列表滑动删除、表内修改等方面,详细介绍中继器的使用方法和案例。
306 6
Axure中继器教程及案例详解
|
12月前
|
Kubernetes 应用服务中间件 调度
k8s的Pod常见的几种调度形式
k8s的Pod常见的几种调度形式
179 0
|
缓存 负载均衡 架构师
优化大型数据处理系统的性能:从设计到实施
在数据驱动的世界中,大型数据处理系统的性能对企业运营至关重要。本文将探讨如何通过优化设计、选择合适的技术栈以及实施高效的策略来提升数据处理系统的性能。我们将深入分析数据库设计优化、并发处理、数据缓存策略、和数据流管理等关键领域,提供实际案例和技术建议,以帮助开发人员和系统架构师构建高效、可扩展的数据处理系统。
|
Kubernetes 并行计算 数据挖掘
构建高可用的数据分析平台:Dask 集群管理与部署
【8月更文第29天】随着数据量的不断增长,传统的单机数据分析方法已无法满足大规模数据处理的需求。Dask 是一个灵活的并行计算库,它能够帮助开发者轻松地在多核 CPU 或分布式集群上运行 Python 代码。本文将详细介绍如何搭建和管理 Dask 集群,以确保数据分析流程的稳定性和可靠性。
1139 3
|
数据建模 程序员
程序员必知:ZVS振荡电路工作原理分析
程序员必知:ZVS振荡电路工作原理分析
236 1
|
缓存 监控 安全
云服务器公网流量异常排查指南
云服务器公网流量异常排查指南
922 1
|
机器学习/深度学习 存储 人工智能
多模态系统的技术挑战
【1月更文挑战第18天】多模态系统的技术挑战
386 1
多模态系统的技术挑战