【多任务学习】Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics

简介: 【多任务学习】Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics

·阅读摘要:

 本文提出针对CV领域的多任务模型,设置一个可以学习损失权重的损失层,可以提高模型精度。

·参考文献:

 [1] Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics

  个人理解:我们使用传统的多任务时,损失函数一般都是各个任务的损失相加,最多会为每个任务的损失前添加权重系数。但是这样的超参数是很难去调参的,代价大,而且很难去调到一个最好的状态。最好的方式应该是交给深度学习。

  论文最重要的部分在损失函数的设置与推导。 这对我们优化自己的多任务学习模型有指导意义。

[1] Homoscedastic uncertainty as task-dependent uncertainty (同方差不确定性)


 作者的数学模型通过贝叶斯模型建立。作者首先提出贝叶斯建模中存在两类不确定性:

 · 认知不确定性(Epistemic uncertainty):由于缺少训练数据而引起的不确定性

 · 偶然不确定性(Aleatoric uncertainty):由于训练数据无法解释信息而引起的不确定性

 而对于偶然不确定性,又分为如下两个子类:

 · 数据依赖地(Data-dependant)或异方差(Heteroscedastic)不确定性

 · 任务依赖地(Task-dependant)或同方差(Homoscedastic)不确定性

 多任务中,任务不确定性捕获任务间相关置信度,反应回归或分类任务的内在不确定性。

【注】本篇论文的假设,是基于同方差不确定性的。关于同方差不确定性和异方差不确定性的通俗解释,可以参考知乎问题:https://www.zhihu.com/question/278182454/answer/398539763

[2] Multi-task likelihoods (多任务似然)


image.png

  对于分类任务有:

image.png

  多任务的概率:

image.png

  例如对于回归任务来说,极大似然估计转化为最小化负对数:

image.png

image.png

各个公式的证明


040ea3618417442cae10b3dd2f6fdc50.png

image.png

pytorch版代码实现


image.png

 代码如下:

import math
import pylab
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
def gen_data(N):
    X = np.random.randn(N, 1)
    w1 = 2.
    b1 = 8.
    sigma1 = 1e1  # ground truth
    Y1 = X.dot(w1) + b1 + sigma1 * np.random.randn(N, 1)
    w2 = 3
    b2 = 3.
    sigma2 = 1e0  # ground truth
    Y2 = X.dot(w2) + b2 + sigma2 * np.random.randn(N, 1)
    return X, Y1, Y2
class TrainData(Dataset):
    def __init__(self, feature_num, X, Y1, Y2):
        self.feature_num = feature_num
        self.X = torch.tensor(X, dtype=torch.float32)
        self.Y1 = torch.tensor(Y1, dtype=torch.float32)
        self.Y2 = torch.tensor(Y2, dtype=torch.float32)
    def __len__(self):
        return self.feature_num
    def __getitem__(self, idx):
        return self.X[idx,:], self.Y1[idx,:], self.Y2[idx,:]
class MultiTaskLossWrapper(nn.Module):
    def __init__(self, task_num, model):
        super(MultiTaskLossWrapper, self).__init__()
        self.model = model
        self.task_num = task_num
        self.log_vars = nn.Parameter(torch.zeros((task_num)))
    def forward(self, input, targets):
        outputs = self.model(input)
        precision1 = torch.exp(-self.log_vars[0])
        loss = torch.sum(precision1 * (targets[0] - outputs[0]) ** 2. + self.log_vars[0], -1)
        precision2 = torch.exp(-self.log_vars[1])
        loss += torch.sum(precision2 * (targets[1] - outputs[1]) ** 2. + self.log_vars[1], -1)
        loss = torch.mean(loss)
        return loss, self.log_vars.data.tolist()
class MTLModel(torch.nn.Module):
    def __init__(self, n_hidden, n_output):
        super(MTLModel, self).__init__()
        self.net1 = nn.Sequential(nn.Linear(1, n_hidden), nn.ReLU(), nn.Linear(n_hidden, n_output))
        self.net2 = nn.Sequential(nn.Linear(1, n_hidden), nn.ReLU(), nn.Linear(n_hidden, n_output))
    def forward(self, x):
        return [self.net1(x), self.net2(x)]
np.random.seed(0)
feature_num = 100
nb_epoch = 2000
batch_size = 20
hidden_dim = 1024
X, Y1, Y2 = gen_data(feature_num)
pylab.figure(figsize=(3, 1.5))
pylab.scatter(X[:, 0], Y1[:, 0])
pylab.scatter(X[:, 0], Y2[:, 0])
pylab.show()
train_data = TrainData(feature_num, X, Y1, Y2)
train_data_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
model = MTLModel(hidden_dim, 1)
mtl = MultiTaskLossWrapper(2, model)
mtl
# https://github.com/keras-team/keras/blob/master/keras/optimizers.py
# k.epsilon() = keras.backend.epsilon()
optimizer = torch.optim.Adam(mtl.parameters(), lr=0.001, eps=1e-07)
loss_list = []
for t in range(nb_epoch):
    cumulative_loss = 0
    for X, Y1, Y2 in train_data_loader:
        loss, log_vars = mtl(X, [Y1, Y2])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        cumulative_loss += loss.item()
    loss_list.append(cumulative_loss/batch_size)
pylab.plot(loss_list)
pylab.show()
print(log_vars)
[4.2984442710876465, -0.2037072628736496]
# Found standard deviations (ground truth is 10 and 1):
print([math.exp(log_var) ** 0.5 for log_var in log_vars])
[8.578183137529612, 0.9031617364804738]

image.png

相关文章
|
6月前
|
数据挖掘
【提示学习】Automatic Multi-Label Prompting: Simple and Interpretable Few-Shot Classification
文章提出了一种简单确高效地构建verbalization的方法:
|
8月前
|
数据挖掘
MUSIED: A Benchmark for Event Detection from Multi-Source Heterogeneous Informal Texts 论文解读
事件检测(ED)从非结构化文本中识别和分类事件触发词,作为信息抽取的基本任务。尽管在过去几年中取得了显著进展
43 0
|
8月前
|
机器学习/深度学习 自然语言处理 数据挖掘
UnifiedEAE: A Multi-Format Transfer Learning Model for Event Argument Extraction via Variational论文解读
事件论元抽取(Event argument extraction, EAE)旨在从文本中抽取具有特定角色的论元,在自然语言处理中已被广泛研究。
49 0
|
8月前
|
机器学习/深度学习 自然语言处理 数据可视化
M2E2: Cross-media Structured Common Space for Multimedia Event Extraction 论文解读
我们介绍了一个新的任务,多媒体事件抽取(M2E2),旨在从多媒体文档中抽取事件及其参数。我们开发了第一个基准测试
61 0
|
8月前
|
机器学习/深度学习 自然语言处理 算法
SS-AGA:Multilingual Knowledge Graph Completion with Self-Supervised Adaptive Graph Alignment 论文解读
预测知识图(KG)中缺失的事实是至关重要的,因为现代知识图远未补全。由于劳动密集型的人类标签,当处理以各种语言表示的知识时,这种现象会恶化。
56 0
|
8月前
|
机器学习/深度学习 自然语言处理 数据可视化
EventGraph:Event Extraction as Semantic Graph Parsing 论文解读
事件抽取涉及到事件触发词和相应事件论元的检测和抽取。现有系统经常将事件抽取分解为多个子任务,而不考虑它们之间可能的交互。
52 0
|
8月前
|
机器学习/深度学习 自然语言处理 算法
TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking 论文解读
近年来,从非结构化文本中提取实体和关系引起了越来越多的关注,但由于识别共享实体的重叠关系存在内在困难,因此仍然具有挑战性。先前的研究表明,联合学习可以显著提高性能。然而,它们通常涉及连续的相互关联的步骤,并存在暴露偏差的问题。
89 0
|
8月前
|
机器学习/深度学习 数据采集 自然语言处理
Efficient Zero-shot Event Extraction with Context-Definition Alignment论文解读
事件抽取(EE)是从文本中识别感兴趣的事件提及的任务。传统的工作主要以监督的方式为主。然而,这些监督的模型不能概括为预定义本体之外的事件类型。
55 0
|
9月前
|
机器学习/深度学习 自然语言处理 计算机视觉
【计算机视觉】MDETR - Modulated Detection for End-to-End Multi-Modal Understanding
对于图像模型,MDETR采用的是一个CNN backbone来提取视觉特征,然后加上二维的位置编码;对于语言模态,作者采用了一个预训练好的Transformer语言模型来生成与输入值相同大小的hidden state。然后作者采用了一个模态相关的Linear Projection将图像和文本特征映射到一个共享的embedding空间。 接着,将图像embedding和语言embedding进行concat,生成一个样本的图像和文本特征序列。这个序列特征首先被送入到一个Cross Encoder进行处理,后面的步骤就和DETR一样,设置Object Query用于预测目标框。
|
9月前
|
机器学习/深度学习 人工智能 自然语言处理
【计算机视觉】CORA: Adapting CLIP for Open-Vocabulary Detection with Region Prompting and Anchor Pre-Matching
CORA 在目标检测任务中提出了一种新的 CLIP 预训练模型适配方法,主要包括 Region Prompting 和 Anchor Pre-Matching 两部分。 这种方法能够让 CLIP 模型适应目标检测的任务,能够识别出图像中的对象,并提供准确的分类和定位信息。