阅读摘要
如上图,思路很朴实无华。
普通MLM任务使用的损失函数是CrossEntropyLoss
,它适用于单标签,代码如下:
masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
文章这里使用的是BCEWithLogitsLoss
,它适用于多标签分类。即:把[MASK]
位置预测到的词表的值进行sigmoid
,取指定阈值以上的标签,然后算损失。
个人觉得这样不可取,效果也不会好。