TorchRec大量使用Jagged Tensor

简介: Jagged Tensor(锯齿张量)是专为变长序列设计的紧凑存储格式,用values+lengths/offsets替代padding,显著节省内存与计算。广泛应用于推荐系统中用户行为、多值标签等不等长特征处理,如HSTU模型中的拼接、拆分与矩阵乘法操作。

什么是 Jagged Tensor

Jagged Tensor(锯齿张量)是一种变长嵌套张量,用于表示每个样本长度不同的特征。与普通的 padded 2D tensor 不同,它用 values(拼接的值)+ offsets/lengths(每个样本的长度)来紧凑存储,避免大量 padding 浪费。


例子 1:用户历史点击序列(长度不一致)

假设一个 batch 有 3 个用户,他们的历史点击 item_id 分别为:

plaintext

用户A: 点击了 [101, 202, 303]         → 长度 3
用户B: 点击了 [404]                    → 长度 1  
用户C: 点击了 [505, 606, 707, 808]    → 长度 4

Padded 方式(浪费内存):

python

# 需要 pad 到最长 4,大量 0 浪费
tensor([[101, 202, 303,   0],
        [404,   0,   0,   0],
        [505, 606, 707, 808]])   # shape: (3, 4)

Jagged Tensor 方式(紧凑):

python

values  = [101, 202, 303, 404, 505, 606, 707, 808]   # 所有值拼接,长度 8
lengths = [3, 1, 4]                                     # 每个用户的序列长度
offsets = [0, 3, 4, 8]                                  # 累加偏移量


例子 2:多值标签特征(tag)

用户打的标签数量不同:

plaintext

用户A: tags = ["科技", "体育"]           → 长度 2
用户B: tags = ["美食", "旅行", "摄影"]   → 长度 3
用户C: tags = ["音乐"]                   → 长度 1

python

values  = [科技, 体育, 美食, 旅行, 摄影, 音乐]   # 拼接
lengths = [2, 3, 1]


例子 3:jagged_tensors.py 中三个 op 的实际场景

这三个 op 来源于 generative-recommenders(HSTU 模型),用于处理用户行为序列 + 候选 item 拼接的场景:

concat_2D_jagged — 拼接用户历史和候选 item

plaintext

用户A历史: [e1, e2, e3](3个embedding)    候选item: [c1, c2](2个embedding)
用户B历史: [e4](1个embedding)             候选item: [c3, c4, c5](3个embedding)

拼接后:

plaintext

用户A: [e1, e2, e3, c1, c2]       → 长度 3+2=5
用户B: [e4, c3, c4, c5]           → 长度 1+3=4

用 padded tensor 需要 shape (2, 5, dim),用户B 有 1 行 padding。

用 jagged tensor 只存 values: (9, dim) + offsets,零浪费。

split_2D_jagged — 从拼接结果中拆回用户/候选

HSTU 的 Transformer 对拼接后的序列做完 self-attention 后,需要把用户部分和候选部分拆回来:

plaintext

拼接的attention输出: [a1, a2, a3, a4, a5, a6, a7, a8, a9]
offsets_left  = [0, 3, 4]      ← 用户历史的偏移
offsets_right = [0, 2, 5]      ← 候选item的偏移
→ left:  用户表征
→ right: 候选打分

jagged_dense_bmm_broadcast_add — 变长序列的矩阵乘法

对每个用户的变长序列做投影:out = jagged × dense + bias

plaintext

jagged: (sum_B(M_i), K)  ← 所有用户序列拼接,总共 sum 个 token
dense:  (B, K, N)         ← 每个用户一个投影矩阵
bias:   (B, N)            ← 每个用户一个偏置
out:    (sum_B(M_i), N)   ← 变长输出

传统做法需要 pad → bmm → unpad,jagged 版本直接在紧凑格式上计算,省内存且更快。


为什么推荐系统特别需要 Jagged Tensor?

推荐系统中到处都是变长特征:

  • 用户点击历史(有人点了 5 个,有人点了 500 个)
  • 多值 ID 特征(用户标签、商品类目)
  • 序列特征中的多值子特征(每个行为关联多个属性)

如果全部 pad 到最大长度,内存和计算浪费巨大。Jagged Tensor 是推荐系统处理这类数据的标准做法,torchrec 的 KeyedJaggedTensor 也是基于同样的理念

相关文章
|
5天前
|
人工智能 JSON 监控
Claude Code 源码泄露:一份价值亿元的 AI 工程公开课
我以为顶级 AI 产品的护城河是模型。读完这 51.2 万行泄露的源码,我发现自己错了。
4054 12
|
16天前
|
人工智能 JSON 机器人
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
本文带你零成本玩转OpenClaw:学生认证白嫖6个月阿里云服务器,手把手配置飞书机器人、接入免费/高性价比AI模型(NVIDIA/通义),并打造微信公众号“全自动分身”——实时抓热榜、AI选题拆解、一键发布草稿,5分钟完成热点→文章全流程!
11635 137
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
|
4天前
|
人工智能 数据可视化 安全
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
本文详解如何用阿里云Lighthouse一键部署OpenClaw,结合飞书CLI等工具,让AI真正“动手”——自动群发、生成科研日报、整理知识库。核心理念:未来软件应为AI而生,CLI即AI的“手脚”,实现高效、安全、可控的智能自动化。
1420 7
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
|
6天前
|
人工智能 自然语言处理 数据挖掘
零基础30分钟搞定 Claude Code,这一步90%的人直接跳过了
本文直击Claude Code使用痛点,提供零基础30分钟上手指南:强调必须配置“工作上下文”(about-me.md+anti-ai-style.md)、采用Cowork/Code模式、建立标准文件结构、用提问式提示词驱动AI理解→规划→执行。附可复制模板与真实项目启动法,助你将Claude从聊天工具升级为高效执行系统。
|
5天前
|
人工智能 定位技术
Claude Code源码泄露:8大隐藏功能曝光
2026年3月,Anthropic因配置失误致Claude Code超51万行源码泄露,意外促成“被动开源”。代码中藏有8大未发布功能,揭示其向“超级智能体”演进的完整蓝图,引发AI编程领域震动。(239字)
2324 9