如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门(3)

简介: 如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门

Global Memory


前面提到 GPU 访存时以 32 Byte 为粒度进行访问的,那么一个 32 Byte 访问被称为一个 sector。那么值得注意的就是在搬运数据时,尽可能的让同一个 warp 搬运同一行的数据来避免使用额外的 sector(本文采用现代的行主序来存储矩阵)。



这里借用一下 Nvidia 的图。如果同一个 warp 内的 thread 都访问每一行的开头,那么如果一行超过 8 个 float,那么每一个 thread 都需要一个 sector 去请求它们需要的数据,这就造成了 sector 浪费。而实际中每一行的元素往往都会大于 8 个 float,因此会有非常大的性能损失。下图为一个 warp 在拷贝时,每个 thread 之间间隔的大小,单位为 float。可以看到在间隔为 2 时就已经有一半的性能损失了,这很不好。



因此我们采用下图所示的访问方式。即尽可能的让一个 warp 中的 thread 连续的读取 Global Memory 中的元素。



Shared Memory


前文已经讲过,shared memory 在图灵架构之后可以完全被看作是 L1 cache。而在此基础之上,shared memory 的访问粒度是 32 bit 也就是 4 Byte,刚好是一个 float 数据的大小。而后 shared memory 按照 4 Byte 连续的划分为一个个 bank。对于 bank 可以简单的理解为双通道内存中通道的概念,即在不同的 bank 中的数据可以并行访问,同一个 bank 内不同地址的数据只能串行访问。在 Compute Capability 5.x 及之后的卡上,shared memory 具有 32 个 bank,刚好是一个 warp 中线程的数量。而如果同一个 warp 中不同 thread 均只访问 4 Byte 数据且希望同时访问同一个 bank 的数据将会有两种结果。(对于每一个 thread 访问更多数据的行为将在后面提到)


1. 两个或多个 thread 访问的刚好是同一个地址内的数据,那么此时将会触发 broadcast 机制,即实际只读取一次数据,而后广播到这些 thread 中。

2. 两个或多个 thread 访问的是同一个 bank 内的数据,那么此时这些 thread 的访问将会被强制安排为串行执行。这种访问情况被称为 bank conflict。


这里给出 cuda programming guide 的两张图来直观的体现 broadcast 和 bank conflict。



这张图表示同一个 warp 中的 thread 按不间隔、隔一个、隔两个 bank 对 shared memory 访问。中间的访问每两个 thread 都会发生一次 bank conflict,而其他两种访问都不会发生 bank conflict。值得注意的一点是这张图最右侧的图的访问方式刚好可以达到每一个 thread 都访问了不同的 bank 的效果。


同时考虑到 shared memory 是按照 bank 来访问的,且与 Load/Store 单元直连,并没有中间商赚差价,所以对于 shared memory 的访存并不讲究连续访存,而只需要考虑是否有 bank conflict 就足够了。因此理论上最左和最右两列图的访问性能是一样的,这与访问全局内存有一点区别。同理,每一个 warp 连续的多次访存也并不要求连续访存,而在拷贝数据到 shared memory 时对 A 矩阵做矩阵转置的目的是为了向量化访存,而不是为了连续访存。



这张图则展示了 broadcast 机制,没啥好说的。


128-bit conflict-free store


而前文中提到,我们使用 float4 来做数据传输来缓解 GPU 聚合访问的压力,使得每一个指令都更加高效。而又因为前文所述,每个线程需要使用向量外积的方法计算矩阵乘,因此我们需要在 A 矩阵转存到 shared memory 时做一次转置。


但细心的同学可能注意到,如果就这么平铺直叙的做转置那么将会发生非常严重的 bank conflict,因为一个 warp 内的奇数 thread 和偶数 thread 使用同一个 bank。那么此时解决 bank conflict 的方法有两种,第一种便是将 shared memory 的 k 维度缩小,然后直接把奇数 thread 所取的数据直接并到 M 维上,就不会有 bank conflict 的问题了。这种方法通过 index 变换,直接就能避免 bank conflict,非常巧妙,而我当时没有想到,就没有用这种方法。值得注意的是,尽管图是按行隔开的,但那只是为了表示数据是如何在一个 thread 里保存的,实际写到 shared memory 中是以一个 float 为单位,按列主序存储到 shared memory 中。



而第二种方法就非常简单粗暴了,直接往 lda 上加 4,然后就不会有 bank conflict 了。当然这种方法的弊端也是有的,那就是会造成一部分 shared memory 的浪费。但对 sgemm 来说倒也还好, shared memory 的占用也不是导致 Occupancy 降低的原因,所以我就用了这个方法。


128-bit conflict-free load


而我们把数据存储到 shared memory 之后,下一步便是考虑如何在没有 bank conflict 的情况下将数据读取出来。在本文中,我们取为 8,在采用向量化存取时,直接按照 Warp Tiling 采用朴素的存取方法就能在没有 bank conflict 的情况下把数据读出来了。



当然有的同学可能会问:既然访存是按照一个 warp 为单位进行的,而图中明显读取 B 矩阵时,t16 会和 t0 发生 bank conflict,那为什么又说不会有 bank conflict 呢?那么答案就是在做 128-bit 访存时,warp 并不是同时读取数据的。这里还是借用 Nvidia 在 GTC 2018 上的分享来做说明。



当 warp 中每个 thread 只读取 4B 或更小数据时,warp 才是同时读取的。而本文中采用 128-bit 也就是 16B 读取,那么一个 warp 会分成 4 次操作读取,每次操作只有 1/4 warp 工作。那么只要同一次操作内的 thread 没有发生 bank conflict,那么就没有 bank conflict。而上图中 t0-t7 同时操作,它们之间并没有 bank conflict,后面的 thread 依此类推,那么也就不会有 bank conflict。那么朴素的 warp tiling 实现代码在这:https://github.com/AyakaGEMM/Hands-on-GEMM/blob/main/src/cuda/warp_tile_gemm.cu


而李少侠在代码中采用了一种更高级的排布方式,即 z 字排布。与之相对应的,他将一个 thread 负责的小型矩阵乘拆分成四个更小的矩阵乘。同时这个拆分虽然是在地址上做的拆分,但在运算中依然可以看作是一个整体,即运算部分不用更改任何代码而只需要在 index 上做一些变换即可。而他这么做的理由是为了更快的 broadcast。但说实话,我不是很理解,也没搜到为什么这样能有更快的 broadcast 性能。(而且我这么试了一下,发现确实是快了,这实在是太神奇了,欢迎大家提供一些看法。)



这里我们跑一个 profile 发现,确实是没有 bank conflict 的,挺好。代码在这:

https://github.com/AyakaGEMM/Hands-on-GEMM/blob/main/src/cuda/z_thread_map_gemm.cu



相关文章
|
6月前
|
Python 开发工具
2024年Python最全使用Python实现音频双通道分离,2024年最新阿里p7面试难度
2024年Python最全使用Python实现音频双通道分离,2024年最新阿里p7面试难度
2024年Python最全使用Python实现音频双通道分离,2024年最新阿里p7面试难度
|
机器学习/深度学习 算法 搜索推荐
一文读懂FM算法优势,并用python实现!(附代码)
介绍 我仍然记得第一次遇到点击率预测问题时的情形,在那之前,我一直在学习数据科学,对自己取得的进展很满意,在机器学习黑客马拉松活动中也开始建立了自信,并决定好好迎接不同的挑战。 为了做得更好,我购买了一台内存16GB,i7处理器的机器,但是当我看到数据集的时候却感到非常不安,解压缩之后的数据大概有50GB - 我不知道基于这样的数据集要怎样进行点击率预测。
15177 0
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习基础之Tensor对象及其应用的讲解及实战(附源码 简单易懂 包括分段 映射 矩阵乘法 随机数等等)
PyTorch深度学习基础之Tensor对象及其应用的讲解及实战(附源码 简单易懂 包括分段 映射 矩阵乘法 随机数等等)
77 1
|
机器学习/深度学习 自然语言处理 算法
【机器学习实战】10分钟学会Python怎么用EM期望最大化进行参数估计(十五)
【机器学习实战】10分钟学会Python怎么用EM期望最大化进行参数估计(十五)
215 0
|
存储 缓存 并行计算
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门(2)
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
285 0
|
存储 并行计算 异构计算
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门(4)
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
234 0
|
机器学习/深度学习 缓存 自然语言处理
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门
|
API C语言 开发者
【精选】对随机粒子玩法的简单探索(C语言简单版本)
【精选】对随机粒子玩法的简单探索(C语言简单版本)
93 0
|
人工智能 缓存 移动开发
通用矩阵乘算法从入门到实践
通用矩阵乘算法从入门到实践
340 0
|
算法 计算机视觉
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(三)
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(三)
258 0
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(三)
下一篇
无影云桌面