PyTorch & MMCV Dispatcher 机制解析

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
简介: 假设一个团队有一个项目经理和三个程序员,甲方正在疯狂地提各种需求,然后项目经理要做的就是根据每位程序员的专长,将不同的需求分配给不同的程序员来做,但是项目经理自己不会去实现需求,此时我们可以说,项目经理就是一个 Dispatcher。

最受欢迎的 MMCV 教程系列如约而至!

话不多说,进入正题~


1. 什么是 Dispatcher



假设一个团队有一个项目经理和三个程序员,甲方正在疯狂地提各种需求,然后项目经理要做的就是根据每位程序员的专长,将不同的需求分配给不同的程序员来做,但是项目经理自己不会去实现需求,此时我们可以说,项目经理就是一个 Dispatcher


Dispatcher,中文译为分派器,它的工作可以简单理解为只负责任务的分发,而不会参与到任务本身中来。但思考一下,如果这个团队的规模扩大,现在有三十位程序员,那这位项目经理可能就不得不加班工作了。这说明上述的 Dispatcher 方案不够优雅,那么有没有更优雅一点的方案呢?我们会在第三节注册 + 分发部分给大家介绍一种思路。


PyTorch 和 MMCV 的 Dispatcher 同样只负责任务的分发,他们将高层 API 分发到合适的底层实现。举个例子,相信很多人都会有疑问,在 PyTorch 中,当我们写下 torch.add(inputs) 时,PyTorch 是怎样让该函数根据 inputs 所在设备、数据布局、数据类型找到其实际应当执行的 kernel 的呢?我们会在第四节和第五节给大家简单解析 PyTorch 和 MMCV 的 Dispatcher 机制,并让大家了解 Dispatcher 在上述过程中发挥的重要作用。

2.  为什么需要 Dispatcher



从上面的描述来看,Dispathcer 只是一个美化的 if 语句:根据 inputs 的一些信息,决定应该调用哪一段代码,那么我们为什么需要 Dispatcher 呢?答案显而易见,是为了维持代码的可维护性和优雅性。


想想如果没有 Dispatcher,对于一个简单的 torch.add,我们就需要针对 CPU、GPU、TPU、FPGA 等不同设备,普通张量、稀疏张量(主要用于解决 3D 视觉、图神经网络等领域数据的稀疏性,以减少内存和计算量浪费 )等不同张量布局,float32、float16、int8 等不同数据类型,写下一大堆 if else,这实在是不可接受的。每一个 operator 我们都这样编写的话,将不可避免地造成代码的混乱、冗余、难以 Debug。因此,我们需要一个 Dispatcher,让它来统一管理分派工作。

3. 注册 + 分发



PyTorch 和 MMCV Dispatcher 都采用的是一种非常常见的架构设计方法——注册 + 分发,或者从设计模式的角度来讲可以叫注册器模式 + 工厂模式。使用这种模式可以实现强大的解耦性和扩展性,实际上 MMCV 的 Register 机制也可以看作这种模式的应用 MMCV 核心组件分析(五): Registry。笔者以为,除了解耦和扩展,这种模式还有一个重要的优点在于责任的下放。


《MMCV 核心组件分析:Registry 》全文可见 OpenMMLab 知乎:

https://zhuanlan.zhihu.com/p/355271993


我们举一个例子来解释什么是注册 + 分发以及什么是责任的下放:


已知复数有两种表示形式,一种是实部和虚部表示(rectangular),一种是模和幅角表示(polar)。那么如果我们要实现一个系统,同时支持这两种表示形式的加法运算,我们应该怎么做呢?我们会先把接口(add)和实现(rectangular-add、polar-add)分离,并给数据打上 tag,最终接口代码会是这样:

# 接口
add(z1, z2):
    if (z1.tag != z2.tag) 
        raise ValueError
    if(z1.tag == "rectangular"):
        return rectangular-add(z1, z2)
    else if(z1.tag == "polar"):
        return polar-add(z1, z2)
    else:
        raise ValueError

这里的 add 接口就像是一个 manager,他负责任务的分发。此时如果有一个人叫小明,他说我有另一种复数表示形式叫 ming 表示法,那么我们就需要在 add 里再加一个判断语句,那么这个 manager 的责任就会越来越重。


而采用注册 + 分发的机制,实际上会存在一张支持 register 和 get 操作的虚拟表格,像是这样:

640.png

每次有新的方法时,通过 register 注册到表里,然后我们的接口通过 get 取到对应的实现函数,此时接口就不再是 manager 了,它只需要从表中取出对应的函数并执行,而向表中注册的事情,交由编写具体的新方法的人来负责,最终接口代码会是这样:

add(z1, z2):
    if (z1.tag != z2.tag) 
        raise ValueError
    method = get(z1.tag)
    return method(z1, z2)

上述的例子来自讲座 Lec4b:通用运算符,讲座的前半段重点讲述类似注册 + 分发的机制,后半段则更进一步,将这种机制与递归结合,展示了在复杂系统中的强大威力,推荐大家观看。


4. PyTorch Dispatcher



设备、布局和类型


PyTorch 的核心数据结构是 Tensor,Tensor 可以简单看做由一些元数据描述的 n 维数据结构。其中最核心的三个元数据是 device、layout 和 dtype。


- device:表明 Tensor 的物理内存实际存储,CPU 或 GPU 或 TPU 或其他


- layout:决定如何在逻辑上解释这块 Tensor 的数据所在的这块连续的物理内存,普通张量或稀疏张量或其他


- dtype:表明数据的类型,float32 或者 float16 或者 int8 或者其他


Tensor 的数据本身和这三个元数据一起决定了这个张量在用户眼中到底是什么,而这三个元数据的笛卡尔积的大小也就决定了一个 torch operator 对应的 kernel 数量。


算子注册


PyTorch 的算子注册是通过一个二维表,竖轴是 PyTorch 中支持的每个运算符。横轴则是 PyTorch 支持的每个分派键(除了 opHandle),可以理解为不同硬件平台 + 不同张量布局 + 其他,一个分派键对应着一类 Tensor。算子注册就是填充单元格。


当我们在单个分派键上为单个运算符注册 kernel 时,我们填充单个单元格,比如:

TORCH_LIBRARY_IMPL(aten, CPU, m) {
  m.impl("mul", cpu_mul);
}

640.png

当我们在单个分派键上为所有运算符注册 kernel 时,我们填充该分派键的列,比如:

TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
  m.fallback(makeFallthrough());
}

640.png

其中,在单个分派键上为单个运算符注册的 kernel 具有更高的优先级。


PyTorch Dispatcher 实现


以下代码部分均基于 PyTorch 1.5,不过有些代码是编译时生成的,之后的 PyTorch 版本可能修改了许多源码,但核心思想应当是不变的。


当调用 torch.add 时,会发生两次分派。第一次分派针对 Tensor 的设备类型和布局,例如,它是 CPU Tensor 还是 CUDA Tensor,它是 Strided Tensor 还是 Sparse Tensor;第二次分派则是针对 Tensor 的类型。其中第二次分派只是用类似 AT_DISPATCH_ALL_TYPES 这样的宏包装了一个 switch 语句来完成分派,MMCV 中同样借用了 PyTorch 的这个轮子,比较简单,在此不再赘述。我们重点关注第一次分派。


当执行 torch.add() 时,通过 pybind11 (连接 Python 和 C++ 的桥梁) 来到 THPVariable_add 函数;然后经过多次跳转来到 at::add。

static inline Tensor add(const Tensor & self, const Tensor & other, Scalar alpha) {
    ...
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::add", "Tensor");
    return op.callUnboxed<Tensor, const Tensor &, const Tensor &, Scalar>(self, other, alpha);
}

在这里,我们通过 findSchemaOrThrow 得到了和 at::add  相对应的 OperatorHandle (即 3.3 节表中的 opHandle)。


每个 operator 和 OperatorHandle 是一一对应的,在每个 OperatorHandle 中都包含一张 "KF" 表,这张表相当于 3.3 节表中的一行,存储着分派键和对应的算子。然后我们会调用 callUnboxed,并跳转到 Dispatcher::callUnboxed。

inline Return Dispatcher::callUnboxedWithDispatchKey(const OperatorHandle& op, DispatchKey dispatchKey, Args... args) const {
    ...
    const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
    const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey);
    return kernel.template callUnboxed<Return, Args...>(op, std::forward<Args>(args)...);
}
inline Return Dispatcher::callUnboxed(const OperatorHandle& op, Args... args) const {
    ...
    const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
    auto dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed<Args...>(backendsWithoutFallthrough_, args...);
    return callUnboxedWithDispatchKey<Return, Args...>(op, dispatchKey, args...);
}

在 callUnboxed 中,我们通过  "KF" 表获取到对应的分派键,随后在 callUnboxedWithDispatchKey  中通过分派键索引到了相应的 KernelFunction ,之后 KernelFunction 内部会调用算子函数。


5. MMCV Dispatcher



算子扩展技术选择



MMCV 作为基于 PyTorch 和 Parrots(商汤自研框架)的框架,可以利用 PyTorch 提供的扩展方式,实现许多高质量的算子。而 PyTorch 主要提供了三种底层算子扩展的方式,native_functions 方式、Extension 方式和 OP Register 方式,那么为什么 MMCV 选择了 Extension 方式呢?下面我们分别分析一下 Pytorch 的三种扩展方式就能得到答案了。


native_functions


我们只需要按照 PyTorch 官方方式组织新算子的 kernel 实现,然后在 native_functions.yaml 和 derivatives.yaml 添加配置信息,就可以自动生成:torch.xxx()、torch.nn.functional.xxx()、tensor.xxx() 方法。


看起来很方便,但是该方法与 PyTorch 的耦合度过高,实际上是修改 PyTorch 源码,所以每次增加或者修改算子都需要重新编译 PyTorch,成本较高。而且 MMCV 的定位是基于 PyTorch,而不是 PyTorch 的魔改版,所以不采用该方案。


OP register


该方式与 PyTorch 源码解耦,增加和修改算子不需要重新编译 PyTorch 源码。我们只需要先编写新算子的 kernel 实现,然后通过 PyTorch 底层的注册接口torch::RegisterOperators 将该算子注册即可。如下所示:

namespace {
Tensor my_kernel_cpu(const Tensor& a, const Tensor& b) {...}
Tensor my_kernel_cuda(const Tensor& a, const Tensor& b) {...}
}
static auto registry = torch::RegisterOperators()
   .op("my_namespace::my_op",  torch::RegisterOperators::options()
       .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(CPU()))
   .op("my_namespace::my_op",  torch::RegisterOperators::options()
       .kernel<decltype(my_kernel_cuda), &my_kernel_cuda>(CUDA()));

但是该方法也有一个问题,就是如果要增加新硬件平台对应的算子,那么需要首先在 PyTorch 源码中增加对新硬件的支持,之后才能借助torch::RegisterOperators 实现算子注册。MMCV 正积极与各硬件厂商合作,已支持曙光 DCU,也正在与寒武纪 MLU 开展合作,被动地等待 Pytorch 官方加入这些新硬件显然不是一个好的策略,因此不采用该方案。


Extension


该方案同样与 PyTorch 源码解耦,增加和修改算子不需要重新编译 PyTorch 源码。它的原理是通过 pybind11,将 C++(CUDA) 编译为 PyTorch 的一个模块,更多内容可见 揭秘 C++/CUDA 算子实现和调用全流程。同时硬件厂商也可以通过和 MMCV 合作,提供自家的 Extension(如寒武纪的 MLUExtension)模块,较为简单地实现在 MMCV 中支持新硬件。


MMCV Dispatcher 实现


为什么要讲 MMCV 算子扩展技术选择呢?主要还是为了让大家理解 MMCV 的特点——基于 PyTorch 和 Parrots、提供算子扩展、支持多后端设备。MMCV 的特点决定了我们会尽量利用 PyTorch 的设计思想和"轮子",但不会使用 PyTorch 那样复杂的 Dispatcher 机制,力求简单、灵活。MMCV 和 PyTorch 一样会发生两次分派,我们也是主要关注第一次分派。


MMCV Dispathcer 的核心就是 pytorch_device_registry.hpp 这个文件,核心内容有三部分:


第一部分是 DeviceRegistry 类,主要包括一个单例模式的方法 instance()、实现设备和算子函数绑定的 Register()、根据设备查找对应算子的Find()。

// Registry
template <typename F, F f>
class DeviceRegistry;
template <typename Ret, typename... Args, Ret (*f)(Args...)>
class DeviceRegistry<Ret (*)(Args...), f> {
 public:
  using FunctionType = Ret (*)(Args...);
  static const int MAX_DEVICE_TYPES =
      int8_t(at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
  void Register(at::DeviceType device, FunctionType function) {
    funcs_[int8_t(device)] = function;
  }
  FunctionType Find(at::DeviceType device) const {
    return funcs_[int8_t(device)];
  }
  static DeviceRegistry& instance() {
    static DeviceRegistry inst;
    return inst;
  }
 private:
  DeviceRegistry() {
    for (size_t i = 0; i < MAX_DEVICE_TYPES; ++i) {
      funcs_[i] = nullptr;
    }
  };
  FunctionType funcs_[MAX_DEVICE_TYPES];
};

第二部分是 Dispatch 函数,主要逻辑为首先获取第一个 Tensor 的设备,然后检查全部 Tensor 的设备一致性,之后根据设备找到对应的函数(指针),最后执行函数,中间会通过 TORCH_CHECK 做检查工作:

// dispatch
template <typename R, typename... Args>
auto Dispatch(const R& registry, const char* name, Args&&... args) {
  auto device = GetFirstTensorDevice(std::forward<Args>(args)...);
  auto inconsist =
      CheckDeviceConsistency(device, 0, std::forward<Args>(args)...);
  TORCH_CHECK(inconsist.first >= int(sizeof...(Args)), name, ": at param ",
              inconsist.first,
              ", inconsistent device: ", GetDeviceStr(inconsist.second).c_str(),
              " vs ", GetDeviceStr(device).c_str(), "\n")
  auto f_ptr = registry.Find(device.type());
  TORCH_CHECK(f_ptr != nullptr, name, ": implementation for device ",
              GetDeviceStr(device).c_str(), " not found.\n")
  return f_ptr(std::forward<Args>(args)...);
}

第三部分主要是一些宏,根据前两部分的代码就能轻松理解这一部分,在此不再赘述。

// helper macro
#define DEVICE_REGISTRY(key) DeviceRegistry<decltype(&(key)), key>::instance()
#define REGISTER_DEVICE_IMPL(key, device, value)           \
  struct key##_##device##_registerer {                     \
    key##_##device##_registerer() {                        \
      DEVICE_REGISTRY(key).Register(at::k##device, value); \
    }                                                      \
  };                                                       \
  static key##_##device##_registerer _##key##_##device##_registerer;
#define DISPATCH_DEVICE_IMPL(key, ...) \
  Dispatch(DEVICE_REGISTRY(key), #key, __VA_ARGS__)

通过使用 Dispatcher,MMCV 重构了 CSRC 目录的代码,详情见 Refactor csrc with device dispatcher

重构前:

void roi_align_forward(Tensor input, Tensor rois, Tensor output,
                       Tensor argmax_y, Tensor argmax_x, int aligned_height,
                       int aligned_width, float spatial_scale,
                       int sampling_ratio, int pool_mode, bool aligned) {
  if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(rois);
    CHECK_CUDA_INPUT(output);
    CHECK_CUDA_INPUT(argmax_y);
    CHECK_CUDA_INPUT(argmax_x);
    roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x,
                           aligned_height, aligned_width, spatial_scale,
                           sampling_ratio, pool_mode, aligned);
#else
    AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
  } else {
    CHECK_CPU_INPUT(input);
    CHECK_CPU_INPUT(rois);
    CHECK_CPU_INPUT(output);
    CHECK_CPU_INPUT(argmax_y);
    CHECK_CPU_INPUT(argmax_x);
    roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
                          aligned_height, aligned_width, spatial_scale,
                          sampling_ratio, pool_mode, aligned);
  }
}

重构后:

// 注册算子的cuda实现
REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
// roi_align.cpp
// 使用dispatcher根据参数中的Tensor device类型对实现进行分发
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
                            Tensor argmax_y, Tensor argmax_x,
                            int aligned_height, int aligned_width,
                            float spatial_scale, int sampling_ratio,
                            int pool_mode, bool aligned) {
  DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y,
                       argmax_x, aligned_height, aligned_width, spatial_scale,
                       sampling_ratio, pool_mode, aligned);
}

文章来源:公众号【OpenMMLab】

2021-12-30 19:58



目录
相关文章
|
2月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
103 2
|
3月前
|
存储 缓存 算法
分布式锁服务深度解析:以Apache Flink的Checkpointing机制为例
【10月更文挑战第7天】在分布式系统中,多个进程或节点可能需要同时访问和操作共享资源。为了确保数据的一致性和系统的稳定性,我们需要一种机制来协调这些进程或节点的访问,避免并发冲突和竞态条件。分布式锁服务正是为此而生的一种解决方案。它通过在网络环境中实现锁机制,确保同一时间只有一个进程或节点能够访问和操作共享资源。
129 3
|
3天前
|
机器学习/深度学习 自然语言处理 搜索推荐
自注意力机制全解析:从原理到计算细节,一文尽览!
自注意力机制(Self-Attention)最早可追溯至20世纪70年代的神经网络研究,但直到2017年Google Brain团队提出Transformer架构后才广泛应用于深度学习。它通过计算序列内部元素间的相关性,捕捉复杂依赖关系,并支持并行化训练,显著提升了处理长文本和序列数据的能力。相比传统的RNN、LSTM和GRU,自注意力机制在自然语言处理(NLP)、计算机视觉、语音识别及推荐系统等领域展现出卓越性能。其核心步骤包括生成查询(Q)、键(K)和值(V)向量,计算缩放点积注意力得分,应用Softmax归一化,以及加权求和生成输出。自注意力机制提高了模型的表达能力,带来了更精准的服务。
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
65 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
1月前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
1月前
|
PHP 开发者 UED
PHP中的异常处理机制解析####
本文深入探讨了PHP中的异常处理机制,通过实例解析try-catch语句的用法,并对比传统错误处理方式,揭示其在提升代码健壮性与可维护性方面的优势。文章还简要介绍了自定义异常类的创建及其应用场景,为开发者提供实用的技术参考。 ####
|
2月前
|
存储 缓存 监控
后端开发中的缓存机制:深度解析与最佳实践####
本文深入探讨了后端开发中不可或缺的一环——缓存机制,旨在为读者提供一份详尽的指南,涵盖缓存的基本原理、常见类型(如内存缓存、磁盘缓存、分布式缓存等)、主流技术选型(Redis、Memcached、Ehcache等),以及在实际项目中如何根据业务需求设计并实施高效的缓存策略。不同于常规摘要的概述性质,本摘要直接点明文章将围绕“深度解析”与“最佳实践”两大核心展开,既适合初学者构建基础认知框架,也为有经验的开发者提供优化建议与实战技巧。 ####
|
2月前
|
缓存 NoSQL Java
千万级电商线上无阻塞双buffer缓冲优化ID生成机制深度解析
【11月更文挑战第30天】在千万级电商系统中,ID生成机制是核心基础设施之一。一个高效、可靠的ID生成系统对于保障系统的稳定性和性能至关重要。本文将深入探讨一种在千万级电商线上广泛应用的ID生成机制——无阻塞双buffer缓冲优化方案。本文从概述、功能点、背景、业务点、底层原理等多个维度进行解析,并通过Java语言实现多个示例,指出各自实践的优缺点。希望给需要的同学提供一些参考。
55 7
|
1月前
|
Java 数据库连接 开发者
Java中的异常处理机制:深入解析与最佳实践####
本文旨在为Java开发者提供一份关于异常处理机制的全面指南,从基础概念到高级技巧,涵盖try-catch结构、自定义异常、异常链分析以及最佳实践策略。不同于传统的摘要概述,本文将以一个实际项目案例为线索,逐步揭示如何高效地管理运行时错误,提升代码的健壮性和可维护性。通过对比常见误区与优化方案,读者将获得编写更加健壮Java应用程序的实用知识。 --- ####
|
2月前
|
Java 开发者 Spring
深入解析:Spring AOP的底层实现机制
在现代软件开发中,Spring框架的AOP(面向切面编程)功能因其能够有效分离横切关注点(如日志记录、事务管理等)而备受青睐。本文将深入探讨Spring AOP的底层原理,揭示其如何通过动态代理技术实现方法的增强。
85 8

推荐镜像

更多