如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门(4)

简介: 如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门

Block 级优化


Block 在 GPU 上基本等同于不同的 kernel 在 GPU 上运行了,所以它们之间的联系并不是特别强烈。而它们之间的相互关系在 GEMM 语境下基本就只有 wave 和 L2 cache(一个 wave 里的 Block 共享这一块 cache)了,良好的 Block Tiling 能提升相当可观的 L2 cache 命中率。


但这一部分属于 sgemm 并不是特别关心的部分,因为本身 FFMA 单元算的就不是很快,所以 Block Tiling 随便搞搞就能够满足 FFMA 单元的带宽和延迟需求了。因此,这一节的内容主要是为了有些有用到 tensor core 的同学提一个需要注意的性能提升点,其次就是有些同学可能会发现自己写的 kernel 可能会比本文中的示例慢一点(大约 10% 左右),因此在此提一下在 sgemm 中应该怎么随便搞搞 Block Tiling。


Wave & L2 cache Hit Rate


首先明确一下 wave 的概念,即一个 GPU 上能够同时运行的 Block 数量。关于 GPU 是如何决定一个 wave 由哪些 Block 组成的我并没有找到非常明确的文档说明,但我一拍脑袋想,说不定就是朴素的按顺序决定的,即 index 处于 范围内的 Block 处在第一个 wave 中,后面的 Block 依此类推。后面试了试好像的确是这样划分的。


在明确了 wave 的概念后,我们便可以对 L2 cache 命中率做一个简单的分析了。我们指定 代表一个 wave 同时运行的 Block 数量,假设一个 wave 刚好能计算 C 矩阵的整数行,那么我们不难发现对于一个 wave 而言,它需要从 Global Memory 中读取 个 float。但由于有 L2 cache 的存在,假设一个 wave 读取的数据全能被 L2 cache 装下,那么实际只读取了 数据。最终 L2 cache 的命中率为:



差距越大,L2 cache 的命中就越低。那么如果想要去优化 L2 cache 命中,一个比较直接的想法就是尽可能把一个 wave 的 Block 变成方的。但就算不搞,sgemm 也不在乎,因为其实对性能来讲并没有什么区别,所以就没搞。



SGEMM Block Tiling


而在 sgemm 的语境下,假设最坏的情况即一个 wave 都不能覆盖目标矩阵 C 的一行,且 RTX 2080 有 46 个 SM,一个 SM 能跑两个 Block,此时




带入上式可得,此时 L2 cache 命中率大概是 49.4%。这里我们并没有考虑访问 C 矩阵的影响,在实践中会把 L2 cache 的命中率拉低一点。但即便是如此,前文我们分析过只要 L2 cache 命中达到 20%,在带宽上就不会造成性能瓶颈了。因此发现,就算我们采用朴素的 Block Tiling,Global Memory 访问也不会成为访存瓶颈。


但事实真的是这样吗?


细心的同学可能会发现,上图所采用的 tiling 方式并不是直觉上的用 blockIdx.x 表示 Block 在 M 维上的位置,而是用 blockIdx.y 表示 Block 在 M 上的位置。而我们简单调换一下代码中 blockIdx.x 与 blockIdx.y 的顺序,瞬间就有了 10% 左右的性能差距。目前网上并没有针对这个现象的分析(因为几乎所有人都是用的 col major 的 data layout,而且李少侠直接就在代码里使用了更优的 tiling 方式,所以没有人遇到这个问题),因此我这里提出一点个人的猜想,如果猜的不对还请指正。



L2 cache


首先我们看一下这两种 tiling 方式的区别在哪。最为直观的区别就是当 N 或 M 足够大时,采用上图中的 tiling 方式的 wave 形状是横着的,而另一种 tiling 方式的 wave 形状是竖着的,而这种竖着的形状看起来就不是 cache 友好的访存方式。


为什么这么说呢?因为我采用的是行主序的方式存储的矩阵,因此如果一个 wave 的形状是扁平的,那么每个 Block 在每一次循环遍历 B 矩阵时只会有 次 cache miss。这是由于 L2 cache 的 cache line 大小为 128 bytes,因此当数据从 Global Memory 中移动到 L2 cache 后,许多 Block 就能直接从 L2 cache 中读取数据了。然而如果一个 wave 的形状是狭长的。那么每个 Block 在第一次访问 A 矩阵的每一行时都会有 cache miss 的情况出现,即产生 次 cache miss,而后 31 次的遍历都不会有 cache miss。虽然两种 tiling 方式最终 cache miss 的次数是一样的,但这种短时间爆发的 cache miss 所带来的延迟是非常难被各种优化手段覆盖的,因为这种延迟不仅短时间内有很多次,同时每一次的延迟都很长,所以会造成性能损失。因此以后高性能代码的开发中,也要注意合理的把 cache miss 分配到 kernel 运行的各个阶段。


Bank Conflict


在查看两种 Tiling 方式的 profile 我发现,采用横着 Tiling 方式的 kernel bank conflict 更低一些。




等等,既然我们之前已经处理过 bank conflict 了,那么为什么这里还会有 bank conflict 呢?这个现象其实我也不是很清楚。但目前已知的是,在没有加 double buffer 情况下是没有 bank conflict 的,但加了 double buffer 之后或多或少会出现一些 bank conflict。那么至于为什么横着 Tiling 方式的 bank conflict 更低,我就更不知道了,因此这里还请各位 dalao 赐教。


最终版本的代码在这:https://github.com/AyakaGEMM/Hands-on-GEMM/blob/main/src/cuda/double_buffer_yhs_refine_gemm.cu


Epilogue


回顾本文,也基本达成了文章开头所立的各种 flag。当然现在还是有很多问题没有解决的,如 split K、长尾问题、分块细调等等,这些权当是一些未来展望了。近期也在尝试写一下 int8 tensor core 的矩阵乘,在较小形状上(M、N、K<=2048)能有比 cublas 更高的性能,但在更大形状上就只有 80% 左右了(这还是 L2 cache 命中率为 90% 的结果,可能还有啥别的没做好),所以就没有写 int8 tensor core 的部分。不过好歹是写完了!

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
8月前
|
Python 开发工具
2024年Python最全使用Python实现音频双通道分离,2024年最新阿里p7面试难度
2024年Python最全使用Python实现音频双通道分离,2024年最新阿里p7面试难度
2024年Python最全使用Python实现音频双通道分离,2024年最新阿里p7面试难度
|
8月前
|
算法 搜索推荐 图计算
图计算中的社区发现算法是什么?请解释其作用和常用算法。
图计算中的社区发现算法是什么?请解释其作用和常用算法。
162 0
|
2月前
|
算法 数据安全/隐私保护 开发者
马特赛特旋转算法:Python的随机模块背后的力量
马特赛特旋转算法是Python `random`模块的核心,由松本真和西村拓士于1997年提出。它基于线性反馈移位寄存器,具有超长周期和高维均匀性,适用于模拟、密码学等领域。Python中通过设置种子值初始化状态数组,经状态更新和输出提取生成随机数,代码简单高效。
122 63
|
2月前
R 语言教程 之 R 基础运算 4
本章《R基础运算》介绍了R语言中的简单运算,重点讲解了赋值运算符的使用方法,包括向左、向右及等于赋值,并通过实例演示了不同赋值方式的效果。
35 1
|
7月前
|
自然语言处理 搜索推荐 机器人
只需几个演示就能对齐大模型,杨笛一团队提出的DITTO竟如此高效
【6月更文挑战第22天】斯坦福团队推出DITTO,一种只需少量演示即可高效对齐大型语言模型的新技术。DITTO借助用户演示生成在线比较数据,实现模型对齐,无需大规模数据集。在用户研究中,DITTO表现优于传统方法,平均胜出19%,开创了LLMs对齐的简洁途径,适用于个性化助手和聊天机器人等场景。然而,它可能不适用于需要大量数据的任务,训练速度较慢,且可能无法完全匹配用户意图。[论文链接](https://arxiv.org/pdf/2406.00888)
102 10
R语言笔记丨从零学起?环境安装、基础知识、运算法则、数据类型(下)
R语言笔记丨从零学起?环境安装、基础知识、运算法则、数据类型(下)
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习基础之Tensor对象及其应用的讲解及实战(附源码 简单易懂 包括分段 映射 矩阵乘法 随机数等等)
PyTorch深度学习基础之Tensor对象及其应用的讲解及实战(附源码 简单易懂 包括分段 映射 矩阵乘法 随机数等等)
94 1
|
机器学习/深度学习 数据挖掘 Linux
R语言笔记丨从零学起?环境安装、基础知识、运算法则、数据类型(上)
R语言笔记丨从零学起?环境安装、基础知识、运算法则、数据类型
结构体全解,适合初学者的一条龙深度讲解(附手绘图详解)下
结构体全解,适合初学者的一条龙深度讲解(附手绘图详解)
71 0
|
存储 编译器 C语言
结构体全解,适合初学者的一条龙深度讲解(附手绘图详解)上
结构体全解,适合初学者的一条龙深度讲解(附手绘图详解)
121 0