为什么使用 TorchRec 训练和推理更快

简介: 本文结合TorchEasyRec实践,从四大维度解析推荐系统加速:1)KeyedJaggedTensor统一变长特征,实现Embedding批量融合查找;2)自动分布式分片突破单卡显存瓶颈;3)TrainPipelineSparseDist流水线并行,重叠通信与计算;4)fbgemm-gpu融合优化器,减少显存访问。端到端提升训练效率与扩展性。

结合 TorchEasyRec 代码中的实际使用,从四个核心维度分析。

一、KeyedJaggedTensor(KJT)— 数据层加速

问题:推荐系统的特征天然是变长的

以 dbmtl_taobao.config 为例,16 个特征中有多值 ID 特征(如 brand、cate_id),不同样本的特征长度不同。

标准 PyTorch 做法(nn.EmbeddingBag)

# 每个特征独立处理,需要逐个调用
emb_user_id = self.emb_user_id(user_id_tensor)      # 调用 1
emb_brand   = self.emb_brand(brand_values, offsets)   # 调用 2
emb_cate_id = self.emb_cate_id(cate_values, offsets)  # 调用 3
# ... 16 个特征 = 16 次独立的 Embedding 查找
# 最后手动 torch.cat([emb_user_id, emb_brand, ...])
  • 16 次独立的 CUDA kernel launch,每次 kernel launch 有固定开销(~10μs)
  • 变长特征需要手动管理 padding 或 offsets
  • CPU→GPU 数据传输也是 16 次独立 tensor

TorchRec 的 KJT + EmbeddingBagCollection 做法

# 所有特征打包成一个 KJT
kjt = KeyedJaggedTensor(
    keys=["user_id", "brand", "cate_id", ...],  # 16 个特征名
    values=torch.tensor([...]),                    # 所有 ID 拼接
    lengths=torch.tensor([...]),                   # 每个样本每个特征的长度
)
# 一次调用完成所有 Embedding 查找
result = ebc(kjt)  # 单次 fused kernel

embedding.py 可以看到:

self.ebc = EmbeddingBagCollection(list(emb_bag_configs.values()), device=device)

加速对比

维度

标准 PyTorch

TorchRec KJT + EBC

Kernel launch

N 次(N 个特征)

1 次(fused)

数据传输

N 个 tensor 独立传输

1 个 KJT 批量传输

内存布局

分散,cache 不友好

连续紧凑,cache 友好

变长处理

手动 padding/offsets

内置 lengths 支持


二、分布式 Embedding 分片(Sharding)— 突破单卡显存瓶颈

推荐模型的 Embedding 表极大,以 config 中为例:

user_id:     1,141,730 × 16 = ~17MB
brand:         461,498 × 16 = ~7MB
adgroup_id:    846,812 × 16 = ~13MB

这只是 demo 数据。生产环境 Embedding 表可达数十 GB 甚至 TB 级别,远超单卡显存。

TorchRec 的自动分片

plan_util.py 可以看到 TorchEasyRec 使用 EmbeddingShardingPlanner 自动规划:

planner = EmbeddingShardingPlanner(
    topology=topology,                    # 集群拓扑(GPU 数量、显存、带宽)
    enumerator=EmbeddingEnumerator(...),  # 枚举所有可能的分片方案
    proposer=[DynamicProgrammingProposer(), UniformProposer()],  # DP 求最优
)
plan = planner.collective_plan(model, sharders, ...)
model = DistributedModelParallel(model, plan=plan, ...)

分片策略

  • table_wise:整张表放在一个 GPU 上
  • row_wise:大表按行切分到多个 GPU
  • column_wise:按 embeddingdim 切分
  • data_parallel:每个 GPU 全量复制

Planner 用动态规划算法自动找到最优分片方案,使得:

  • 显存均衡分布在各卡上
  • 通信量最小化
  • 各卡计算负载均衡

三、TrainPipelineSparseDist — 流水线并行加速

这是 TorchRec 的杀手级特性。从 dist_util.py 可以看到 TorchEasyRec 直接使用了这个流水线。

无流水线(标准 PyTorch DDP)

Batch 1: [数据加载] → [Embedding 分发] → [前向] → [反向] → [梯度同步]
Batch 2:                                                    → [数据加载] → ...

每个阶段串行等待,GPU 大量空闲。

TrainPipelineSparseDist(3 阶段流水线)

时间 →  T1        T2        T3        T4        T5
Batch1: [数据加载] [Emb分发]  [前向+反向]
Batch2:           [数据加载] [Emb分发]  [前向+反向]
Batch3:                     [数据加载] [Emb分发]  [前向+反向]
  • Batch N 的 Embedding 分发 与 Batch N-1 的前向/反向 同时进行
  • Batch N+1 的数据加载 与 Batch N 的 Embedding 分发 同时进行
  • 通信和计算完全重叠,GPU 利用率接近 100%

四、fbgemm-gpu Fused Optimizer — 优化器层加速

optimizer.py 可以看到:

from fbgemm_gpu import split_table_batched_embeddings_ops_training

标准 PyTorch 的 Adagrad/Adam 优化器对 Embedding 的更新流程:

前向 → 反向得到梯度 → 优化器读取梯度 → 更新权重(3 次显存读写)

fbgemm-gpu 的 fused optimizer 将梯度计算和权重更新融合在一个 kernel 中:

前向 → 反向直接在 kernel 内更新权重(1 次显存读写)

结合 apply_optimizer_in_backward,实现了 backward 和 optimizer.step 的融合,避免梯度的中间存储。


总结:端到端加速对比


优化层

加速来源

典型提升

数据表示

KJT 紧凑存储 + fused kernel

减少 kernel launch 开销

Embedding 查找

EBC 批量 fused 查找

10x+ vs 逐特征查找

分布式分片

自动 DP 最优分片

突破单卡显存,线性扩展

流水线

通信/计算重叠

30-50% 吞吐提升

优化器

fbgemm fused backward

减少 2/3 显存带宽

这就是为什么 TorchEasyRec 的核心数据流全部围绕 TorchRec 构建——它不只是一个 Embedding 库,而是一套从数据表示到分布式训练到优化器的完整加速栈。

相关文章
|
13天前
|
分布式计算 MaxCompute iOS开发
TorchEasyRec 在 macOS 上的功能限制总结
本文总结tzrec在macOS上的功能限制:核心依赖(如torchrec、fbgemm-gpu、graphlearn等)无法安装;分布式训练、原生数据管线、Embedding模块、Triton/CUDA算子、TDM树模型等功能完全不可用;优化器与模型导出部分失效;单元测试大多因强依赖而失败。
108 15
|
12天前
|
人工智能 安全 JavaScript
OpenClaw(小龙虾)Windows 11 一键部署教程 2026 最新版
OpenClaw(小龙虾)是GitHub星标28W+的开源本地AI智能体,支持Win11全版本,零代码、免配置、解压即用!自动操控电脑、整理文件、浏览器/办公自动化,数据全程本地运行,隐私安全。一键部署包v2.6.0专为Win11优化,10分钟搞定!
587 129
|
23天前
|
人工智能 弹性计算 数据可视化
阿里云OpenClaw部署实操教程:轻量应用服务器+百炼免费大模型
OpenClaw(“小龙虾”)是一款开源AI智能体,不仅能聊天,更能自动处理文件、运行代码、收发邮件等任务。本教程教你用阿里云轻量服务器+百炼免费大模型,零代码10分钟部署专属AI数字员工!
568 26
|
9天前
|
安全 Java 索引
java工具:《对Collections.sort排序后我想制定查询几条,比如list有10条,我只想获取前4条》
java工具:《对Collections.sort排序后我想制定查询几条,比如list有10条,我只想获取前4条》
78 12
|
12天前
|
Kubernetes 网络协议 文件存储
Docker镜像拉了一下午还没完?我受够了,花了一周找替代方案
上周拉镜像卡在47%两小时?试遍阿里云、高校源、GitHub清单全失效。直到发现「毫秒镜像」——宝塔、爱快、绿联NAS已原生集成,金融级客户背书。一行命令安装,3秒拉完nginx,全仓库加速(Docker Hub/gcr/ghcr/k8s等),含DNS自诊。免费版够用,稳定不跑路。
416 19
|
5天前
|
人工智能 安全 数据可视化
Windows 全版本 OpenClaw 搭建教程 零代码可视化一键部署
OpenClaw(小龙虾)是2026年热门开源AI自动化工具,支持Win10/11本地离线运行。零代码、全图形化、内置依赖、多模型切换、大Token额度,5–10分钟一键部署。数据不出设备,安全可控,适配办公全场景。(239字)
|
14天前
|
机器学习/深度学习 搜索推荐 数据处理
PAI-Rec推荐开发平台:企业级智能推荐解决方案,驱动业务全域增长
PAI-Rec是阿里云一站式推荐系统平台,集成多路召回、多目标精排(如DBMTL)、GPU加速推理与灵活迭代能力,已助力电商、直播、音视频等多行业提升点击率、转化率与ROI,实现高效、低成本、可自主演进的智能推荐。
156 16
|
12天前
|
人工智能
【钉钉会议 | 日程 Skill】让 Agent 真正帮你「把时间排进钉钉」
钉钉日程助手技能,打通“找人→约时→订室→发邀→跟进”全链路。支持查空闲、抢会议室、一键建会(含视频)、签到链接推送、周期例会自动排期,让AI真正驱动协作闭环。(239字)
152 15
|
9天前
|
Windows Python
SBTI 人格测试人一多网站就崩?试试这个本机就能轻松下载的 SBTI 测试
SBTI人格测试火爆致官网崩坏?这款Windows桌面版解压即用,离线答题不卡顿、不抢带宽,支持单机多测、随时分享。源自开源项目,尊重原作者,GitHub可下载或联系作者秒发包。(239字)
1278 11
|
12天前
|
缓存 人工智能 文字识别
阿里云Qwen3.6-Plus收费价格:输入、输出、显式缓存收费标准,2026最新
阿里云Qwen3.6-Plus是2026年推出的原生视觉语言大模型,阿里云大模型官网:https://t.aliyun.com/U/JbblVp 代码(Agentic/Vibe/前端)、OCR、多模态识别与物体定位能力显著超越3.5系列。输入2元/百万tokens,输出12元/百万tokens,显式缓存命中仅0.2元;新用户可领7000万免费Tokens。
1034 17
下一篇
开通oss服务