Facebook出手!自适应梯度打败人工调参

简介: Facebook出手!自适应梯度打败人工调参

大家好,我是对白。


多任务模型中各个任务难以调参、收敛、效果平平,是一件令人头疼的事情。有没有什么可以令人省心省力的“自适应”方法呢?笔者浏览了一些最近的顶会文章,读了一些相关文章,今天挑选一篇分享给大家~


说到多任务学习,想必大家都不陌生。在理想的推荐场景中,通过与辅助任务的联合学习,可以提升目标任务的预测效果。例如,在社交推荐中,用户偏好的学习可以与辅助任务(预测用户之间的联系与信任)联合训练。


但理想归理想,现实却往往“事故多发”。说不定,在加了多个辅助任务,并经历艰难的调参之后,目标任务的效果却令人头秃。直觉分析原因,有两种可能:


当辅助任务产生了比目标任务更大的影响,甚至支配了网络权重时,目标任务的结果会变得更差;


另一个极端,当一个或多个辅助任务的影响太弱时,则无法帮助目标任务提升效果。


更具可能性的情况是,上述两个可能原因在训练的过程中交替出现,并且在同一个网络的不同部分之间变化。


在多任务学习中,模型的训练loss通常由多个损失函数加权得到,而不同任务重要程度往往是需要人为设参的,使得我们可能在调参问题上时间花费较多。


这篇文章所提的方法MetaBalance采用对辅助任务动态梯度调整的方式,取代对不同任务的权重调整,在NDCG@10上针对两数据集取得了8.34%的改进



论文标题:MetaBalance: Improving Multi-Task Recommendations via Adapting Gradient Magnitudes of Auxiliary Tasks


论文链接:https://doi.org/10.1145/3485447.3512093


论文代码:https://github.com/facebookresearch/MetaBalance


一、MetaBalance核心思想



本文由Meta AI发表于WWW2022上。在看文章具体内容之前,我们先回忆一下多任务学习的损失函数以及梯度更新规则:


从梯度的角度再去解释一下辅助任务对目标任务产生负面影响的原因:


多任务网络通常由具有共享参数的bottom layer和几个ask-specific layers组成,如下图:



在训练中,每个任务都有相应的loss,并相对于多任务网络的共享参数具有相应的梯度。这些梯度的综合会影响参数的更新方式,并且梯度越大,对共享参数的影响越大


  • 当辅助loss的梯度远大于目标loss的梯度时,辅助任务相比目标函数会对共享参数产生更大的影响,导致最终目标任务的性能下降;


  • 反之,辅助任务的影响太弱,则无法辅助目标任务。


这种情况其实很常见,如下图阿里巴巴的两个例子,分别对应了上述两种情况。



二、算法细节

如何调整辅助梯度大小呢?针对上面的分析,我们可以看到有两种不适宜的梯度情况,对此MetaBalance提出了三种策略:

  1. 当远大于时,应能自适应减少;


  1. 当远小于时,应能自适应增大;


  1. 如有必要,可同时进行1与2策略。


该策略是根据目标任务在验证数据集上的性能选择的,这是针对特定任务和数据集的经验最佳策略。


为了实现的自适应变化,文章提出可以平衡的*动态权重*。


到目前为止,一个基本的算法流程为:



然而,强制辅助梯度与目标梯度具有完全相同的大小,一定可以实现目标任务的最佳值吗?文章对此提出采用 relax factor 来调节辅助梯度与目标梯度大小的接近度


采用上式来代替算法1中的步骤6。可以看到,当 越接近1,两种梯度的接近度越高。



不仅如此, 实际上会影响每个辅助任务的权重,将算法1中的第六行改写为:



那么,应该如何选择合适的 值呢?


由于 仅用于反向传播,不含任何loss的梯度,故作为超参数,在验证集进行优化。


这里注意,尽管所有辅助任务均采用相同的 值,但并不表示它们具有相同的权重或梯度大小,可参考上图公式(4)。


最后,应用相应梯度的移动平均值,以训练迭代中所有梯度之间的方差


最终算法的伪代码为:



三、实验结果

实验结果总体不错,下面展示了三种梯度调整策略的实验结果,可见在UserBehavior-2017中均显著优于vanilla多任务学习baseline(“vanilla multi”),并且策略C在IJCAI-2015中显著优于基线,这表明了MetaBalance的有效性和鲁棒性。



下面展示了对超参数relax factor 的研究分析。



四、总结

本文从梯度大小的角度提出了MetaBalance来调整辅助任务,以更好地辅助目标任务。该方法可以防止目标任务不受辅助任务的支配,也可以避免一个或多个辅助任务被忽略。此外,辅助梯度针对网络的每个部分,在整个训练过程中实现了动态、自适应的平衡


文末留两个思考题:


  1. 从梯度的视角下去理解多任务学习,并不是一个新鲜的话题,你还知道哪些有效的梯度调整方法?


  1. 类似的思路是否可以用到其他领域呢,例如多模态学习?


相关文章
|
计算机视觉
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
200 0
|
26天前
|
机器学习/深度学习 人工智能 算法
谷歌DeepMind:GPT-4高阶心智理论彻底击败人类!第6阶推理讽刺暗示全懂了
【6月更文挑战第10天】谷歌DeepMind团队的最新论文显示,GPT-4在高阶心智理论任务中超越了人类水平,这是AI在理解和推理人类心理状态上的重大突破。研究人员通过MoToMQA测试套件评估了大型语言模型,发现GPT-4在第6阶推理上超过成人表现。这一进展意味着AI能更好地理解用户意图,提升交互体验,但也引发了关于操纵与控制人类以及模型是否真正理解心理状态的担忧。论文链接:https://arxiv.org/pdf/2405.18870
28 3
|
2月前
|
算法 数据可视化 图形学
超越GIoU/DIoU/CIoU/EIoU | MPDIoU让YOLOv7/YOLACT双双涨点,速度不减!
超越GIoU/DIoU/CIoU/EIoU | MPDIoU让YOLOv7/YOLACT双双涨点,速度不减!
82 0
|
机器学习/深度学习 编解码 vr&ar
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(一)
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(一)
191 0
|
机器学习/深度学习 自然语言处理
十年来论文量激增,深度学习如何慢慢推开数学推理的门(2)
十年来论文量激增,深度学习如何慢慢推开数学推理的门
|
机器学习/深度学习 消息中间件 人工智能
十年来论文量激增,深度学习如何慢慢推开数学推理的门(1)
十年来论文量激增,深度学习如何慢慢推开数学推理的门
150 0
谷歌、DeepMind新研究:归纳偏置如何影响模型缩放?
谷歌、DeepMind新研究:归纳偏置如何影响模型缩放?
|
机器学习/深度学习 算法 数据挖掘
图神经网络发Nature子刊,却被爆比普通算法慢104倍,质疑者:灌水新高度?
图神经网络发Nature子刊,却被爆比普通算法慢104倍,质疑者:灌水新高度?
|
自然语言处理 算法 计算机视觉
陈丹琦组掩蔽语言模型研究引争议:15%掩蔽率不是最佳,但40%站得住脚吗?
陈丹琦组掩蔽语言模型研究引争议:15%掩蔽率不是最佳,但40%站得住脚吗?
|
人工智能 自然语言处理 搜索推荐
ChatGPT之后性能怪兽来了?马库斯7大「黑暗」预测:GPT-4带不来AGI
ChatGPT之后性能怪兽来了?马库斯7大「黑暗」预测:GPT-4带不来AGI
115 0

热门文章

最新文章

  • 1
    流量控制系统,用正则表达式提取汉字
    25
  • 2
    Redis09-----List类型,有序,元素可以重复,插入和删除快,查询速度一般,一般保存一些有顺序的数据,如朋友圈点赞列表,评论列表等,LPUSH user 1 2 3可以一个一个推
    26
  • 3
    Redis08命令-Hash类型,也叫散列,其中value是一个无序字典,类似于java的HashMap结构,Hash结构可以将对象中的每个字段独立存储,可以针对每字段做CRUD
    27
  • 4
    Redis07命令-String类型字符串,不管是哪种格式,底层都是字节数组形式存储的,最大空间不超过512m,SET添加,MSET批量添加,INCRBY age 2可以,MSET,INCRSETEX
    27
  • 5
    S外部函数可以访问函数内部的变量的闭包-闭包最简单的用不了,闭包是内层函数+外层函数的变量,简称为函数套函数,外部函数可以访问函数内部的变量,存在函数套函数
    24
  • 6
    Redis06-Redis常用的命令,模糊的搜索查询往往会对服务器产生很大的压力,MSET k1 v1 k2 v2 k3 v3 添加,DEL是删除的意思,EXISTS age 可以用来查询是否有存在1
    31
  • 7
    Redis05数据结构介绍,数据结构介绍,官方网站中看到
    22
  • 8
    JS字符串数据类型转换,字符串如何转成变量,+号只要有一个是字符串,就会把另外一个转成字符串,- * / 都会把数据转成数字类型,数字型控制台是蓝色,字符型控制台是黑色,
    20
  • 9
    JS数组操作---删除,arr.pop()方法从数组中删除最后一个元素,并返回该元素的值,arr.shift() 删除第一个值,arr.splice()方法,删除指定元素,arr.splice,从第一
    20
  • 10
    定义好变量,${age}模版字符串,对象可以放null,检验数据类型console.log(typeof str)
    19