写在前面
这篇文章提到了绝对位置编码和相对位置编码,但是他们都有局限性,比如绝对位置编码不能直接表征token的相对位置关系;相对位置编码过于复杂,影响效率。于是诞生了一种用绝对位置编码的方式实现相对位置编码的编码方式——旋转位置编码(Rotary Position Embedding, RoPE),兼顾效率和相对位置关系。
RoPE的核心思想是通过旋转的方式将位置信息编码到每个维度,从而使得模型能够捕捉到序列中元素的相对位置信息。现在已经在很多大模型证明了其有效性,比如ChatGLM、LLaMA等。
一、RoPE的优点
1.真正的旋转位置编码
Transformer的原版位置编码也使用了三角函数,但它生成的是每个位置的绝对编码,三角函数的主要用途是生成具有可区分性的周期性模式,也没有应用旋转变换的概念,因此属于绝对位置编码。同时原版的编码使用加法,在多层传递后导致位置信息的稀释,如下图 (没想到这张图也有被当做反面典型的时候吧 ):
RoPE不是简单的加法,而是通过复数乘法实现旋转变换,这种旋转是将位置信息融入到token表示中的关键机制。RoPE在实现过程中通过乘法操作融入位置信息,与模型中的Q和K深度融合,将旋转操作真正植入Attention机制内部,强化了位置编码信息的作用。
2.更好的相对位置信息编码
注意力机制通过计算Embedding的内积来确定它们之间的关系强度。
使用RoPE时,两个位置的编码通过旋转变换后的内积,自然地包含了它们之间的相对位置信息。这是因为旋转操作保持了内积的性质,使得内积计算不仅反映了token的内容相似性,还反映了它们的位置关系。
3.更适用于多维输入
这点很有意思,传统的Transformer位置编码主要针对一维序列,如文本序列。然而,在某些任务中,输入可能是二维或更高维的数据,如图像或视频数据。旋转位置编码可以更灵活地应用于多维输入数据,通过对不同维度的位置信息进行编码,使得模型能够更好地理解多维数据中的位置关系。
4. 更善于处理长序列
RoPE可以减少位置信息的损失。在深层网络中,RoPE通过乘法操作融入位置信息,乘法操作有助于在深层网络中保持位置信息的完整性。在处理一个长文本时,RoPE通过在每一层的自注意力计算中使用旋转变换,确保了位置信息能够被有效保留和利用,即使是在模型的较深层次。
二、公式
既然旋转的位置编码有这么多优点,那怎么实现位置编码的旋转呢,其实网上有很多介绍的文章。大概意思就是复数可以通过乘以e的幂来旋转角度,其中幂就是角度,再结合欧拉公式推出三角函数的表达,大致流程如下。
三、代码实现
我们以ChatGLM的代码为例,展示一下RoPE的使用,以下代码都在modeling_chatglm.py文件中,一条训练数据:
{"context": "你好", "target": "你好,我是大白话"}
1.字符串转换成token_ids
[ 5, 74874, 130001, 130004, 5, 74874, 6, 65806, 63850, 95351, 130005]
2.计算position_ids
根据上面的token_ids计算出position_ids:
[[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]]
解释一下position_ids:第一行表示序列中每个元素的全局位置,第一个“2”表明context结束了,target要开始了,后面所有的2都是target部分;第二行则细化到更具体的局部位置,从1开始表征整个target的内容,这样用两个维度的编码很优雅的体现了context和target,这种层次化处理对于理解上下文非常重要。
代码如下:
def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): """ 根据token_ids生成position_ids :param input_ids: 这里是[[ 5, 74874, 130001, 130004, 5, 74874, 6, 65806, 63850, 95351, 130005]] :param mask_positions: 2 输出的第1维mask掉几位,即这一位及其前面都是0,后面是1,2... :param device: :param use_gmasks: :return: [[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]] """ batch_size, seq_length = input_ids.shape if use_gmasks is None: use_gmasks = [False] * batch_size context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] if self.position_encoding_2d: # 会走这一分支 position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) for i, context_length in enumerate(context_lengths): position_ids[i, context_length:] = mask_positions[i] block_position_ids = [torch.cat(( torch.zeros(context_length, dtype=torch.long, device=device), torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 )) for context_length in context_lengths] block_position_ids = torch.stack(block_position_ids, dim=0) position_ids = torch.stack((position_ids, block_position_ids), dim=1) else: position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) for i, context_length in enumerate(context_lengths): if not use_gmasks[i]: position_ids[i, context_length:] = mask_positions[i] return position_ids
3.角度序列Embedding
接下来,将position_ids转换成角度序列Embedding,下表中每个格的公式为
m | i | 0 | 31 | 0 | 31 | ||
0 | 0/(10000^(2*0/64)) | ... | 0/(10000^(2*31/64)) | 0/(10000^(2*0/64)) | ... | 0/(10000^(2*31/64)) |
1 | 1/(10000^(2*0/64)) | 1/(10000^(2*31/64)) | 1/(10000^(2*0/64)) | 1/(10000^(2*31/64)) |
2 | 2/(10000^(2*0/64)) | ... | 2/(10000^(2*31/64)) | 2/(10000^(2*0/64)) | ... | 2/(10000^(2*31/64)) |
... | ... |
class RotaryEmbedding(torch.nn.Module): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): pass def __init__(self, dim, base=10000, precision=torch.half, learnable=False): """ 根据position_ids计算旋转角度的Embedding :param dim: 这里hidden_size // (num_attention_heads * 2)=46,其中hidden_size=4096 num_attention_heads=32 :param base: :param precision: :param learnable: """ super().__init__() # 初始化“频率”,可以理解为position_id每增加1,增加的角度,是Embedding形式的。 inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = inv_freq.half() self.learnable = learnable if learnable: self.inv_freq = torch.nn.Parameter(inv_freq) self.max_seq_len_cached = None else: self.register_buffer('inv_freq', inv_freq) self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None self.precision = precision def forward(self, x, seq_dim=1, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else seq_len # 1.对position_ids去重并正序排列得到t,如:[[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]] --> t=[[0, 1, 2]] t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) # 2.t与初始化好的“频率”做外积,得到每个position_id的角度,是Embedding freqs = torch.einsum('i,j->ij', t, self.inv_freq) # 3.每个Embedding重复叠加一次 emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: emb = emb.float() # 4.算cos和sin,并增加维度 cos_cached = emb.cos()[:, None, :] sin_cached = emb.sin()[:, None, :] if self.precision == torch.bfloat16: cos_cached = cos_cached.bfloat16() sin_cached = sin_cached.bfloat16() if self.learnable: return cos_cached, sin_cached self.cos_cached, self.sin_cached = cos_cached, sin_cached return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] # 类似于查表,根据每个position_id获取相应的Embedding cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) ......
4.截取拼接Q和K
这一步对Q或者K做截断,并将第二段取反拼在第一段的前面,拼接成公式第二项的q部分。
上述3、4流程示意图:
代码如下:
def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1)
5.旋转位置编码融合
代码如下:
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] # 类似于查表,根据每个position_id获取相应的Embedding cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) # 执行旋转位置编码与QK的融合 q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k # 整体流程如下 # 1.拆分出Q1、Q2、K1、K2 q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) # 2.计算旋转Embedding cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ position_ids[:, 1, :].transpose(0, 1).contiguous() # 3.旋转位置编码融合 q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) # 4.将拆分出的Q1、Q2、K1、K2合并成新的Q、K query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
位置编码对于Transformer的重要性毋庸置疑,旋转位置编码也确实解决了一些问题。最有意思的就是它是一个二维编码,将旋转信息通过乘法操作融入Attention机制内部,强化了位置编码信息,现在已经有很多开源大模型都使用了旋转位置编码,可见其效果不俗。
旋转位置编码就介绍到这里,关注不迷路(#^.^#)