PyTorch & MMCV Dispatcher 机制解析

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 假设一个团队有一个项目经理和三个程序员,甲方正在疯狂地提各种需求,然后项目经理要做的就是根据每位程序员的专长,将不同的需求分配给不同的程序员来做,但是项目经理自己不会去实现需求,此时我们可以说,项目经理就是一个 Dispatcher。

最受欢迎的 MMCV 教程系列如约而至!

话不多说,进入正题~


1. 什么是 Dispatcher



假设一个团队有一个项目经理和三个程序员,甲方正在疯狂地提各种需求,然后项目经理要做的就是根据每位程序员的专长,将不同的需求分配给不同的程序员来做,但是项目经理自己不会去实现需求,此时我们可以说,项目经理就是一个 Dispatcher


Dispatcher,中文译为分派器,它的工作可以简单理解为只负责任务的分发,而不会参与到任务本身中来。但思考一下,如果这个团队的规模扩大,现在有三十位程序员,那这位项目经理可能就不得不加班工作了。这说明上述的 Dispatcher 方案不够优雅,那么有没有更优雅一点的方案呢?我们会在第三节注册 + 分发部分给大家介绍一种思路。


PyTorch 和 MMCV 的 Dispatcher 同样只负责任务的分发,他们将高层 API 分发到合适的底层实现。举个例子,相信很多人都会有疑问,在 PyTorch 中,当我们写下 torch.add(inputs) 时,PyTorch 是怎样让该函数根据 inputs 所在设备、数据布局、数据类型找到其实际应当执行的 kernel 的呢?我们会在第四节和第五节给大家简单解析 PyTorch 和 MMCV 的 Dispatcher 机制,并让大家了解 Dispatcher 在上述过程中发挥的重要作用。

2.  为什么需要 Dispatcher



从上面的描述来看,Dispathcer 只是一个美化的 if 语句:根据 inputs 的一些信息,决定应该调用哪一段代码,那么我们为什么需要 Dispatcher 呢?答案显而易见,是为了维持代码的可维护性和优雅性。


想想如果没有 Dispatcher,对于一个简单的 torch.add,我们就需要针对 CPU、GPU、TPU、FPGA 等不同设备,普通张量、稀疏张量(主要用于解决 3D 视觉、图神经网络等领域数据的稀疏性,以减少内存和计算量浪费 )等不同张量布局,float32、float16、int8 等不同数据类型,写下一大堆 if else,这实在是不可接受的。每一个 operator 我们都这样编写的话,将不可避免地造成代码的混乱、冗余、难以 Debug。因此,我们需要一个 Dispatcher,让它来统一管理分派工作。

3. 注册 + 分发



PyTorch 和 MMCV Dispatcher 都采用的是一种非常常见的架构设计方法——注册 + 分发,或者从设计模式的角度来讲可以叫注册器模式 + 工厂模式。使用这种模式可以实现强大的解耦性和扩展性,实际上 MMCV 的 Register 机制也可以看作这种模式的应用 MMCV 核心组件分析(五): Registry。笔者以为,除了解耦和扩展,这种模式还有一个重要的优点在于责任的下放。


《MMCV 核心组件分析:Registry 》全文可见 OpenMMLab 知乎:

https://zhuanlan.zhihu.com/p/355271993


我们举一个例子来解释什么是注册 + 分发以及什么是责任的下放:


已知复数有两种表示形式,一种是实部和虚部表示(rectangular),一种是模和幅角表示(polar)。那么如果我们要实现一个系统,同时支持这两种表示形式的加法运算,我们应该怎么做呢?我们会先把接口(add)和实现(rectangular-add、polar-add)分离,并给数据打上 tag,最终接口代码会是这样:

# 接口
add(z1, z2):
    if (z1.tag != z2.tag) 
        raise ValueError
    if(z1.tag == "rectangular"):
        return rectangular-add(z1, z2)
    else if(z1.tag == "polar"):
        return polar-add(z1, z2)
    else:
        raise ValueError

这里的 add 接口就像是一个 manager,他负责任务的分发。此时如果有一个人叫小明,他说我有另一种复数表示形式叫 ming 表示法,那么我们就需要在 add 里再加一个判断语句,那么这个 manager 的责任就会越来越重。


而采用注册 + 分发的机制,实际上会存在一张支持 register 和 get 操作的虚拟表格,像是这样:

640.png

每次有新的方法时,通过 register 注册到表里,然后我们的接口通过 get 取到对应的实现函数,此时接口就不再是 manager 了,它只需要从表中取出对应的函数并执行,而向表中注册的事情,交由编写具体的新方法的人来负责,最终接口代码会是这样:

add(z1, z2):
    if (z1.tag != z2.tag) 
        raise ValueError
    method = get(z1.tag)
    return method(z1, z2)

上述的例子来自讲座 Lec4b:通用运算符,讲座的前半段重点讲述类似注册 + 分发的机制,后半段则更进一步,将这种机制与递归结合,展示了在复杂系统中的强大威力,推荐大家观看。


4. PyTorch Dispatcher



设备、布局和类型


PyTorch 的核心数据结构是 Tensor,Tensor 可以简单看做由一些元数据描述的 n 维数据结构。其中最核心的三个元数据是 device、layout 和 dtype。


- device:表明 Tensor 的物理内存实际存储,CPU 或 GPU 或 TPU 或其他


- layout:决定如何在逻辑上解释这块 Tensor 的数据所在的这块连续的物理内存,普通张量或稀疏张量或其他


- dtype:表明数据的类型,float32 或者 float16 或者 int8 或者其他


Tensor 的数据本身和这三个元数据一起决定了这个张量在用户眼中到底是什么,而这三个元数据的笛卡尔积的大小也就决定了一个 torch operator 对应的 kernel 数量。


算子注册


PyTorch 的算子注册是通过一个二维表,竖轴是 PyTorch 中支持的每个运算符。横轴则是 PyTorch 支持的每个分派键(除了 opHandle),可以理解为不同硬件平台 + 不同张量布局 + 其他,一个分派键对应着一类 Tensor。算子注册就是填充单元格。


当我们在单个分派键上为单个运算符注册 kernel 时,我们填充单个单元格,比如:

TORCH_LIBRARY_IMPL(aten, CPU, m) {
  m.impl("mul", cpu_mul);
}

640.png

当我们在单个分派键上为所有运算符注册 kernel 时,我们填充该分派键的列,比如:

TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
  m.fallback(makeFallthrough());
}

640.png

其中,在单个分派键上为单个运算符注册的 kernel 具有更高的优先级。


PyTorch Dispatcher 实现


以下代码部分均基于 PyTorch 1.5,不过有些代码是编译时生成的,之后的 PyTorch 版本可能修改了许多源码,但核心思想应当是不变的。


当调用 torch.add 时,会发生两次分派。第一次分派针对 Tensor 的设备类型和布局,例如,它是 CPU Tensor 还是 CUDA Tensor,它是 Strided Tensor 还是 Sparse Tensor;第二次分派则是针对 Tensor 的类型。其中第二次分派只是用类似 AT_DISPATCH_ALL_TYPES 这样的宏包装了一个 switch 语句来完成分派,MMCV 中同样借用了 PyTorch 的这个轮子,比较简单,在此不再赘述。我们重点关注第一次分派。


当执行 torch.add() 时,通过 pybind11 (连接 Python 和 C++ 的桥梁) 来到 THPVariable_add 函数;然后经过多次跳转来到 at::add。

static inline Tensor add(const Tensor & self, const Tensor & other, Scalar alpha) {
    ...
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::add", "Tensor");
    return op.callUnboxed<Tensor, const Tensor &, const Tensor &, Scalar>(self, other, alpha);
}

在这里,我们通过 findSchemaOrThrow 得到了和 at::add  相对应的 OperatorHandle (即 3.3 节表中的 opHandle)。


每个 operator 和 OperatorHandle 是一一对应的,在每个 OperatorHandle 中都包含一张 "KF" 表,这张表相当于 3.3 节表中的一行,存储着分派键和对应的算子。然后我们会调用 callUnboxed,并跳转到 Dispatcher::callUnboxed。

inline Return Dispatcher::callUnboxedWithDispatchKey(const OperatorHandle& op, DispatchKey dispatchKey, Args... args) const {
    ...
    const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
    const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey);
    return kernel.template callUnboxed<Return, Args...>(op, std::forward<Args>(args)...);
}
inline Return Dispatcher::callUnboxed(const OperatorHandle& op, Args... args) const {
    ...
    const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
    auto dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed<Args...>(backendsWithoutFallthrough_, args...);
    return callUnboxedWithDispatchKey<Return, Args...>(op, dispatchKey, args...);
}

在 callUnboxed 中,我们通过  "KF" 表获取到对应的分派键,随后在 callUnboxedWithDispatchKey  中通过分派键索引到了相应的 KernelFunction ,之后 KernelFunction 内部会调用算子函数。


5. MMCV Dispatcher



算子扩展技术选择



MMCV 作为基于 PyTorch 和 Parrots(商汤自研框架)的框架,可以利用 PyTorch 提供的扩展方式,实现许多高质量的算子。而 PyTorch 主要提供了三种底层算子扩展的方式,native_functions 方式、Extension 方式和 OP Register 方式,那么为什么 MMCV 选择了 Extension 方式呢?下面我们分别分析一下 Pytorch 的三种扩展方式就能得到答案了。


native_functions


我们只需要按照 PyTorch 官方方式组织新算子的 kernel 实现,然后在 native_functions.yaml 和 derivatives.yaml 添加配置信息,就可以自动生成:torch.xxx()、torch.nn.functional.xxx()、tensor.xxx() 方法。


看起来很方便,但是该方法与 PyTorch 的耦合度过高,实际上是修改 PyTorch 源码,所以每次增加或者修改算子都需要重新编译 PyTorch,成本较高。而且 MMCV 的定位是基于 PyTorch,而不是 PyTorch 的魔改版,所以不采用该方案。


OP register


该方式与 PyTorch 源码解耦,增加和修改算子不需要重新编译 PyTorch 源码。我们只需要先编写新算子的 kernel 实现,然后通过 PyTorch 底层的注册接口torch::RegisterOperators 将该算子注册即可。如下所示:

namespace {
Tensor my_kernel_cpu(const Tensor& a, const Tensor& b) {...}
Tensor my_kernel_cuda(const Tensor& a, const Tensor& b) {...}
}
static auto registry = torch::RegisterOperators()
   .op("my_namespace::my_op",  torch::RegisterOperators::options()
       .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(CPU()))
   .op("my_namespace::my_op",  torch::RegisterOperators::options()
       .kernel<decltype(my_kernel_cuda), &my_kernel_cuda>(CUDA()));

但是该方法也有一个问题,就是如果要增加新硬件平台对应的算子,那么需要首先在 PyTorch 源码中增加对新硬件的支持,之后才能借助torch::RegisterOperators 实现算子注册。MMCV 正积极与各硬件厂商合作,已支持曙光 DCU,也正在与寒武纪 MLU 开展合作,被动地等待 Pytorch 官方加入这些新硬件显然不是一个好的策略,因此不采用该方案。


Extension


该方案同样与 PyTorch 源码解耦,增加和修改算子不需要重新编译 PyTorch 源码。它的原理是通过 pybind11,将 C++(CUDA) 编译为 PyTorch 的一个模块,更多内容可见 揭秘 C++/CUDA 算子实现和调用全流程。同时硬件厂商也可以通过和 MMCV 合作,提供自家的 Extension(如寒武纪的 MLUExtension)模块,较为简单地实现在 MMCV 中支持新硬件。


MMCV Dispatcher 实现


为什么要讲 MMCV 算子扩展技术选择呢?主要还是为了让大家理解 MMCV 的特点——基于 PyTorch 和 Parrots、提供算子扩展、支持多后端设备。MMCV 的特点决定了我们会尽量利用 PyTorch 的设计思想和"轮子",但不会使用 PyTorch 那样复杂的 Dispatcher 机制,力求简单、灵活。MMCV 和 PyTorch 一样会发生两次分派,我们也是主要关注第一次分派。


MMCV Dispathcer 的核心就是 pytorch_device_registry.hpp 这个文件,核心内容有三部分:


第一部分是 DeviceRegistry 类,主要包括一个单例模式的方法 instance()、实现设备和算子函数绑定的 Register()、根据设备查找对应算子的Find()。

// Registry
template <typename F, F f>
class DeviceRegistry;
template <typename Ret, typename... Args, Ret (*f)(Args...)>
class DeviceRegistry<Ret (*)(Args...), f> {
 public:
  using FunctionType = Ret (*)(Args...);
  static const int MAX_DEVICE_TYPES =
      int8_t(at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
  void Register(at::DeviceType device, FunctionType function) {
    funcs_[int8_t(device)] = function;
  }
  FunctionType Find(at::DeviceType device) const {
    return funcs_[int8_t(device)];
  }
  static DeviceRegistry& instance() {
    static DeviceRegistry inst;
    return inst;
  }
 private:
  DeviceRegistry() {
    for (size_t i = 0; i < MAX_DEVICE_TYPES; ++i) {
      funcs_[i] = nullptr;
    }
  };
  FunctionType funcs_[MAX_DEVICE_TYPES];
};

第二部分是 Dispatch 函数,主要逻辑为首先获取第一个 Tensor 的设备,然后检查全部 Tensor 的设备一致性,之后根据设备找到对应的函数(指针),最后执行函数,中间会通过 TORCH_CHECK 做检查工作:

// dispatch
template <typename R, typename... Args>
auto Dispatch(const R& registry, const char* name, Args&&... args) {
  auto device = GetFirstTensorDevice(std::forward<Args>(args)...);
  auto inconsist =
      CheckDeviceConsistency(device, 0, std::forward<Args>(args)...);
  TORCH_CHECK(inconsist.first >= int(sizeof...(Args)), name, ": at param ",
              inconsist.first,
              ", inconsistent device: ", GetDeviceStr(inconsist.second).c_str(),
              " vs ", GetDeviceStr(device).c_str(), "\n")
  auto f_ptr = registry.Find(device.type());
  TORCH_CHECK(f_ptr != nullptr, name, ": implementation for device ",
              GetDeviceStr(device).c_str(), " not found.\n")
  return f_ptr(std::forward<Args>(args)...);
}

第三部分主要是一些宏,根据前两部分的代码就能轻松理解这一部分,在此不再赘述。

// helper macro
#define DEVICE_REGISTRY(key) DeviceRegistry<decltype(&(key)), key>::instance()
#define REGISTER_DEVICE_IMPL(key, device, value)           \
  struct key##_##device##_registerer {                     \
    key##_##device##_registerer() {                        \
      DEVICE_REGISTRY(key).Register(at::k##device, value); \
    }                                                      \
  };                                                       \
  static key##_##device##_registerer _##key##_##device##_registerer;
#define DISPATCH_DEVICE_IMPL(key, ...) \
  Dispatch(DEVICE_REGISTRY(key), #key, __VA_ARGS__)

通过使用 Dispatcher,MMCV 重构了 CSRC 目录的代码,详情见 Refactor csrc with device dispatcher

重构前:

void roi_align_forward(Tensor input, Tensor rois, Tensor output,
                       Tensor argmax_y, Tensor argmax_x, int aligned_height,
                       int aligned_width, float spatial_scale,
                       int sampling_ratio, int pool_mode, bool aligned) {
  if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(rois);
    CHECK_CUDA_INPUT(output);
    CHECK_CUDA_INPUT(argmax_y);
    CHECK_CUDA_INPUT(argmax_x);
    roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x,
                           aligned_height, aligned_width, spatial_scale,
                           sampling_ratio, pool_mode, aligned);
#else
    AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
  } else {
    CHECK_CPU_INPUT(input);
    CHECK_CPU_INPUT(rois);
    CHECK_CPU_INPUT(output);
    CHECK_CPU_INPUT(argmax_y);
    CHECK_CPU_INPUT(argmax_x);
    roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
                          aligned_height, aligned_width, spatial_scale,
                          sampling_ratio, pool_mode, aligned);
  }
}

重构后:

// 注册算子的cuda实现
REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
// roi_align.cpp
// 使用dispatcher根据参数中的Tensor device类型对实现进行分发
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
                            Tensor argmax_y, Tensor argmax_x,
                            int aligned_height, int aligned_width,
                            float spatial_scale, int sampling_ratio,
                            int pool_mode, bool aligned) {
  DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y,
                       argmax_x, aligned_height, aligned_width, spatial_scale,
                       sampling_ratio, pool_mode, aligned);
}

文章来源:公众号【OpenMMLab】

2021-12-30 19:58



目录
相关文章
|
14天前
|
传感器 C# Android开发
深度解析Uno Platform中的事件处理机制与交互设计艺术:从理论到实践的全方位指南,助您构建响应迅速、交互流畅的跨平台应用
Uno Platform 是一款开源框架,支持使用 C# 和 XAML 开发跨平台原生 UI 应用,兼容 Windows、iOS、Android 及 WebAssembly。本文将介绍 Uno Platform 中高效的事件处理方法,并通过示例代码展示交互设计的核心原则与实践技巧,帮助提升应用的用户体验。事件处理让应用能响应用户输入,如点击、触摸及传感器数据变化。通过 XAML 或 C# 添加事件处理器,可确保及时反馈用户操作。示例代码展示了一个按钮点击事件处理过程。此外,还可运用动画和过渡效果进一步增强应用交互性。
127 57
|
5天前
|
移动开发 Android开发 数据安全/隐私保护
移动应用与系统的技术演进:从开发到操作系统的全景解析随着智能手机和平板电脑的普及,移动应用(App)已成为人们日常生活中不可或缺的一部分。无论是社交、娱乐、购物还是办公,移动应用都扮演着重要的角色。而支撑这些应用运行的,正是功能强大且复杂的移动操作系统。本文将深入探讨移动应用的开发过程及其背后的操作系统机制,揭示这一领域的技术演进。
本文旨在提供关于移动应用与系统技术的全面概述,涵盖移动应用的开发生命周期、主要移动操作系统的特点以及它们之间的竞争关系。我们将探讨如何高效地开发移动应用,并分析iOS和Android两大主流操作系统的技术优势与局限。同时,本文还将讨论跨平台解决方案的兴起及其对移动开发领域的影响。通过这篇技术性文章,读者将获得对移动应用开发及操作系统深层理解的钥匙。
|
5天前
|
存储 关系型数据库 MySQL
深入解析MySQL数据存储机制:从表结构到物理存储
深入解析MySQL数据存储机制:从表结构到物理存储
14 1
|
9天前
|
Java 开发者
Java中的异常处理机制深度解析
在Java编程中,异常处理是保证程序稳定性和健壮性的重要手段。本文将深入探讨Java的异常处理机制,包括异常的分类、捕获与处理、自定义异常以及一些最佳实践。通过详细讲解和代码示例,帮助读者更好地理解和应用这一机制,提升代码质量。
12 1
|
15天前
|
存储 缓存 Android开发
Android RecyclerView 缓存机制深度解析与面试题
本文首发于公众号“AntDream”,详细解析了 `RecyclerView` 的缓存机制,包括多级缓存的原理与流程,并提供了常见面试题及答案。通过本文,你将深入了解 `RecyclerView` 的高性能秘诀,提升列表和网格的开发技能。
39 8
|
18天前
|
Java 程序员 开发者
Java中的异常处理机制深度解析
本文旨在深入探讨Java中异常处理的核心概念与实际应用,通过剖析异常的本质、分类、捕获及处理方法,揭示其在程序设计中的关键作用。不同于常规摘要,本文将直接切入主题,以简明扼要的方式概述异常处理的重要性及其在Java编程中的应用策略,引导读者快速把握异常处理的精髓。
|
17天前
|
安全 Java 开发者
Java并发编程中的锁机制解析
本文深入探讨了Java中用于管理多线程同步的关键工具——锁机制。通过分析synchronized关键字和ReentrantLock类等核心概念,揭示了它们在构建线程安全应用中的重要性。同时,文章还讨论了锁机制的高级特性,如公平性、类锁和对象锁的区别,以及锁的优化技术如锁粗化和锁消除。此外,指出了在高并发环境下锁竞争可能导致的问题,并提出了减少锁持有时间和使用无锁编程等策略来优化性能的建议。最后,强调了理解和正确使用Java锁机制对于开发高效、可靠并发应用程序的重要性。
16 3
|
21天前
|
Java 开发者
深入解析Java中的异常处理机制
本文将深入探讨Java中异常处理的核心概念和实际应用,包括异常的分类、捕获、处理以及最佳实践。我们将通过具体示例展示如何有效使用try-catch块、throws关键字和自定义异常类,以帮助读者更好地理解和应用Java异常处理机制。
12 1
|
21天前
|
Java 程序员 开发者
Java中的异常处理机制深度解析
本文旨在深入探讨Java中异常处理的机制,包括异常的分类、如何捕获和处理异常,以及自定义异常的最佳实践。通过实例讲解,帮助读者更好地理解如何在Java编程中有效管理和利用异常处理来提高代码的健壮性和可维护性。
|
2月前
|
机器学习/深度学习 人工智能 PyTorch
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
PyTorch提供了几种张量乘法的方法,每种方法都是不同的,并且有不同的应用。我们来详细介绍每个方法,并且详细解释这些函数有什么区别:
44 4
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析

热门文章

最新文章

下一篇
无影云桌面