文本分类还停留在BERT?对偶对比学习框架也太强了(二)

简介: 文本分类还停留在BERT?对偶对比学习框架也太强了(二)

实验结果



从结果中可以看出,除了使用RoBERTa的TREC数据集外,同时使用BERT和RoBERTa编码器在几乎所有设置中都取得了最好的分类性能。与具有完整训练数据的CE+CL相比,DualCL对BERT和RoBERTa的平均改善率分别为0.46%和0.39%。此外,我们观察到,在10%的训练数据下,DualCL的性能明显大于CE+CL方法,在BERT和RoBERTa上分别高出0.74%和0.51%。同时,CE 和 CE+SCL 的性能无法超越 DualCL。这是因为CE方法忽略了样本之间的关系,CE+SCL方法不能直接学习分类任务的分类器。


此外论文发现双重对比损失项有助于模型在所有五个数据集上实现更好的性能。它表明利用样本之间的关系有助于模型在对比学习中学习更好的表示。


4a4e25afaa3db1d64195e829d15ffb76.png


案例分析



为了验证DualCL是否能够捕获信息特征,作者还计算了[CLS]标记的特征与句子中每个单词之间的注意得分。首先在整个训练集上微调RoBERTa编码器。然后我们计算特征之间的距离,并可视化图4中的注意图。结果表明,在对情绪进行分类时,所捕获的特征是不同的。上面的例子来自SST-2数据集,我们可以看到我们的模型更关注表达“积极”情绪的句子“predictably heart warming”。下面的例子来自CR数据集,我们可以看到我们的模型对表达“消极”情绪的句子更关注“small”。相反,CE方法没有集中于这些鉴别特征。结果表明DualCL能够成功地处理句子中的信息性关键词。


2193e7b9a097fa87e0a1efc1e1311f8a.png


5、论文总结



在本研究中,从文本分类任务的角度,提出了一种对偶对比学习方法DualCL,来解决监督学习任务。


在DualCL中,作者使用PLMs同时学习两种表示形式。一个是输入示例的鉴别特征,另一个是该示例的分类器。我们引入了具有标签感知的数据增强功能来生成输入样本的不同视图,其中包含特征和分类器。然后设计了一个对偶对比损失,使分类器对输入特征有效。


对偶对比损失利用训练样本之间的监督信号来学习更好的表示,通过大量的实验验证了对偶对比学习的有效性。


6、核心代码



关于Dual-Contrastive-Learning实现,大家可以查看开源代码:


https://github.com/hiyouga/Dual-Contrastive-Learning/blob/main/main_polarity.py
 def _contrast_loss(self, cls_feature, label_feature, labels):
        normed_cls_feature = F.normalize(cls_feature, dim=-1)
        normed_label_feature = F.normalize(label_feature, dim=-1)
        list_con_loss = []
        BS, LABEL_CLASS, HS = normed_label_feature.shape
        normed_positive_label_feature = torch.gather(normed_label_feature, dim=1,
                                                     index=labels.reshape(-1, 1, 1).expand(-1, 1, HS)).squeeze(1)  # (bs, 768)
        if "1" in self.opt.contrast_mode:
            loss1 = self._calculate_contrast_loss(normed_positive_label_feature, normed_cls_feature, labels)
            list_con_loss.append(loss1)
        if "2" in self.opt.contrast_mode:
            loss2 = self._calculate_contrast_loss(normed_cls_feature, normed_positive_label_feature, labels)
            list_con_loss.append(loss2)
        if "3" in self.opt.contrast_mode:
            loss3 = self._calculate_contrast_loss(normed_positive_label_feature, normed_positive_label_feature, labels)
            list_con_loss.append(loss3)
        if "4" in self.opt.contrast_mode:
            loss4 = self._calculate_contrast_loss(normed_cls_feature, normed_cls_feature, labels)
            list_con_loss.append(loss4)
        return list_con_loss
    def _calculate_contrast_loss(self, anchor, target, labels, mu=1.0):
        BS = len(labels)
        with torch.no_grad():
            labels = labels.reshape(-1, 1)
            mask = torch.eq(labels, labels.T)  # (bs, bs)
            # compute temperature using mask
            temperature_matrix = torch.where(mask == True, mu * torch.ones_like(mask),
                                             1 / self.opt.temperature * torch.ones_like(mask)).to(self.opt.device)
            # # mask-out self-contrast cases, 即自身对自身不考虑在内
            # logits_mask = torch.scatter(
            #     torch.ones_like(mask),
            #     1,
            #     torch.arange(BS).view(-1, 1).to(self.opt.device),
            #     0
            # )
            # mask = mask * logits_mask
        # compute logits
        anchor_dot_target = torch.multiply(torch.matmul(anchor, target.T), temperature_matrix)  # (bs, bs)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_target, dim=1, keepdim=True)
        logits = anchor_dot_target - logits_max.detach()  # (bs, bs)
        # compute log_prob
        exp_logits = torch.exp(logits)  # (bs, bs)
        exp_logits = exp_logits - torch.diag_embed(torch.diag(exp_logits))  # 减去对角线元素,对自身不可以
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)  # (bs, bs)
        # in case that mask.sum(1) has no zero
        mask_sum = mask.sum(dim=1)
        mask_sum = torch.where(mask_sum == 0, torch.ones_like(mask_sum), mask_sum)
        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask_sum.detach()
        loss = - mean_log_prob_pos.mean()
        return loss


7、参考资料



ICML 2020: 从Alignment 和 Uniformity的角度理解对比表征学习


https://blog.csdn.net/c2a2o2/article/details/117898108


相关文章
|
11月前
|
机器学习/深度学习 数据采集 自然语言处理
【Deep Learning A情感文本分类实战】2023 Pytorch+Bert、Roberta+TextCNN、BiLstm、Lstm等实现IMDB情感文本分类完整项目(项目已开源)
亮点:代码开源+结构清晰+准确率高+保姆级解析 🍊本项目使用Pytorch框架,使用上游语言模型+下游网络模型的结构实现IMDB情感分析 🍊语言模型可选择Bert、Roberta 🍊神经网络模型可选择BiLstm、LSTM、TextCNN、Rnn、Gru、Fnn共6种 🍊语言模型和网络模型扩展性较好,方便读者自己对模型进行修改
408 0
|
10月前
|
数据采集
基于Bert文本分类进行行业识别
基于Bert文本分类进行行业识别
162 0
|
12月前
|
机器学习/深度学习 缓存 人工智能
深度学习进阶篇-预训练模型[3]:XLNet、BERT、GPT,ELMO的区别优缺点,模型框架、一些Trick、Transformer Encoder等原理详解
深度学习进阶篇-预训练模型[3]:XLNet、BERT、GPT,ELMO的区别优缺点,模型框架、一些Trick、Transformer Encoder等原理详解
深度学习进阶篇-预训练模型[3]:XLNet、BERT、GPT,ELMO的区别优缺点,模型框架、一些Trick、Transformer Encoder等原理详解
|
12月前
|
机器学习/深度学习 XML 人工智能
ELMo、GPT、BERT、X-Transformer…你都掌握了吗?一文总结文本分类必备经典模型(五)
ELMo、GPT、BERT、X-Transformer…你都掌握了吗?一文总结文本分类必备经典模型
362 0
|
12月前
|
机器学习/深度学习 自然语言处理 数据可视化
ELMo、GPT、BERT、X-Transformer…你都掌握了吗?一文总结文本分类必备经典模型(四)
ELMo、GPT、BERT、X-Transformer…你都掌握了吗?一文总结文本分类必备经典模型
239 0
|
12月前
|
机器学习/深度学习 自然语言处理 算法
ELMo、GPT、BERT、X-Transformer…你都掌握了吗?一文总结文本分类必备经典模型(三)
ELMo、GPT、BERT、X-Transformer…你都掌握了吗?一文总结文本分类必备经典模型(三)
234 0
|
机器学习/深度学习 数据可视化 PyTorch
【BERT-多标签文本分类实战】之七——训练-评估-测试与运行主程序
【BERT-多标签文本分类实战】之七——训练-评估-测试与运行主程序
385 0
|
机器学习/深度学习 存储
【BERT-多标签文本分类实战】之六——数据加载与模型代码
【BERT-多标签文本分类实战】之六——数据加载与模型代码
282 0
【BERT-多标签文本分类实战】之六——数据加载与模型代码
|
自然语言处理 PyTorch TensorFlow
【BERT-多标签文本分类实战】之五——BERT模型库的挑选与Transformers
【BERT-多标签文本分类实战】之五——BERT模型库的挑选与Transformers
728 0
【BERT-多标签文本分类实战】之五——BERT模型库的挑选与Transformers
|
存储 数据采集 自然语言处理
【BERT-多标签文本分类实战】之四——数据集预处理
【BERT-多标签文本分类实战】之四——数据集预处理
620 1
【BERT-多标签文本分类实战】之四——数据集预处理

热门文章

最新文章