多任务学习模型之DBMTL介绍与实现

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 本文介绍的是阿里在2019年发表的多任务学习算法。该模型显示地建模目标间的贝叶斯网络因果关系,整合建模了特征和多个目标之间的复杂因果关系网络,省去了一般MTL模型中较强的独立假设。由于不对目标分布做任何特定假设,使得它能够比较自然地推广到任意形式的目标上。

多任务学习背景

目前工业中使用的推荐算法已不只局限在单目标(ctr)任务上,还需要关注后续的转化链路,如是否评论、收藏、加购、购买、观看时长等目标。

常见的多目标优化模型是从每个优化目标单独的模型网络出发,通过让这些网络在底层共享参数,实现各目标相关模型的适当程度的独立性和相关性。这类的模型框架可以用上图的结构来概括。不论底层如何共享参数,这些网络在最后几层都要伸出一些独立分支来预测各个目标的最终值。此类网络的概率模型可以用下述公式描述:

其中l,m 为目标,x为样本特征,H为模型。这里做了各目标独立的假设。


DBMTL介绍

DBMTL(Deep Bayesian Multi-Target Learning)的一个出发点就是解决上述问题。事实上套用简单的贝叶斯公式,概率模型可以写成:

如下图所示,DBMTL与传统MTL结构(认为各目标独立)最主要差别在于构建了target node之间的贝叶斯网络,显式建模了目标间可能存在的因果关系。因为在实际业务中,用户的很多行为往往存在明显的序列先后依赖关系,例如在信息流场景,用户要先点进图文详情页,才会进行后续的浏览/评论/转发/收藏 等操作。DBMTL在模型结构中体现了这些关系,因此,往往能学到更好的结果。


下图是DBMTL模型的具体实现。网络包含输入层、共享embedding层、共享层,区别层和贝叶斯层。

  • 共享embedding层是一个共享的lookup table,为各个target训练所共享。
  • 共享层和分离层是一般的multilayer perceptron (MLP),分别建模各目标的共享/区别表示。
  • Bayesian层是DBMTL中最重要的部分。它实现了如下的概率模型:

其对应的log-likelihood损失函数为:

实际应用中,对不同目标调权仍有着较大的现实作用。当对目标赋予不同权重时,相当于把损失函数重新表达为:

在网络的贝叶斯层中,函数f1, f2, f3 被实现为全连接的MLP,以学习目标间的隐含因果关系。他们把函数输入变量的embedding级联作为输入,并输入一个表示函数输出变量的embedding。每一个目标的embedding最后再经过一层MLP以输出最终目标的概率。



代码实现

基于EasyRec推荐算法框架,我们实现了DBMTL算法,具体实现可移步至github:EasyRec-DBMTL

EasyRec介绍:EasyRec是阿里云计算平台机器学习PAI团队开源的大规模分布式推荐算法框架,EasyRec 正如其名字一样,简单易用,集成了诸多优秀前沿的推荐系统论文思想,并且在实际工业落地中取得优良效果的特征工程方法,集成训练、评估、部署,与阿里云产品无缝衔接,可以借助 EasyRec 在短时间内搭建起一套前沿的推荐系统。作为阿里云的拳头产品,现已稳定服务于数百个企业客户。


模型前馈网络

def build_predict_graph(self):

   """Forward function.

   Returns:

     self._prediction_dict: Prediction result of two tasks.

   """

   # 此处从共享embedding层后的tensor(self._features)开始,省略其生成逻辑

 

   # shared layer

   if self._model_config.HasField('bottom_dnn'):

       bottom_dnn = dnn.DNN(

           self._model_config.bottom_dnn,

           self._l2_reg,

           name='bottom_dnn',

           is_training=self._is_training)

       bottom_fea = bottom_dnn(self._features)

   else:

       bottom_fea = self._features

   # MMOE block

   if self._model_config.HasField('expert_dnn'):

       mmoe_layer = mmoe.MMOE(

           self._model_config.expert_dnn,

           l2_reg=self._l2_reg,

           num_task=self._task_num,

           num_expert=self._model_config.num_expert)

       task_input_list = mmoe_layer(bottom_fea)

   else:

       task_input_list = [bottom_fea] * self._task_num

   tower_features = {}

   # specific layer

   for i, task_tower_cfg in enumerate(self._model_config.task_towers):

       tower_name = task_tower_cfg.tower_name

       if task_tower_cfg.HasField('dnn'):

           tower_dnn = dnn.DNN(

               task_tower_cfg.dnn,

               self._l2_reg,

               name=tower_name + '/dnn',

               is_training=self._is_training)

           tower_fea = tower_dnn(task_input_list[i])

           tower_features[tower_name] = tower_fea

       else:

           tower_features[tower_name] = task_input_list[i]

   tower_outputs = {}

   relation_features = {}

   # bayesian network

   for task_tower_cfg in self._model_config.task_towers:

       tower_name = task_tower_cfg.tower_name

       relation_dnn = dnn.DNN(

           task_tower_cfg.relation_dnn,

           self._l2_reg,

           name=tower_name + '/relation_dnn',

           is_training=self._is_training)

       tower_inputs = [tower_features[tower_name]]

       for relation_tower_name in task_tower_cfg.relation_tower_names:

           tower_inputs.append(relation_features[relation_tower_name])

       relation_input = tf.concat(

           tower_inputs, axis=-1, name=tower_name + '/relation_input')

       relation_fea = relation_dnn(relation_input)

       relation_features[tower_name] = relation_fea

       output_logits = tf.layers.dense(

           relation_fea,

           task_tower_cfg.num_class,

           kernel_regularizer=self._l2_reg,

           name=tower_name + '/output')

       tower_outputs[tower_name] = output_logits

       self._add_to_prediction_dict(tower_outputs)

Loss计算

def build(loss_type, label, pred, loss_weight=1.0, num_class=1, **kwargs):

   if loss_type == LossType.CLASSIFICATION:

       if num_class == 1:

           return tf.losses.sigmoid_cross_entropy(

             label, logits=pred, weights=loss_weight, **kwargs)

       else:

           return tf.losses.sparse_softmax_cross_entropy(

             labels=label, logits=pred, weights=loss_weight, **kwargs)

   elif loss_type == LossType.CROSS_ENTROPY_LOSS:

       return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs)

   elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:

       logging.info('%s is used' % LossType.Name(loss_type))

       return tf.losses.mean_squared_error(

           labels=label, predictions=pred, weights=loss_weight, **kwargs)

   elif loss_type == LossType.PAIR_WISE_LOSS:

       return pairwise_loss(pred, label)

   else:

       raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))


def _build_loss_impl(self,

                    loss_type,

                    label_name,

                    loss_weight=1.0,

                    num_class=1,

                    suffix=''):

   loss_dict = {}

   if loss_type == LossType.CLASSIFICATION:

       loss_name = 'cross_entropy_loss' + suffix

       pred = self._prediction_dict['logits' + suffix]

   elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:

       loss_name = 'l2_loss' + suffix

       pred = self._prediction_dict['y' + suffix]

   else:

       raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))

       loss_dict[loss_name] = build(loss_type,

                                    self._labels[label_name],

                                    pred,

                                    loss_weight, num_class)

   return loss_dict


def build_loss_graph(self):

   """Build loss graph for multi task model."""

   for task_tower_cfg in self._task_towers:

       tower_name = task_tower_cfg.tower_name

       loss_weight = task_tower_cfg.weight * self._sample_weight

       if hasattr(task_tower_cfg, 'task_space_indicator_label') and \

       task_tower_cfg.HasField('task_space_indicator_label'):

           in_task_space = tf.to_float(

               self._labels[task_tower_cfg.task_space_indicator_label] > 0)

           loss_weight = loss_weight * (

               task_tower_cfg.in_task_space_weight * in_task_space +

               task_tower_cfg.out_task_space_weight * (1 - in_task_space))

           # EasyRec框架会自动对self._loss_dict中的loss进行加和。

           self._loss_dict.update(

               self._build_loss_impl(

                   task_tower_cfg.loss_type,

                   label_name=self._label_name_dict[tower_name],

                   loss_weight=loss_weight,

                   num_class=task_tower_cfg.num_class,

                   suffix='_%s' % tower_name))


   return self._loss_dict


应用

由于其卓越的算法效果,DBMTL在PAI上被大量使用。

以某直播推荐业务为例,该场景有is_click, is_view, view_costtime, is_on_mic, on_mic_duration多个目标,其中is_click, is_view, is_on_mic为二分类任务,view_costtime, on_mic_duration为预测时长的回归任务。用户行为的依赖关系为:

  • is_click=> is_view
  • is_click+is_view=> view_costtime
  • is_click=> is_on_mic
  • is_click+is_on_mic => on_mic_duration

因此配置如下:

dbmtl {

 bottom_dnn {

 hidden_units: [512, 256]

}

task_towers {

 tower_name: "is_click"

 label_name: "is_click"

 loss_type: CLASSIFICATION

 metrics_set: {

 auc {}

}

dnn {

 hidden_units: [128, 96, 64]

}

relation_dnn {

 hidden_units: [32]

}

weight: 1.0

}

task_towers {

 tower_name: "is_view"

 label_name: "is_view"

 loss_type: CLASSIFICATION

 metrics_set: {

 auc {}

}

dnn {

 hidden_units: [128, 96, 64]

}

relation_tower_names: ["is_click"]

relation_dnn {

 hidden_units: [32]

}

weight: 1.0

}

task_towers {

 tower_name: "view_costtime"

 label_name: "view_costtime"

 loss_type: L2_LOSS

 metrics_set: {

 mean_squared_error {}

}

dnn {

 hidden_units: [128, 96, 64]

}

relation_tower_names: ["is_click", "is_view"]

relation_dnn {

 hidden_units: [32]

}

weight: 1.0

}    

task_towers {

 tower_name: "is_on_mic"

 label_name: "is_on_mic"

 loss_type: CLASSIFICATION

 metrics_set: {

 auc {}

}

dnn {

 hidden_units: [128, 96, 64]

}

relation_tower_names: ["is_click"]

relation_dnn {

 hidden_units: [32]

}

weight: 1.0

}

task_towers {

 tower_name: "on_mic_duration"

 label_name: "on_mic_duration"

 loss_type: L2_LOSS

 metrics_set: {

 mean_squared_error {}

}

dnn {

 hidden_units: [128, 96, 64]

}

relation_tower_names: ["is_click", "is_on_mic"]

relation_dnn {

 hidden_units: [32]

}

weight: 1.0

}

l2_regularization: 1e-6

}

embedding_regularization: 5e-6

}


值得一提的是,DBMTL模型上线后,相比GBDT+FM(围观单目标)线上围观率提升18%,上麦率提升14%。


参考文献

EasyRec-DBMTL模型介绍

EasyRec-DBMTL模型源码

注:本文图片及公式均引用自论文:DBMTL论文

若有收获,就点个赞吧

相关实践学习
使用PAI+LLaMA Factory微调Qwen2-VL模型,搭建文旅领域知识问答机器人
使用PAI和LLaMA Factory框架,基于全参方法微调 Qwen2-VL模型,使其能够进行文旅领域知识问答,同时通过人工测试验证了微调的效果。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
机器学习/深度学习 分布式计算 DataWorks
EasyRec 使用介绍|学习笔记
快速学习 EasyRec 使用介绍。
1735 0
|
机器学习/深度学习 自然语言处理 达摩院
长文本口语语义理解技术系列①:段落分割实践
长文本口语语义理解技术系列①:段落分割实践
1389 0
长文本口语语义理解技术系列①:段落分割实践
|
机器学习/深度学习 搜索推荐 算法
多任务学习之mmoe理论详解与实践
多任务学习之mmoe理论详解与实践
多任务学习之mmoe理论详解与实践
|
机器学习/深度学习 算法 流计算
深度预测平台RTP介绍
前言 RTP平台是阿里内部一个通用的在线预测平台,不仅支持淘系搜索、推荐、聚划算、淘金币等业务,也支持国际化相关icbu、lazada等搜索推荐业务,同时还支持着淘客,优酷、飞猪等大文娱的搜索推荐场景。
9408 0
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
阿里云人工智能平台 PAI 开源 EasyDistill 框架助力大语言模型轻松瘦身
本文介绍了阿里云人工智能平台 PAI 推出的开源工具包 EasyDistill。随着大语言模型的复杂性和规模增长,它们面临计算需求和训练成本的障碍。知识蒸馏旨在不显著降低性能的前提下,将大模型转化为更小、更高效的版本以降低训练和推理成本。EasyDistill 框架简化了知识蒸馏过程,其具备多种功能模块,包括数据合成、基础和进阶蒸馏训练。通过数据合成,丰富训练集的多样性;基础和进阶蒸馏训练则涵盖黑盒和白盒知识转移策略、强化学习及偏好优化,从而提升小模型的性能。
|
9月前
|
JSON 文字识别 数据可视化
Qwen2-VL微调实战:LaTex公式OCR识别任务(完整代码)
《SwanLab机器学习实战教程》推出了一项基于Qwen2-VL大语言模型的LaTeX OCR任务,通过指令微调实现多模态LLM的应用。本教程详述了环境配置、数据集准备、模型加载、SwanLab集成及微调训练等步骤,旨在帮助开发者轻松上手视觉大模型的微调实践。
|
10月前
|
人工智能 边缘计算 自然语言处理
DistilQwen2:通义千问大模型的知识蒸馏实践
DistilQwen2 是基于 Qwen2大模型,通过知识蒸馏进行指令遵循效果增强的、参数较小的语言模型。本文将介绍DistilQwen2 的技术原理、效果评测,以及DistilQwen2 在阿里云人工智能平台 PAI 上的使用方法,和在各开源社区的下载使用教程。
|
搜索推荐 测试技术
淘宝粗排问题之在粗排模型中引入交叉特征如何解决
淘宝粗排问题之在粗排模型中引入交叉特征如何解决
|
机器学习/深度学习 存储 搜索推荐
连续迁移学习跨域推荐排序模型在淘宝推荐系统的应用
本文探讨了如何在工业界的连续学习的框架下实现跨域推荐模型,提出了连续迁移学习这一新的跨域推荐范式,利用连续预训练的源域模型的中间层表征结果作为目标域模型的额外知识,设计了一个轻量级的Adapter模块实现跨域知识的迁移,并在有好货推荐排序上取得了显著业务效果。
1195 0
连续迁移学习跨域推荐排序模型在淘宝推荐系统的应用
|
SQL 存储 Java
Hive 特殊的数据类型 Array、Map、Struct
在Hive中,`Array`、`Map`和`Struct`是三种特殊的数据类型。`Array`用于存储相同类型的列表,如`select array(1, "1", 2, 3, 4, 5)`会产生一个整数数组。`Map`是键值对集合,键值类型需一致,如`select map(1, 2, 3, "4")`会产生一个整数到整数的映射。`Struct`表示结构体,有固定数量和类型的字段,如`select struct(1, 2, 3, 4)`创建一个无名结构体。这些类型支持嵌套使用,允许更复杂的结构数据存储。例如,可以创建一个包含用户结构体的数组来存储多用户信息
1933 0