什么是 Jagged Tensor
Jagged Tensor(锯齿张量)是一种变长嵌套张量,用于表示每个样本长度不同的特征。与普通的 padded 2D tensor 不同,它用 values(拼接的值)+ offsets/lengths(每个样本的长度)来紧凑存储,避免大量 padding 浪费。
例子 1:用户历史点击序列(长度不一致)
假设一个 batch 有 3 个用户,他们的历史点击 item_id 分别为:
plaintext
用户A: 点击了 [101, 202, 303] → 长度 3 用户B: 点击了 [404] → 长度 1 用户C: 点击了 [505, 606, 707, 808] → 长度 4
Padded 方式(浪费内存):
python
# 需要 pad 到最长 4,大量 0 浪费 tensor([[101, 202, 303, 0], [404, 0, 0, 0], [505, 606, 707, 808]]) # shape: (3, 4)
Jagged Tensor 方式(紧凑):
python
values = [101, 202, 303, 404, 505, 606, 707, 808] # 所有值拼接,长度 8 lengths = [3, 1, 4] # 每个用户的序列长度 offsets = [0, 3, 4, 8] # 累加偏移量
例子 2:多值标签特征(tag)
用户打的标签数量不同:
plaintext
用户A: tags = ["科技", "体育"] → 长度 2 用户B: tags = ["美食", "旅行", "摄影"] → 长度 3 用户C: tags = ["音乐"] → 长度 1
python
values = [科技, 体育, 美食, 旅行, 摄影, 音乐] # 拼接 lengths = [2, 3, 1]
例子 3:jagged_tensors.py 中三个 op 的实际场景
这三个 op 来源于 generative-recommenders(HSTU 模型),用于处理用户行为序列 + 候选 item 拼接的场景:
concat_2D_jagged — 拼接用户历史和候选 item
plaintext
用户A历史: [e1, e2, e3](3个embedding) 候选item: [c1, c2](2个embedding) 用户B历史: [e4](1个embedding) 候选item: [c3, c4, c5](3个embedding)
拼接后:
plaintext
用户A: [e1, e2, e3, c1, c2] → 长度 3+2=5 用户B: [e4, c3, c4, c5] → 长度 1+3=4
用 padded tensor 需要 shape (2, 5, dim),用户B 有 1 行 padding。
用 jagged tensor 只存 values: (9, dim) + offsets,零浪费。
split_2D_jagged — 从拼接结果中拆回用户/候选
HSTU 的 Transformer 对拼接后的序列做完 self-attention 后,需要把用户部分和候选部分拆回来:
plaintext
拼接的attention输出: [a1, a2, a3, a4, a5, a6, a7, a8, a9] offsets_left = [0, 3, 4] ← 用户历史的偏移 offsets_right = [0, 2, 5] ← 候选item的偏移 → left: 用户表征 → right: 候选打分
jagged_dense_bmm_broadcast_add — 变长序列的矩阵乘法
对每个用户的变长序列做投影:out = jagged × dense + bias
plaintext
jagged: (sum_B(M_i), K) ← 所有用户序列拼接,总共 sum 个 token dense: (B, K, N) ← 每个用户一个投影矩阵 bias: (B, N) ← 每个用户一个偏置 out: (sum_B(M_i), N) ← 变长输出
传统做法需要 pad → bmm → unpad,jagged 版本直接在紧凑格式上计算,省内存且更快。
为什么推荐系统特别需要 Jagged Tensor?
推荐系统中到处都是变长特征:
- 用户点击历史(有人点了 5 个,有人点了 500 个)
- 多值 ID 特征(用户标签、商品类目)
- 序列特征中的多值子特征(每个行为关联多个属性)
如果全部 pad 到最大长度,内存和计算浪费巨大。Jagged Tensor 是推荐系统处理这类数据的标准做法,torchrec 的 KeyedJaggedTensor 也是基于同样的理念