实验结果
从结果中可以看出,除了使用RoBERTa的TREC数据集外,同时使用BERT和RoBERTa编码器在几乎所有设置中都取得了最好的分类性能。与具有完整训练数据的CE+CL相比,DualCL对BERT和RoBERTa的平均改善率分别为0.46%和0.39%。此外,我们观察到,在10%的训练数据下,DualCL的性能明显大于CE+CL方法,在BERT和RoBERTa上分别高出0.74%和0.51%。同时,CE 和 CE+SCL 的性能无法超越 DualCL。这是因为CE方法忽略了样本之间的关系,CE+SCL方法不能直接学习分类任务的分类器。
此外论文发现双重对比损失项有助于模型在所有五个数据集上实现更好的性能。它表明利用样本之间的关系有助于模型在对比学习中学习更好的表示。
案例分析
为了验证DualCL是否能够捕获信息特征,作者还计算了[CLS]标记的特征与句子中每个单词之间的注意得分。首先在整个训练集上微调RoBERTa编码器。然后我们计算特征之间的距离,并可视化图4中的注意图。结果表明,在对情绪进行分类时,所捕获的特征是不同的。上面的例子来自SST-2数据集,我们可以看到我们的模型更关注表达“积极”情绪的句子“predictably heart warming”。下面的例子来自CR数据集,我们可以看到我们的模型对表达“消极”情绪的句子更关注“small”。相反,CE方法没有集中于这些鉴别特征。结果表明DualCL能够成功地处理句子中的信息性关键词。
5、论文总结
在本研究中,从文本分类任务的角度,提出了一种对偶对比学习方法DualCL,来解决监督学习任务。
在DualCL中,作者使用PLMs同时学习两种表示形式。一个是输入示例的鉴别特征,另一个是该示例的分类器。我们引入了具有标签感知的数据增强功能来生成输入样本的不同视图,其中包含特征和分类器。然后设计了一个对偶对比损失,使分类器对输入特征有效。
对偶对比损失利用训练样本之间的监督信号来学习更好的表示,通过大量的实验验证了对偶对比学习的有效性。
6、核心代码
关于Dual-Contrastive-Learning实现,大家可以查看开源代码:
https://github.com/hiyouga/Dual-Contrastive-Learning/blob/main/main_polarity.py def _contrast_loss(self, cls_feature, label_feature, labels): normed_cls_feature = F.normalize(cls_feature, dim=-1) normed_label_feature = F.normalize(label_feature, dim=-1) list_con_loss = [] BS, LABEL_CLASS, HS = normed_label_feature.shape normed_positive_label_feature = torch.gather(normed_label_feature, dim=1, index=labels.reshape(-1, 1, 1).expand(-1, 1, HS)).squeeze(1) # (bs, 768) if "1" in self.opt.contrast_mode: loss1 = self._calculate_contrast_loss(normed_positive_label_feature, normed_cls_feature, labels) list_con_loss.append(loss1) if "2" in self.opt.contrast_mode: loss2 = self._calculate_contrast_loss(normed_cls_feature, normed_positive_label_feature, labels) list_con_loss.append(loss2) if "3" in self.opt.contrast_mode: loss3 = self._calculate_contrast_loss(normed_positive_label_feature, normed_positive_label_feature, labels) list_con_loss.append(loss3) if "4" in self.opt.contrast_mode: loss4 = self._calculate_contrast_loss(normed_cls_feature, normed_cls_feature, labels) list_con_loss.append(loss4) return list_con_loss def _calculate_contrast_loss(self, anchor, target, labels, mu=1.0): BS = len(labels) with torch.no_grad(): labels = labels.reshape(-1, 1) mask = torch.eq(labels, labels.T) # (bs, bs) # compute temperature using mask temperature_matrix = torch.where(mask == True, mu * torch.ones_like(mask), 1 / self.opt.temperature * torch.ones_like(mask)).to(self.opt.device) # # mask-out self-contrast cases, 即自身对自身不考虑在内 # logits_mask = torch.scatter( # torch.ones_like(mask), # 1, # torch.arange(BS).view(-1, 1).to(self.opt.device), # 0 # ) # mask = mask * logits_mask # compute logits anchor_dot_target = torch.multiply(torch.matmul(anchor, target.T), temperature_matrix) # (bs, bs) # for numerical stability logits_max, _ = torch.max(anchor_dot_target, dim=1, keepdim=True) logits = anchor_dot_target - logits_max.detach() # (bs, bs) # compute log_prob exp_logits = torch.exp(logits) # (bs, bs) exp_logits = exp_logits - torch.diag_embed(torch.diag(exp_logits)) # 减去对角线元素,对自身不可以 log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12) # (bs, bs) # in case that mask.sum(1) has no zero mask_sum = mask.sum(dim=1) mask_sum = torch.where(mask_sum == 0, torch.ones_like(mask_sum), mask_sum) # compute mean of log-likelihood over positive mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask_sum.detach() loss = - mean_log_prob_pos.mean() return loss
7、参考资料
ICML 2020: 从Alignment 和 Uniformity的角度理解对比表征学习
https://blog.csdn.net/c2a2o2/article/details/117898108