Gradient Normalization在多任务学习中的优化实践

简介: 在每平每屋以及我们团队负责的其他一些场景,MMoE多任务模型是精排阶段常用的模型。在每平每屋场景中,我们使用MMoE模型,对场景中的三个任务同时进行点击率预估,即内容的一跳点击、内容详情页点击以及跳转的商品详情页点击。

1.gif

本系列文章包含每平每屋过去一年在召回、排序和冷启动等模块中的一些探索和实践经验,本文为该专题的第四篇。

第一篇指路:冷启动系统优化与内容潜力预估实践

二篇指路:GNN在轻应用内容推荐中的召回实践

第三篇指路:基于特征全埋点的精排ODL实践总结

前言

在每平每屋以及我们团队负责的其他一些场景,MMoE多任务模型是精排阶段常用的模型。在每平每屋场景中,我们使用MMoE模型,对场景中的三个任务同时进行点击率预估,即内容的一跳点击、内容详情页点击以及跳转的商品详情页点击。MMoE结构有别于简单的Hard Parameter Sharing的方式,可以通过模型去学习不同任务之间哪些参数共享(Experts),共享的程度如何(Gates)

图片.png

建模

MMoE是从Model Structure的角度来对共享参数做精细化控制,并在很多场景都取得了提升。除了Model Structure之外,我们还可以从Model Dynamics的角度去考虑优化。由于不同任务之间天然存在着数据量、样本分布、难易程度的差别,导致在训练过程中,收敛速度、收敛结果都会存在不小的差异。Model Dynamics希望在模型训练过程中去动态调整训练的速度等变量。

建模

图片.png

▐  Dynamic Weight Averaging

图片.png


图片.png

▐  Dynamic Task Prioritization

图片.png

▐  Gradient Normalization

图片.png

图片.pngimage.gif

GradNorm实践

图片.png


▐  离线对比

图片.png


auc结果波动性


我们对同样的模型参数组,连续训练三次,计算每次结果三个任务auc的均值和方差,我们可以发现,在不使用gn的基准组,一跳点击、二跳点击、商详点击的预估auc方差逐次增大,这也表明二跳点击和商详点击的预估任务,在相同训练参数下波动较大,表明任务训练的结果不好,这对于线上结果显然是有害的。而在使用了gn的实验组,auc的方差都保持在相对较低的水平,表明这几个任务的训练结果应当是稳定的。


type click auc detail auc fav auc
baseline 0.6720 ± 0.0013 0.6886 ± 0.0057 0.6915 ± 0.0095
baseline + gn 0.6828 ± 0.0023 0.7078 ± 0.0014 0.7155 ± 0.0012

baseline 平均波动:0.0055

baseline + gn 平均波动:0.0016

整体训练波动被大幅缩小,尤其是在正样本比较稀疏的 detail 任务和 fav 任务上,此外加入gn之后整体auc有明显的提升。


▐  线上AB

线上ab结果,pctcvr +3.49%,avg_ipv +4.04%

ab_name

uctr_improv

pctr_improv

avg_expo_improv

avg_click_improv

ipv_uctr_improv

ipv_pctr_improv

uctcvr_improv

pctcvr_improv

avg_ipv_improv

GradNorm

2.25%

4.59%

0.51%

5.13%

-0.19%

-1.05%

2.06%

3.49%

4.04%

Baseline AA

0.00%

-0.15%

0.43%

0.27%

0.03%

0.02%

0.03%

-0.13%

0.26%

Baseline

0.00%

0.00%

0.00%

0.00%

0.00%

0.00%

0.00%

0.00%

0.00%

总结

在实际推荐、搜索场景中,我们所优化的业务目标往往并不单一,使用多任务、多目标模型在目前业界比较常见。不同于MMoE、PLE等Model Structure优化,我们从Model Dynamic的角度出发,人为定义一系列"好的"Model Dynamic标准,并在建模中应用,取得了在MMoE基础上的进一步提升。两类优化具有互补性,在实际使用中可以同时应用。

团队介绍

*本项工作是本组实习生知同在职时合作完成。

淘系技术部-淘宝智能团队

淘宝智能团队是一支数据和算法一体的团队,服务于淘宝、天猫、聚划算、闲鱼和每平每屋等业务线的二十余个业务场景,提供线上零售、内容社区、3D智能设计和端上智能等数据和算法服务。我们通过机器学习、强化学习、数据挖掘、机器视觉、NLP、运筹学、3D算法、搜索和推荐算法,为千万商家寻找商机,为平台运营提供智能化方案,为用户提高使用体验,为设计师提供自动搭配和布局,从而促进平台和生态的供给繁荣和用户增长,不断拓展商业边界。

这是一支快速成长中的学习型团队。在创造业务价值的同时,我们不断输出学术成果,在KDD、ICCV、Management Science等国际会议和杂志上发表数篇学术论文。团队学习氛围浓厚,每年组织上百场技术分享交流,互相学习和启发。真诚邀请海内外相关方向的优秀人才加入我们。如果您有兴趣可将简历发至bangzhu.gx@alibaba-inc.com,期待您的加入!

参考文献

[1] 多任务学习优化(Optimization in Multi-task learning) https://zhuanlan.zhihu.com/p/269492239[2] Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts, KDD 2018 https://dl.acm.org/doi/10.1145/3219819.3220007[3] End-to-End Multi-Task Learning with Attention, CVPR 2019 https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf[4] Dynamic task prioritization for multitask learning, ECCV 2018 https://openaccess.thecvf.com/content_ECCV_2018/papers/Michelle_Guo_Focus_on_the_ECCV_2018_paper.pdf[5] GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks, ICML 2018 http://proceedings.mlr.press/v80/chen18a/chen18a.pdf


相关文章
|
存储 JSON 自然语言处理
手把手教你使用ModelScope训练一个文本分类模型
手把手教你使用ModelScope训练一个文本分类模型
|
Linux
Linux:ln创建删除软连接
Linux:ln创建删除软连接
2431 0
|
Android开发
教你在Android手机上使用全局代理!
前言:在Android上使用系统自带的代理,限制灰常大,仅支持系统自带的浏览器。这样像QQ、飞信、微博等这些单独的App都不能使用系统的代理。如何让所有软件都能正常代理呢?ProxyDroid这个软件能帮你解决!使用方法及步骤如下: 一、推荐从Google Play下载ProxyDroid,目前最新版本是v2.6.6。
17511 0
|
3月前
|
人工智能 监控 搜索推荐
构建AI智能体:八十三、当AI开始“失忆“:深入理解和预防模型衰老与数据漂移
AI模型会因数据分布变化和时间推移而性能下降,即“模型衰老”与“数据漂移”。如同知识过时,旧模型难以适应新环境,导致预测不准。需通过PSI、KS等指标监测,并定期重训练以保持其有效性。
283 8
|
11月前
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
1210 17
|
9月前
|
机器学习/深度学习 存储 Java
Java 大视界 -- Java 大数据机器学习模型在游戏用户行为分析与游戏平衡优化中的应用(190)
本文探讨了Java大数据与机器学习模型在游戏用户行为分析及游戏平衡优化中的应用。通过数据采集、预处理与聚类分析,开发者可深入洞察玩家行为特征,构建个性化运营策略。同时,利用回归模型优化游戏数值与付费机制,提升游戏公平性与用户体验。
|
11月前
|
机器学习/深度学习 存储 自然语言处理
常用的CTR领域经典机器模型介绍
逻辑回归(Logistic Regression)是经典的统计学习算法,因其简单、高效和可大规模并行化的特点,在早期工业机器学习中占据重要地位。它通过手动设计特征实现非线性学习能力,适用于CTR预估等场景。 梯度提升决策树(Gradient Boosting Decision Tree, GBDT)是一种迭代决策树算法,通过多棵回归树的累加结果进行预测,具有较强的泛化能力。其核心思想是最小化平方误差来优化分枝依据,并利用残差拟合提升模型性能。
|
SQL 测试技术 API
如何编写API接口的自动化测试脚本
本文详细介绍了编写API自动化测试脚本的方法和最佳实践,涵盖确定测试需求、选择测试框架、编写测试脚本(如使用Postman和Python Requests库)、参数化和数据驱动测试、断言和验证、集成CI/CD、生成测试报告及维护更新等内容,旨在帮助开发者构建高效可靠的API测试体系。
|
存储 人工智能 搜索推荐
RAG系统的7个检索指标:信息检索任务准确性评估指南
大型语言模型(LLMs)在生成式AI领域备受关注,但其知识局限性和幻觉问题仍具挑战。检索增强生成(RAG)通过引入外部知识和上下文,有效解决了这些问题,并成为2024年最具影响力的AI技术之一。RAG评估需超越简单的实现方式,建立有效的性能度量标准。本文重点讨论了七个核心检索指标,包括准确率、精确率、召回率、F1分数、平均倒数排名(MRR)、平均精确率均值(MAP)和归一化折损累积增益(nDCG),为评估和优化RAG系统提供了重要依据。这些指标不仅在RAG中发挥作用,还广泛应用于搜索引擎、电子商务、推荐系统等领域。
8216 2
RAG系统的7个检索指标:信息检索任务准确性评估指南