如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门(4)

简介: 如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门

Block 级优化


Block 在 GPU 上基本等同于不同的 kernel 在 GPU 上运行了,所以它们之间的联系并不是特别强烈。而它们之间的相互关系在 GEMM 语境下基本就只有 wave 和 L2 cache(一个 wave 里的 Block 共享这一块 cache)了,良好的 Block Tiling 能提升相当可观的 L2 cache 命中率。


但这一部分属于 sgemm 并不是特别关心的部分,因为本身 FFMA 单元算的就不是很快,所以 Block Tiling 随便搞搞就能够满足 FFMA 单元的带宽和延迟需求了。因此,这一节的内容主要是为了有些有用到 tensor core 的同学提一个需要注意的性能提升点,其次就是有些同学可能会发现自己写的 kernel 可能会比本文中的示例慢一点(大约 10% 左右),因此在此提一下在 sgemm 中应该怎么随便搞搞 Block Tiling。


Wave & L2 cache Hit Rate


首先明确一下 wave 的概念,即一个 GPU 上能够同时运行的 Block 数量。关于 GPU 是如何决定一个 wave 由哪些 Block 组成的我并没有找到非常明确的文档说明,但我一拍脑袋想,说不定就是朴素的按顺序决定的,即 index 处于 范围内的 Block 处在第一个 wave 中,后面的 Block 依此类推。后面试了试好像的确是这样划分的。


在明确了 wave 的概念后,我们便可以对 L2 cache 命中率做一个简单的分析了。我们指定 代表一个 wave 同时运行的 Block 数量,假设一个 wave 刚好能计算 C 矩阵的整数行,那么我们不难发现对于一个 wave 而言,它需要从 Global Memory 中读取 个 float。但由于有 L2 cache 的存在,假设一个 wave 读取的数据全能被 L2 cache 装下,那么实际只读取了 数据。最终 L2 cache 的命中率为:



差距越大,L2 cache 的命中就越低。那么如果想要去优化 L2 cache 命中,一个比较直接的想法就是尽可能把一个 wave 的 Block 变成方的。但就算不搞,sgemm 也不在乎,因为其实对性能来讲并没有什么区别,所以就没搞。



SGEMM Block Tiling


而在 sgemm 的语境下,假设最坏的情况即一个 wave 都不能覆盖目标矩阵 C 的一行,且 RTX 2080 有 46 个 SM,一个 SM 能跑两个 Block,此时




带入上式可得,此时 L2 cache 命中率大概是 49.4%。这里我们并没有考虑访问 C 矩阵的影响,在实践中会把 L2 cache 的命中率拉低一点。但即便是如此,前文我们分析过只要 L2 cache 命中达到 20%,在带宽上就不会造成性能瓶颈了。因此发现,就算我们采用朴素的 Block Tiling,Global Memory 访问也不会成为访存瓶颈。


但事实真的是这样吗?


细心的同学可能会发现,上图所采用的 tiling 方式并不是直觉上的用 blockIdx.x 表示 Block 在 M 维上的位置,而是用 blockIdx.y 表示 Block 在 M 上的位置。而我们简单调换一下代码中 blockIdx.x 与 blockIdx.y 的顺序,瞬间就有了 10% 左右的性能差距。目前网上并没有针对这个现象的分析(因为几乎所有人都是用的 col major 的 data layout,而且李少侠直接就在代码里使用了更优的 tiling 方式,所以没有人遇到这个问题),因此我这里提出一点个人的猜想,如果猜的不对还请指正。



L2 cache


首先我们看一下这两种 tiling 方式的区别在哪。最为直观的区别就是当 N 或 M 足够大时,采用上图中的 tiling 方式的 wave 形状是横着的,而另一种 tiling 方式的 wave 形状是竖着的,而这种竖着的形状看起来就不是 cache 友好的访存方式。


为什么这么说呢?因为我采用的是行主序的方式存储的矩阵,因此如果一个 wave 的形状是扁平的,那么每个 Block 在每一次循环遍历 B 矩阵时只会有 次 cache miss。这是由于 L2 cache 的 cache line 大小为 128 bytes,因此当数据从 Global Memory 中移动到 L2 cache 后,许多 Block 就能直接从 L2 cache 中读取数据了。然而如果一个 wave 的形状是狭长的。那么每个 Block 在第一次访问 A 矩阵的每一行时都会有 cache miss 的情况出现,即产生 次 cache miss,而后 31 次的遍历都不会有 cache miss。虽然两种 tiling 方式最终 cache miss 的次数是一样的,但这种短时间爆发的 cache miss 所带来的延迟是非常难被各种优化手段覆盖的,因为这种延迟不仅短时间内有很多次,同时每一次的延迟都很长,所以会造成性能损失。因此以后高性能代码的开发中,也要注意合理的把 cache miss 分配到 kernel 运行的各个阶段。


Bank Conflict


在查看两种 Tiling 方式的 profile 我发现,采用横着 Tiling 方式的 kernel bank conflict 更低一些。




等等,既然我们之前已经处理过 bank conflict 了,那么为什么这里还会有 bank conflict 呢?这个现象其实我也不是很清楚。但目前已知的是,在没有加 double buffer 情况下是没有 bank conflict 的,但加了 double buffer 之后或多或少会出现一些 bank conflict。那么至于为什么横着 Tiling 方式的 bank conflict 更低,我就更不知道了,因此这里还请各位 dalao 赐教。


最终版本的代码在这:https://github.com/AyakaGEMM/Hands-on-GEMM/blob/main/src/cuda/double_buffer_yhs_refine_gemm.cu


Epilogue


回顾本文,也基本达成了文章开头所立的各种 flag。当然现在还是有很多问题没有解决的,如 split K、长尾问题、分块细调等等,这些权当是一些未来展望了。近期也在尝试写一下 int8 tensor core 的矩阵乘,在较小形状上(M、N、K<=2048)能有比 cublas 更高的性能,但在更大形状上就只有 80% 左右了(这还是 L2 cache 命中率为 90% 的结果,可能还有啥别的没做好),所以就没有写 int8 tensor core 的部分。不过好歹是写完了!

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习基础之Tensor对象及其应用的讲解及实战(附源码 简单易懂 包括分段 映射 矩阵乘法 随机数等等)
PyTorch深度学习基础之Tensor对象及其应用的讲解及实战(附源码 简单易懂 包括分段 映射 矩阵乘法 随机数等等)
34 1
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习基础之Tensor的索引和切片讲解及实战(附源码 简单易懂)
PyTorch深度学习基础之Tensor的索引和切片讲解及实战(附源码 简单易懂)
79 0
|
4月前
|
存储 Windows
R 语言数值实验中常见技巧整理
R 语言数值实验中常见技巧整理
57 0
R 语言数值实验中常见技巧整理
|
11月前
|
机器学习/深度学习 缓存 自然语言处理
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
|
11月前
|
存储 并行计算 异构计算
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门(3)
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
132 0
|
11月前
|
存储 缓存 并行计算
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门(2)
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
134 0
|
11月前
|
存储 编解码 边缘计算
【CUDA学习笔记】第十一篇:FPS1000+的背景减法之目标跟踪(附实践源码下载)(一)
【CUDA学习笔记】第十一篇:FPS1000+的背景减法之目标跟踪(附实践源码下载)(一)
151 0
|
11月前
|
存储 并行计算 算法
【CUDA学习笔记】第十一篇:FPS1000+的背景减法之目标跟踪(附实践源码下载)(二)
【CUDA学习笔记】第十一篇:FPS1000+的背景减法之目标跟踪(附实践源码下载)(二)
92 0
|
11月前
|
机器学习/深度学习 存储 人工智能
【Pytorch神经网络理论篇】 28 DGLGraph图的基本操作(缺一部分 明天补)
DGLGraph图按照边的方向将度分为两种:连接其他顶点的度(out)和被其他顶点连接的度。
296 0
|
人工智能 缓存 移动开发
通用矩阵乘算法从入门到实践
通用矩阵乘算法从入门到实践
260 0