为什么使用 TorchRec 训练和推理更快

简介: 本文结合TorchEasyRec实践,从四大维度解析推荐系统加速:1)KeyedJaggedTensor统一变长特征,实现Embedding批量融合查找;2)自动分布式分片突破单卡显存瓶颈;3)TrainPipelineSparseDist流水线并行,重叠通信与计算;4)fbgemm-gpu融合优化器,减少显存访问。端到端提升训练效率与扩展性。

结合 TorchEasyRec 代码中的实际使用,从四个核心维度分析。

一、KeyedJaggedTensor(KJT)— 数据层加速

问题:推荐系统的特征天然是变长的

以 dbmtl_taobao.config 为例,16 个特征中有多值 ID 特征(如 brand、cate_id),不同样本的特征长度不同。

标准 PyTorch 做法(nn.EmbeddingBag)

# 每个特征独立处理,需要逐个调用
emb_user_id = self.emb_user_id(user_id_tensor)      # 调用 1
emb_brand   = self.emb_brand(brand_values, offsets)   # 调用 2
emb_cate_id = self.emb_cate_id(cate_values, offsets)  # 调用 3
# ... 16 个特征 = 16 次独立的 Embedding 查找
# 最后手动 torch.cat([emb_user_id, emb_brand, ...])
  • 16 次独立的 CUDA kernel launch,每次 kernel launch 有固定开销(~10μs)
  • 变长特征需要手动管理 padding 或 offsets
  • CPU→GPU 数据传输也是 16 次独立 tensor

TorchRec 的 KJT + EmbeddingBagCollection 做法

# 所有特征打包成一个 KJT
kjt = KeyedJaggedTensor(
    keys=["user_id", "brand", "cate_id", ...],  # 16 个特征名
    values=torch.tensor([...]),                    # 所有 ID 拼接
    lengths=torch.tensor([...]),                   # 每个样本每个特征的长度
)
# 一次调用完成所有 Embedding 查找
result = ebc(kjt)  # 单次 fused kernel

embedding.py 可以看到:

self.ebc = EmbeddingBagCollection(list(emb_bag_configs.values()), device=device)

加速对比

维度

标准 PyTorch

TorchRec KJT + EBC

Kernel launch

N 次(N 个特征)

1 次(fused)

数据传输

N 个 tensor 独立传输

1 个 KJT 批量传输

内存布局

分散,cache 不友好

连续紧凑,cache 友好

变长处理

手动 padding/offsets

内置 lengths 支持


二、分布式 Embedding 分片(Sharding)— 突破单卡显存瓶颈

推荐模型的 Embedding 表极大,以 config 中为例:

user_id:     1,141,730 × 16 = ~17MB
brand:         461,498 × 16 = ~7MB
adgroup_id:    846,812 × 16 = ~13MB

这只是 demo 数据。生产环境 Embedding 表可达数十 GB 甚至 TB 级别,远超单卡显存。

TorchRec 的自动分片

plan_util.py 可以看到 TorchEasyRec 使用 EmbeddingShardingPlanner 自动规划:

planner = EmbeddingShardingPlanner(
    topology=topology,                    # 集群拓扑(GPU 数量、显存、带宽)
    enumerator=EmbeddingEnumerator(...),  # 枚举所有可能的分片方案
    proposer=[DynamicProgrammingProposer(), UniformProposer()],  # DP 求最优
)
plan = planner.collective_plan(model, sharders, ...)
model = DistributedModelParallel(model, plan=plan, ...)

分片策略

  • table_wise:整张表放在一个 GPU 上
  • row_wise:大表按行切分到多个 GPU
  • column_wise:按 embeddingdim 切分
  • data_parallel:每个 GPU 全量复制

Planner 用动态规划算法自动找到最优分片方案,使得:

  • 显存均衡分布在各卡上
  • 通信量最小化
  • 各卡计算负载均衡

三、TrainPipelineSparseDist — 流水线并行加速

这是 TorchRec 的杀手级特性。从 dist_util.py 可以看到 TorchEasyRec 直接使用了这个流水线。

无流水线(标准 PyTorch DDP)

Batch 1: [数据加载] → [Embedding 分发] → [前向] → [反向] → [梯度同步]
Batch 2:                                                    → [数据加载] → ...

每个阶段串行等待,GPU 大量空闲。

TrainPipelineSparseDist(3 阶段流水线)

时间 →  T1        T2        T3        T4        T5
Batch1: [数据加载] [Emb分发]  [前向+反向]
Batch2:           [数据加载] [Emb分发]  [前向+反向]
Batch3:                     [数据加载] [Emb分发]  [前向+反向]
  • Batch N 的 Embedding 分发 与 Batch N-1 的前向/反向 同时进行
  • Batch N+1 的数据加载 与 Batch N 的 Embedding 分发 同时进行
  • 通信和计算完全重叠,GPU 利用率接近 100%

四、fbgemm-gpu Fused Optimizer — 优化器层加速

optimizer.py 可以看到:

from fbgemm_gpu import split_table_batched_embeddings_ops_training

标准 PyTorch 的 Adagrad/Adam 优化器对 Embedding 的更新流程:

前向 → 反向得到梯度 → 优化器读取梯度 → 更新权重(3 次显存读写)

fbgemm-gpu 的 fused optimizer 将梯度计算和权重更新融合在一个 kernel 中:

前向 → 反向直接在 kernel 内更新权重(1 次显存读写)

结合 apply_optimizer_in_backward,实现了 backward 和 optimizer.step 的融合,避免梯度的中间存储。


总结:端到端加速对比


优化层

加速来源

典型提升

数据表示

KJT 紧凑存储 + fused kernel

减少 kernel launch 开销

Embedding 查找

EBC 批量 fused 查找

10x+ vs 逐特征查找

分布式分片

自动 DP 最优分片

突破单卡显存,线性扩展

流水线

通信/计算重叠

30-50% 吞吐提升

优化器

fbgemm fused backward

减少 2/3 显存带宽

这就是为什么 TorchEasyRec 的核心数据流全部围绕 TorchRec 构建——它不只是一个 Embedding 库,而是一套从数据表示到分布式训练到优化器的完整加速栈。

相关文章
|
1月前
|
分布式计算 MaxCompute iOS开发
TorchEasyRec 在 macOS 上的功能限制总结
本文总结tzrec在macOS上的功能限制:核心依赖(如torchrec、fbgemm-gpu、graphlearn等)无法安装;分布式训练、原生数据管线、Embedding模块、Triton/CUDA算子、TDM树模型等功能完全不可用;优化器与模型导出部分失效;单元测试大多因强依赖而失败。
167 15
|
1月前
|
人工智能 安全 JavaScript
OpenClaw(小龙虾)Windows 11 一键部署教程 2026 最新版
OpenClaw(小龙虾)是GitHub星标28W+的开源本地AI智能体,支持Win11全版本,零代码、免配置、解压即用!自动操控电脑、整理文件、浏览器/办公自动化,数据全程本地运行,隐私安全。一键部署包v2.6.0专为Win11优化,10分钟搞定!
911 129
|
1月前
|
安全 Java 索引
java工具:《对Collections.sort排序后我想制定查询几条,比如list有10条,我只想获取前4条》
java工具:《对Collections.sort排序后我想制定查询几条,比如list有10条,我只想获取前4条》
108 12
|
1月前
|
人工智能
【钉钉会议 | 日程 Skill】让 Agent 真正帮你「把时间排进钉钉」
钉钉日程助手技能,打通“找人→约时→订室→发邀→跟进”全链路。支持查空闲、抢会议室、一键建会(含视频)、签到链接推送、周期例会自动排期,让AI真正驱动协作闭环。(239字)
305 15
|
15天前
|
人工智能 缓存 安全
阿里云百炼Token Plan 标准坐席25,000 Credits 能用多少token或者调用次数?
阿里百炼Token Plan标准坐席198元/月,提供25,000 Credits额度(非固定Token数或调用次数)。支持多模型、全模态(文本/视觉/图像生成),动态计费,兼顾灵活与安全,适合轻度AI辅助团队。
|
1月前
|
机器学习/深度学习 JSON 自然语言处理
PAI-Rec 特征工程全解析:统计特征、实时特征、序列特征与 FG 特征算子
PAI-Rec是阿里云智能推荐的特征工程解决方案,支持离线统计、实时及序列特征自动衍生,并通过Feature Generator(17种内置算子)保障离线/在线特征一致性,大幅降低开发与维护成本。
413 9
|
16天前
|
JSON API PHP
印度股票实时数据 NSE和BSE的实时行情、K 线及指数数据
StockTV全面支持印度股市,覆盖NSE(ID 46)与BSE(ID 74)实时行情、指数及K线数据。对接需设`countryId=14`,通过API Key调用统一接口,支持股票列表、实时报价、Nifty/Sensex指数及多周期K线查询,PHP示例开箱即用。(239字)
|
1月前
|
Kubernetes 网络协议 文件存储
Docker镜像拉了一下午还没完?我受够了,花了一周找替代方案
上周拉镜像卡在47%两小时?试遍阿里云、高校源、GitHub清单全失效。直到发现「毫秒镜像」——宝塔、爱快、绿联NAS已原生集成,金融级客户背书。一行命令安装,3秒拉完nginx,全仓库加速(Docker Hub/gcr/ghcr/k8s等),含DNS自诊。免费版够用,稳定不跑路。
746 18
|
1月前
|
人工智能 安全 数据可视化
Windows 全版本 OpenClaw 搭建教程 零代码可视化一键部署
OpenClaw(小龙虾)是2026年热门开源AI自动化工具,支持Win10/11本地离线运行。零代码、全图形化、内置依赖、多模型切换、大Token额度,5–10分钟一键部署。数据不出设备,安全可控,适配办公全场景。(239字)
208 1
|
1月前
|
机器学习/深度学习 分布式计算 搜索推荐
PAI-Rec 召回引擎:构建高性能推荐系统的核心引擎
PAI-Rec是阿里云智能推荐平台的核心召回引擎,经阿里大规模场景验证。支持多路召回融合(U2I/I2I/向量/随机)、召回即过滤、毫秒级实时更新与分布式弹性架构,开箱即用,助力企业构建毫秒级、高精度、强实时的推荐系统。
272 9