RoPE中旋转位置编码的全部过程如图所示:
这里我以自己的理解解释一下这张图以及等式
首先我们以二维为例子,为了方便我们令 m=1,再把 Wqxm用 (x1,x2)表示, q m q_m qm用 (x1′,x2′)表示,就有了如下等式:
这里我觉得不应该弄成等式,我转化成这样好理解一些:
第一种:
我们有 ,采取的处理方式是先 ,单数取实部,双数取虚部这里有
即
都知道欧拉函数 从欧拉函数中我们可以发现 是可以对应平面直角坐标系的,即 ;从这里的公式来说x1,x2表示的是标量,如果使用 ,从某种意义上拓展了其维度,这也是我使用 <=表示的原因;
第二种:转化为指数形式再展开
上面说了,标量直接与 从某种意义上来说相当于拓维,所以我们可以两两构成 ,再来与 进行计算,最后再把 转化为 ;有
得到的旋转矩阵就是
引入到多维有旋转矩阵为:
这里直接贴的原文,转置的原因为我这里的顺序与原文相反;
同时可以发现:
通过这个结论拓展到多维就可以得到:
这里是原文中出现的错误 应该改为
也就是相当于
这里插入介绍一下旋转矩阵的快速计算技巧:
结合下面代码:
def rotate_every_two(x): x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) def theta_shift(x, sin, cos): return (x * cos) + (rotate_every_two(x) * sin)
结束!