基于PaddleNLP的美国专利中的相似短语匹配竞赛

简介: 基于PaddleNLP的美国专利中的相似短语匹配竞赛

一、 基于PaddleNLP的美国专利中的相似短语匹配竞赛


1.项目简介


竞赛地址: www.kaggle.com/competition…

美国专利商标局(USPTO)通过其开放数据门户提供世界上最大的科学,技术和商业信息存储库之一。专利是授予知识产权的一种形式,以换取公开披露新的和有用的发明。由于专利在授予之前经过严格的审查过程,并且由于美国的创新历史跨越了两个多世纪和1100万项专利,因此美国专利档案是数据量,质量和多样性的罕见组合。

“美国专利商标局为一台美国创新机器提供服务,该机器通过授予专利,注册商标和在全球范围内推广知识产权来永不停歇。美国专利商标局与世界分享了200多年的人类创造力,从灯泡到量子计算机。结合数据科学界的创造力,USPTO数据集具有无限的潜力,可以增强AI和ML模型,从而有利于科学和整个社会的进步。”
— 美国专利商标局首席信息官杰米·霍尔科姆

在本次竞赛中,您将在新颖的语义相似性数据集上训练您的模型,以通过匹配专利文件中的关键短语来提取相关信息。在专利检索和审查过程中,确定短语之间的语义相似性对于确定以前是否描述过发明至关重要。例如,如果一项发明要求“电视机”,而先前的出版物描述了“电视机”,则理想情况下,模型将认识到这些是相同的,并协助专利律师或审查员检索相关文件。这超出了释义识别的范围;如果一项发明声称使用“强材料”,而另一项发明使用“钢”,那也可能是匹配的。什么算作“坚固的材料”因域而异(一个域中可能是钢,另一个域中可能是防撕裂织物,但您不希望降落伞由钢制成)。我们已将合作专利分类作为技术领域上下文作为附加功能,以帮助您消除这些情况的歧义


您能否建立一个模型来匹配短语,以便提取上下文信息,从而帮助专利界连接数百万份专利文件之间的点?


2. 数据集介绍


该数据集为英文数据集,,根据两段信息检索文本在语义上是否相似进行5分类。


在此数据集中,将显示成对的短语(an 和 a 短语),并要求您在从 0(完全不相似)到 1(含义相同)的尺度上对它们的相似程度进行评级。该挑战与标准语义相似性任务的不同之处在于,此处已在专利中对相似性进行了评分,特别是其CPC分类(版本2021.05),该分类指示了专利所涉及的主题。例如,虽然短语“鸟”和“科德角”在正常语言中可能具有较低的语义相似性,但如果在“房子”的上下文中考虑它们的含义的相似性要接近得多。


分数在 0-1 范围内,增量为 0.25,具有以下含义:


  • 1.0 - 非常接近匹配。这通常是完全匹配的,除了可能在变位,数量(例如单数与复数)以及添加或删除非索引字(例如“the”,“and”,“or”)方面的差异。
  • 0.75 - 关闭同义词,例如“手机”与“手机”。这还包括缩写,例如“TCP”->“传输控制协议”。
  • 0.5 - 不具有相同含义的同义词(相同的函数,相同的属性)。这包括宽-窄(通名)和窄-宽(超名)匹配。
  • 0.25 - 有些相关,例如,这两个短语位于同一高级域中,但不是同义词。这也包括反义词。
  • 0.0 - 不相关。


文件


  • 训练.csv - 训练集,包含短语、上下文及其相似性分数
  • test.csv - 测试集,结构与训练集相同,但没有分数
  • sample_submission.csv - 格式正确的示例提交文件



  • id- 一对短语的唯一标识符
  • anchor- 第一个短语
  • target- 第二个短语
  • context- CPC分类(版本2021.05),表示要在其中对相似性进行评分的主题
  • score- 相似性。这是来自一个或多个手动专家评级的组合。


二、前置基础知识


1.文本语义匹配


文本语义匹配是自然语言处理中一个重要的基础问题,NLP 领域的很多任务都可以抽象为文本匹配任务。例如,信息检索可以归结为查询项和文档的匹配,问答系统可以归结为问题和候选答案的匹配,对话系统可以归结为对话和回复的匹配。语义匹配在搜索优化、推荐系统、快速检索排序、智能客服上都有广泛的应用。如何提升文本匹配的准确度,是自然语言处理领域的一个重要挑战。


  • 信息检索:在信息检索领域的很多应用中,都需要根据原文本来检索与其相似的其他文本,使用场景非常普遍。
  • 新闻推荐:通过用户刚刚浏览过的新闻标题,自动检索出其他的相似新闻,个性化地为用户做推荐,从而增强用户粘性,提升产品体验。
  • 智能客服:用户输入一个问题后,自动为用户检索出相似的问题和答案,节约人工客服的成本,提高效率。


让我们来看一个简单的例子,比较各候选句子哪句和原句语义更相近:

原句:“车头如何放置车牌”


  • 比较句1:“前牌照怎么装”
  • 比较句2:“如何办理北京车牌”
  • 比较句3:“后牌照怎么装”


(1)比较句1与原句,虽然句式和语序等存在较大差异,但是所表述的含义几乎相同

(2)比较句2与原句,虽然存在“如何” 、“车牌”等共现词,但是所表述的含义完全不同

(3)比较句3与原句,二者讨论的都是如何放置车牌的问题,只不过一个是前牌照,另一个是后牌照。二者间存在一定的语义相关性

所以语义相关性,句1大于句3,句3大于句2,这就是语义匹配。


2.短文本语义匹配网络


短文本语义匹配(SimilarityNet, SimNet)是一个计算短文本相似度的框架,可以根据用户输入的两个文本,计算出相似度得分。主要包括 BOW、CNN、RNN、MMDNN 等核心网络结构形式,提供语义相似度计算训练和预测框架,适用于信息检索、新闻推荐、智能客服等多个应用场景,帮助企业解决语义匹配问题。

SimNet 模型结构如图所示,包括输入层、表示层以及匹配层。

image.pngSimilarityNet 框架


3.SimilarityNet模型框架结构图


模型框架结构图如下图所示,其中 query 和 title 是数据集经过处理后的待匹配的文本,然后经过分词处理,编码成 id,经过 SimilarityNet 处理,得到输出,训练的损失函数使用的是交叉熵损失。

image.pngSimilarityNet 模型框架结构图


三、数据基本情况


1.数据解压缩


!unzip -oaq /home/aistudio/data/data143274/us-patent-phrase-to-phrase-matching.zip


2.数据查看


查看数据并进行行数统计。

  • 36474 train.csv
  • 37 test.csv
  • 37 sample_submission.csv

基本上分5列,分别为id,以及对比的2个短语,以及使用的语境,得分。

!head train.csv
id,anchor,target,context,score
37d61fd2272659b1,abatement,abatement of pollution,A47,0.5
7b9652b17b68b7a4,abatement,act of abating,A47,0.75
36d72442aefd8232,abatement,active catalyst,A47,0.25
5296b0c19e1ce60e,abatement,eliminating process,A47,0.5
54c1e3b9184cb5b6,abatement,forest region,A47,0
067203128142739c,abatement,greenhouse gases,A47,0.25
061d17f04be2d1cf,abatement,increased rate,A47,0.25
e1f44e48399a2027,abatement,measurement level,A47,0.25
0a425937a3e86d10,abatement,minimising sounds,A47,0.5
!wc -l train.csv
36474 train.csv
!head test.csv
id,anchor,target,context
4112d61851461f60,opc drum,inorganic photoconductor drum,G02
09e418c93a776564,adjust gas flow,altering gas flow,F23
36baf228038e314b,lower trunnion,lower locating,B60
1f37ead645e7f0c8,cap component,upper portion,D06
71a5b6ad068d531f,neural stimulation,artificial neural network,H04
474c874d0c07bd21,dry corn,dry corn starch,C12
442c114ed5c4e3c9,tunneling capacitor,capacitor housing,G11
b8ae62ea5e1d8bdb,angular contact bearing,contact therapy radiation,B23
faaddaf8fcba8a3f,produce liquid hydrocarbons,produce a treated stream,C10
!wc -l test.csv
37 test.csv
!head sample_submission.csv
id,score
4112d61851461f60,0
09e418c93a776564,0
36baf228038e314b,0
1f37ead645e7f0c8,0
71a5b6ad068d531f,0
474c874d0c07bd21,0
442c114ed5c4e3c9,0
b8ae62ea5e1d8bdb,0
faaddaf8fcba8a3f,0
!wc -l sample_submission.csv
37 sample_submission.csv


四、环境配置


此次使用PaddleNLP快速处理,PaddleNLP需要升级,版本号为 2.3.3 ,并需要重启 notebook。


1.升级PaddleNLP


  • 升级pip
  • 升级paddlenlp
!python -m pip install  --upgrade pip --user>log.log
# AI Studio上的PaddleNLP版本过低,所以需要首先升级PaddleNLP
!pip install paddlenlp --upgrade --user>log.log


2.重启并导入必要库


paddlenlp升级完毕,需要重启,重新加载。

#导入paddle相关的包
import paddle
import paddle.nn as nn
# 导入PaddleNLP相关的包
import paddlenlp as ppnlp
from paddlenlp.data import JiebaTokenizer, Pad, Stack, Tuple, Vocab
# from utils import convert_example
from paddlenlp.datasets import MapDataset
from paddle.dataset.common import md5file
from paddlenlp.datasets import DatasetBuilder
from paddlenlp.datasets import load_dataset
from paddle.io import Dataset, Subset
from paddlenlp.datasets import MapDataset
import numpy as np
import paddle.nn.functional as F
import time
import os
from paddlenlp.data import Stack, Pad, Tuple
print("本项目基于Paddle的版本号为:"+ paddle.__version__)
print("本项目基于PaddleNLP的版本号为:"+ ppnlp.__version__)
本项目基于Paddle的版本号为:2.3.0
本项目基于PaddleNLP的版本号为:2.3.3


五、加载预训练模型


1.ERNIE模型汇总


下表汇总介绍了目前PaddleNLP支持的ERNIE模型对应预训练权重。 关于模型的具体细节可以参考对应链接。

Pretrained Weight

Language

Details of the model

ernie-1.0-base-zh

Chinese

12-layer, 768-hidden, 12-heads, 108M parameters. Trained on Chinese text.

ernie-tiny

Chinese

3-layer, 1024-hidden, 16-heads, _M parameters. Trained on Chinese text.

ernie-2.0-base-en

English

12-layer, 768-hidden, 12-heads, 103M parameters. Trained on lower-cased English text.

ernie-2.0-base-en-finetuned-squad

English

12-layer, 768-hidden, 12-heads, 110M parameters. Trained on finetuned squad text.

ernie-2.0-large-en

English

24-layer, 1024-hidden, 16-heads, 336M parameters. Trained on lower-cased English text.

zhui/ernie-1.0-cluecorpussmall

Chinese

Please refer to:zhui/ernie-1.0-cluecorpussmall


2.模型加载


# 若运行失败,请重启项目
MODEL_NAME = "ernie-2.0-base-en"
ernie_model = ppnlp.transformers.ErnieModel.from_pretrained(MODEL_NAME)
model = ppnlp.transformers.ErnieForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=5)
[2022-06-12 15:23:06,757] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-2.0-base-en/ernie_v2_eng_base.pdparams
W0612 15:23:06.761130 13195 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0612 15:23:06.765818 13195 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.
[2022-06-12 15:23:12,415] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-2.0-base-en/ernie_v2_eng_base.pdparams
# 定义ERNIE模型对应的 tokenizer,并查看效果
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained(MODEL_NAME)
[2022-06-12 15:23:13,762] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-2.0-base-en/vocab.txt


六、加载数据集


1.查看数据


import pandas as pd
train_data = "./train.csv"
train_data = pd.read_csv(train_data,  sep=',')
test_data = "./test.csv"
test_data = pd.read_csv(test_data,  sep=',')
train_data.head(5)
.dataframe tbody tr th:only-of-type {         vertical-align: middle;     } .dataframe tbody tr th {     vertical-align: top; } .dataframe thead th {     text-align: right; }

id anchor target context score
0 37d61fd2272659b1 abatement abatement of pollution A47 0.50
1 7b9652b17b68b7a4 abatement act of abating A47 0.75
2 36d72442aefd8232 abatement active catalyst A47 0.25
3 5296b0c19e1ce60e abatement eliminating process A47 0.50
4 54c1e3b9184cb5b6 abatement forest region A47 0.00
print(max(train_data['anchor'].str.len()))
print(max(train_data['target'].str.len()))
38
98


2.score分布


%matplotlib inline
train_data.hist(column="score",        # 具体列
              figsize=(16,16),         # 图片大小
              color="red")
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7fc3ee0d8710>]],
      dtype=object)

image.png


3.label转换


def score2label(x):
    if x==0.0:
        y=0
    elif x==0.25:
        y=1
    elif x==0.5:
        y=2
    elif x==0.75:
        y=3
    elif x==1.0:
        y=4   
    return y         
train_data['label']=train_data.score.map(score2label)


4.自定义read


def read_train(pd_data):   
    for index, item in pd_data.iterrows():       
        yield {"query": item['anchor'], "title": item['target'], "label": item['label']}       
def read_test(pd_data):   
    for index, item in pd_data.iterrows():
        yield {"query": item['anchor'], "title": item['target'], "label": 0}                   


5.加载并分割数据集


# data_path为read()方法的参数
train_ds = load_dataset(read_train, pd_data=train_data,lazy=False)
for i in range(5):
    print(train_ds[i])
{'query': 'abatement', 'title': 'abatement of pollution', 'label': 2}
{'query': 'abatement', 'title': 'act of abating', 'label': 3}
{'query': 'abatement', 'title': 'active catalyst', 'label': 1}
{'query': 'abatement', 'title': 'eliminating process', 'label': 2}
{'query': 'abatement', 'title': 'forest region', 'label': 0}


6.数据预处理


通过 paddlenlp 加载进来的 数据集是原始的明文数据集,这部分我们来实现组 batch、tokenize 等预处理逻辑,将原始明文数据转换成网络训练的输入数据。


6.1定义样本转换函数


from functools import partial
from paddlenlp.data import Stack, Tuple, Pad
from utils import  convert_example, create_dataloader
batch_size = 64
max_seq_length = 256
trans_func = partial(
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length)
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),       # input
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # segment
    Stack(dtype="int64")                               # label
): [data for data in fn(samples)]
train_data_loader = create_dataloader(
    train_ds,
    mode='train',
    batch_size=batch_size,
    batchify_fn=batchify_fn,
    trans_fn=trans_func)
from paddlenlp.transformers import LinearDecayWithWarmup
# 训练过程中的最大学习率
learning_rate = 5e-5
# 训练轮次
epochs = 3
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01
num_training_steps = len(train_data_loader) * epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ])
criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()


七、模型训练


import paddle.nn.functional as F
from utils import evaluate
global_step = 0
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        input_ids, segment_ids, labels = batch
        logits = model(input_ids, segment_ids)
        loss = criterion(logits, labels)
        probs = F.softmax(logits, axis=1)
        correct = metric.compute(probs, labels)
        metric.update(correct)
        acc = metric.accumulate()
        global_step += 1
        if global_step % 10 == 0 :
            print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc))
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()
    evaluate(model, criterion, metric, train_data_loader)
global step 10, epoch: 1, batch: 10, loss: 1.56753, acc: 0.19219
global step 20, epoch: 1, batch: 20, loss: 1.46965, acc: 0.25234
global step 30, epoch: 1, batch: 30, loss: 1.36288, acc: 0.30000
global step 40, epoch: 1, batch: 40, loss: 1.51080, acc: 0.31602
global step 50, epoch: 1, batch: 50, loss: 1.35407, acc: 0.32500
global step 60, epoch: 1, batch: 60, loss: 1.29739, acc: 0.34036
global step 70, epoch: 1, batch: 70, loss: 1.46518, acc: 0.34375
global step 80, epoch: 1, batch: 80, loss: 1.38125, acc: 0.34590
global step 90, epoch: 1, batch: 90, loss: 1.19846, acc: 0.35365
global step 100, epoch: 1, batch: 100, loss: 1.27309, acc: 0.36250
global step 110, epoch: 1, batch: 110, loss: 1.25976, acc: 0.37173
global step 120, epoch: 1, batch: 120, loss: 1.31990, acc: 0.37799
global step 130, epoch: 1, batch: 130, loss: 1.28407, acc: 0.38762
global step 140, epoch: 1, batch: 140, loss: 1.33328, acc: 0.39308
global step 150, epoch: 1, batch: 150, loss: 1.16941, acc: 0.39427
global step 160, epoch: 1, batch: 160, loss: 1.16217, acc: 0.39951
global step 170, epoch: 1, batch: 170, loss: 1.30883, acc: 0.40450
global step 180, epoch: 1, batch: 180, loss: 1.40474, acc: 0.40608
global step 190, epoch: 1, batch: 190, loss: 1.23098, acc: 0.41020
global step 200, epoch: 1, batch: 200, loss: 1.09877, acc: 0.41156
global step 210, epoch: 1, batch: 210, loss: 1.27568, acc: 0.41235
global step 220, epoch: 1, batch: 220, loss: 1.25363, acc: 0.41506
global step 230, epoch: 1, batch: 230, loss: 1.35899, acc: 0.41692
global step 240, epoch: 1, batch: 240, loss: 1.32752, acc: 0.41823
global step 250, epoch: 1, batch: 250, loss: 1.10830, acc: 0.42094
global step 260, epoch: 1, batch: 260, loss: 1.19847, acc: 0.42284
global step 270, epoch: 1, batch: 270, loss: 1.31777, acc: 0.42541
global step 280, epoch: 1, batch: 280, loss: 1.13283, acc: 0.42840
global step 290, epoch: 1, batch: 290, loss: 1.03019, acc: 0.43141
global step 300, epoch: 1, batch: 300, loss: 1.28243, acc: 0.43297
global step 310, epoch: 1, batch: 310, loss: 1.02955, acc: 0.43533
global step 320, epoch: 1, batch: 320, loss: 1.38921, acc: 0.43706
global step 330, epoch: 1, batch: 330, loss: 1.21835, acc: 0.43954
global step 340, epoch: 1, batch: 340, loss: 1.12567, acc: 0.44040
global step 350, epoch: 1, batch: 350, loss: 1.28655, acc: 0.44214
global step 360, epoch: 1, batch: 360, loss: 1.21317, acc: 0.44462
global step 370, epoch: 1, batch: 370, loss: 1.15794, acc: 0.44561
global step 380, epoch: 1, batch: 380, loss: 1.05953, acc: 0.44704
global step 390, epoch: 1, batch: 390, loss: 1.02538, acc: 0.44804
global step 400, epoch: 1, batch: 400, loss: 1.23620, acc: 0.44949
global step 410, epoch: 1, batch: 410, loss: 1.03123, acc: 0.45107
global step 420, epoch: 1, batch: 420, loss: 1.06850, acc: 0.45279
global step 430, epoch: 1, batch: 430, loss: 1.03816, acc: 0.45396
global step 440, epoch: 1, batch: 440, loss: 1.17245, acc: 0.45568
global step 450, epoch: 1, batch: 450, loss: 1.09026, acc: 0.45684
global step 460, epoch: 1, batch: 460, loss: 1.12116, acc: 0.45822
global step 470, epoch: 1, batch: 470, loss: 1.12357, acc: 0.46051
global step 480, epoch: 1, batch: 480, loss: 1.14521, acc: 0.46260
global step 490, epoch: 1, batch: 490, loss: 1.15795, acc: 0.46403
global step 500, epoch: 1, batch: 500, loss: 1.05499, acc: 0.46506
global step 510, epoch: 1, batch: 510, loss: 1.17101, acc: 0.46572
global step 520, epoch: 1, batch: 520, loss: 1.26547, acc: 0.46596
global step 530, epoch: 1, batch: 530, loss: 1.14175, acc: 0.46713
global step 540, epoch: 1, batch: 540, loss: 1.13119, acc: 0.46881
global step 550, epoch: 1, batch: 550, loss: 1.21626, acc: 0.47000
global step 560, epoch: 1, batch: 560, loss: 1.24214, acc: 0.47146
global step 570, epoch: 1, batch: 570, loss: 1.04112, acc: 0.47265
eval loss: 0.97536, accu: 0.60310
global step 580, epoch: 2, batch: 10, loss: 0.93487, acc: 0.60781
global step 590, epoch: 2, batch: 20, loss: 0.90623, acc: 0.60625
global step 600, epoch: 2, batch: 30, loss: 0.97992, acc: 0.59271
global step 610, epoch: 2, batch: 40, loss: 1.11511, acc: 0.58555
global step 620, epoch: 2, batch: 50, loss: 1.02230, acc: 0.58969
global step 630, epoch: 2, batch: 60, loss: 1.06591, acc: 0.58698
global step 640, epoch: 2, batch: 70, loss: 0.76665, acc: 0.58527
global step 650, epoch: 2, batch: 80, loss: 0.96098, acc: 0.58926
global step 660, epoch: 2, batch: 90, loss: 0.93122, acc: 0.58490
global step 670, epoch: 2, batch: 100, loss: 0.91359, acc: 0.58313
global step 680, epoch: 2, batch: 110, loss: 0.94778, acc: 0.58466
global step 690, epoch: 2, batch: 120, loss: 1.03503, acc: 0.58672
global step 700, epoch: 2, batch: 130, loss: 1.22159, acc: 0.58437
global step 710, epoch: 2, batch: 140, loss: 1.15416, acc: 0.58460
global step 720, epoch: 2, batch: 150, loss: 1.01074, acc: 0.58365
global step 730, epoch: 2, batch: 160, loss: 0.92790, acc: 0.58389
global step 740, epoch: 2, batch: 170, loss: 0.80486, acc: 0.58529
global step 750, epoch: 2, batch: 180, loss: 0.82019, acc: 0.58672
global step 760, epoch: 2, batch: 190, loss: 0.95490, acc: 0.58824
global step 770, epoch: 2, batch: 200, loss: 0.91211, acc: 0.58820
global step 780, epoch: 2, batch: 210, loss: 0.91922, acc: 0.58824
global step 790, epoch: 2, batch: 220, loss: 0.93980, acc: 0.58899
global step 800, epoch: 2, batch: 230, loss: 1.18243, acc: 0.58927
global step 810, epoch: 2, batch: 240, loss: 1.00865, acc: 0.58939
global step 820, epoch: 2, batch: 250, loss: 0.91286, acc: 0.59031
global step 830, epoch: 2, batch: 260, loss: 0.94963, acc: 0.59069
global step 840, epoch: 2, batch: 270, loss: 1.07120, acc: 0.59080
global step 850, epoch: 2, batch: 280, loss: 0.94161, acc: 0.59096
global step 860, epoch: 2, batch: 290, loss: 1.03282, acc: 0.59154
global step 870, epoch: 2, batch: 300, loss: 1.03144, acc: 0.59260
global step 880, epoch: 2, batch: 310, loss: 0.92614, acc: 0.59350
global step 890, epoch: 2, batch: 320, loss: 0.97692, acc: 0.59390
global step 900, epoch: 2, batch: 330, loss: 0.99512, acc: 0.59332
global step 910, epoch: 2, batch: 340, loss: 1.02203, acc: 0.59283
global step 920, epoch: 2, batch: 350, loss: 0.98110, acc: 0.59308
global step 930, epoch: 2, batch: 360, loss: 1.11808, acc: 0.59388
global step 940, epoch: 2, batch: 370, loss: 0.96240, acc: 0.59468
global step 950, epoch: 2, batch: 380, loss: 0.78424, acc: 0.59552
global step 960, epoch: 2, batch: 390, loss: 0.96898, acc: 0.59583
global step 970, epoch: 2, batch: 400, loss: 1.03831, acc: 0.59574
global step 980, epoch: 2, batch: 410, loss: 0.92671, acc: 0.59661
global step 990, epoch: 2, batch: 420, loss: 0.91024, acc: 0.59702
global step 1000, epoch: 2, batch: 430, loss: 1.06812, acc: 0.59695
global step 1010, epoch: 2, batch: 440, loss: 0.93563, acc: 0.59716
global step 1020, epoch: 2, batch: 450, loss: 0.92760, acc: 0.59708
global step 1030, epoch: 2, batch: 460, loss: 0.79893, acc: 0.59789
global step 1040, epoch: 2, batch: 470, loss: 0.89773, acc: 0.59850
global step 1050, epoch: 2, batch: 480, loss: 0.82244, acc: 0.59847
global step 1060, epoch: 2, batch: 490, loss: 0.93188, acc: 0.59892
global step 1070, epoch: 2, batch: 500, loss: 0.88342, acc: 0.59953
global step 1080, epoch: 2, batch: 510, loss: 0.73247, acc: 0.60034
global step 1090, epoch: 2, batch: 520, loss: 0.83115, acc: 0.60072
global step 1100, epoch: 2, batch: 530, loss: 0.89471, acc: 0.60085
global step 1110, epoch: 2, batch: 540, loss: 0.83178, acc: 0.60084
global step 1120, epoch: 2, batch: 550, loss: 0.74519, acc: 0.60176
global step 1130, epoch: 2, batch: 560, loss: 0.88650, acc: 0.60240
global step 1140, epoch: 2, batch: 570, loss: 0.83957, acc: 0.60258
eval loss: 0.69819, accu: 0.72596
global step 1150, epoch: 3, batch: 10, loss: 0.88984, acc: 0.67500
global step 1160, epoch: 3, batch: 20, loss: 0.88784, acc: 0.68672
global step 1170, epoch: 3, batch: 30, loss: 0.59687, acc: 0.69375
global step 1180, epoch: 3, batch: 40, loss: 0.66985, acc: 0.69453
global step 1190, epoch: 3, batch: 50, loss: 0.83693, acc: 0.68875
global step 1200, epoch: 3, batch: 60, loss: 0.70200, acc: 0.68672
global step 1210, epoch: 3, batch: 70, loss: 0.64738, acc: 0.69018
global step 1220, epoch: 3, batch: 80, loss: 0.60940, acc: 0.69141
global step 1230, epoch: 3, batch: 90, loss: 0.68041, acc: 0.69219
global step 1240, epoch: 3, batch: 100, loss: 0.73690, acc: 0.69250
global step 1250, epoch: 3, batch: 110, loss: 0.69368, acc: 0.69304
global step 1260, epoch: 3, batch: 120, loss: 0.71949, acc: 0.69128
global step 1270, epoch: 3, batch: 130, loss: 0.96168, acc: 0.69291
global step 1280, epoch: 3, batch: 140, loss: 0.69623, acc: 0.69375
global step 1290, epoch: 3, batch: 150, loss: 0.69540, acc: 0.69594
global step 1300, epoch: 3, batch: 160, loss: 0.92395, acc: 0.69502
global step 1310, epoch: 3, batch: 170, loss: 0.67467, acc: 0.69568
global step 1320, epoch: 3, batch: 180, loss: 0.71054, acc: 0.69696
global step 1330, epoch: 3, batch: 190, loss: 0.77184, acc: 0.69646
global step 1340, epoch: 3, batch: 200, loss: 0.72983, acc: 0.69609
global step 1350, epoch: 3, batch: 210, loss: 0.61290, acc: 0.69695
global step 1360, epoch: 3, batch: 220, loss: 0.73045, acc: 0.69730
global step 1370, epoch: 3, batch: 230, loss: 0.68346, acc: 0.69755
global step 1380, epoch: 3, batch: 240, loss: 0.72073, acc: 0.69772
global step 1390, epoch: 3, batch: 250, loss: 0.69043, acc: 0.69769
global step 1400, epoch: 3, batch: 260, loss: 0.76760, acc: 0.69730
global step 1410, epoch: 3, batch: 270, loss: 0.81253, acc: 0.69751
global step 1420, epoch: 3, batch: 280, loss: 0.92509, acc: 0.69665
global step 1430, epoch: 3, batch: 290, loss: 0.61820, acc: 0.69698
global step 1440, epoch: 3, batch: 300, loss: 0.60527, acc: 0.69766
global step 1450, epoch: 3, batch: 310, loss: 0.78595, acc: 0.69783
global step 1460, epoch: 3, batch: 320, loss: 0.52455, acc: 0.69834
global step 1470, epoch: 3, batch: 330, loss: 0.52329, acc: 0.69938
global step 1480, epoch: 3, batch: 340, loss: 0.64003, acc: 0.69995
global step 1490, epoch: 3, batch: 350, loss: 0.84223, acc: 0.69996
global step 1500, epoch: 3, batch: 360, loss: 0.73410, acc: 0.69987
global step 1510, epoch: 3, batch: 370, loss: 0.62209, acc: 0.69996
global step 1520, epoch: 3, batch: 380, loss: 0.77763, acc: 0.69959
global step 1530, epoch: 3, batch: 390, loss: 0.78981, acc: 0.70004
global step 1540, epoch: 3, batch: 400, loss: 0.78278, acc: 0.69969
global step 1550, epoch: 3, batch: 410, loss: 0.47082, acc: 0.69954
global step 1560, epoch: 3, batch: 420, loss: 0.54026, acc: 0.69985
global step 1570, epoch: 3, batch: 430, loss: 0.62268, acc: 0.70087
global step 1580, epoch: 3, batch: 440, loss: 0.70356, acc: 0.70131
global step 1590, epoch: 3, batch: 450, loss: 0.67467, acc: 0.70149
global step 1600, epoch: 3, batch: 460, loss: 0.91220, acc: 0.70139
global step 1610, epoch: 3, batch: 470, loss: 0.72232, acc: 0.70193
global step 1620, epoch: 3, batch: 480, loss: 0.77109, acc: 0.70195
global step 1630, epoch: 3, batch: 490, loss: 0.82069, acc: 0.70185
global step 1640, epoch: 3, batch: 500, loss: 0.78051, acc: 0.70194
global step 1650, epoch: 3, batch: 510, loss: 0.72643, acc: 0.70169
global step 1660, epoch: 3, batch: 520, loss: 0.73463, acc: 0.70189
global step 1670, epoch: 3, batch: 530, loss: 0.66037, acc: 0.70200
global step 1680, epoch: 3, batch: 540, loss: 0.77316, acc: 0.70217
global step 1690, epoch: 3, batch: 550, loss: 0.56472, acc: 0.70205
global step 1700, epoch: 3, batch: 560, loss: 0.60610, acc: 0.70240
global step 1710, epoch: 3, batch: 570, loss: 0.55158, acc: 0.70277
eval loss: 0.55542, accu: 0.78784

Baseline 运行了3个 epoch,用时约45分钟。

global step 10, epoch: 1, batch: 10, loss: 1.56753, acc: 0.19219
global step 20, epoch: 1, batch: 20, loss: 1.46965, acc: 0.25234
global step 30, epoch: 1, batch: 30, loss: 1.36288, acc: 0.30000
global step 40, epoch: 1, batch: 40, loss: 1.51080, acc: 0.31602
global step 50, epoch: 1, batch: 50, loss: 1.35407, acc: 0.32500
global step 60, epoch: 1, batch: 60, loss: 1.29739, acc: 0.34036
global step 70, epoch: 1, batch: 70, loss: 1.46518, acc: 0.34375
global step 80, epoch: 1, batch: 80, loss: 1.38125, acc: 0.34590
global step 90, epoch: 1, batch: 90, loss: 1.19846, acc: 0.35365
global step 100, epoch: 1, batch: 100, loss: 1.27309, acc: 0.36250
global step 110, epoch: 1, batch: 110, loss: 1.25976, acc: 0.37173


八、保存模型


model.save_pretrained('model')
tokenizer.save_pretrained('model')
[2022-06-12 15:32:15,519] [    INFO] - tokenizer config file saved in model/tokenizer_config.json
[2022-06-12 15:32:15,521] [    INFO] - Special tokens file saved in model/special_tokens_map.json
('model/tokenizer_config.json',
 'model/special_tokens_map.json',
 'model/added_tokens.json')


九、预测模型


from utils import predict
import pandas as pd
label_map = {0:'0', 1:'0.25', 2:'0.5', 3:'0.75', 4:'1'}
def preprocess_prediction_data(pd_data):
    examples = []
    for index, item in pd_data.iterrows():       
        examples.append({"query": item['anchor'], "title": item['target']})   
    return examples
test_file = 'test.csv'
pd_data = pd.read_csv(test_file, sep=',')
examples = preprocess_prediction_data(pd_data)
results = predict(
        model, examples, tokenizer, label_map, batch_size=batch_size)
for idx, text in enumerate(examples):
    print('Data: {} \t Label: {}'.format(text, results[idx]))
Data: {'query': 'opc drum', 'title': 'inorganic photoconductor drum'}    Label: 0.5
Data: {'query': 'adjust gas flow', 'title': 'altering gas flow'}   Label: 0.5
Data: {'query': 'lower trunnion', 'title': 'lower locating'}   Label: 0.5
Data: {'query': 'cap component', 'title': 'upper portion'}   Label: 0.25
Data: {'query': 'neural stimulation', 'title': 'artificial neural network'}    Label: 0
Data: {'query': 'dry corn', 'title': 'dry corn starch'}    Label: 0.5
Data: {'query': 'tunneling capacitor', 'title': 'capacitor housing'}   Label: 0.5
Data: {'query': 'angular contact bearing', 'title': 'contact therapy radiation'}   Label: 0
Data: {'query': 'produce liquid hydrocarbons', 'title': 'produce a treated stream'}    Label: 0.25
Data: {'query': 'diesel fuel tank', 'title': 'diesel fuel tanks'}    Label: 1
Data: {'query': 'chemical activity', 'title': 'dielectric characteristics'}    Label: 0.25
Data: {'query': 'transmit to platform', 'title': 'direct receiving'}   Label: 0.25
Data: {'query': 'oil tankers', 'title': 'oil carriers'}    Label: 0.5
Data: {'query': 'generate in layer', 'title': 'generate by layer'}   Label: 0.75
Data: {'query': 'slip segment', 'title': 'slip portion'}   Label: 0.75
Data: {'query': 'el display', 'title': 'illumination'}   Label: 0.5
Data: {'query': 'overflow device', 'title': 'oil filler'}    Label: 0.25
Data: {'query': 'beam traveling direction', 'title': 'concrete beam'}    Label: 0
Data: {'query': 'el display', 'title': 'electroluminescent'}   Label: 0.5
Data: {'query': 'equipment unit', 'title': 'power detection'}    Label: 0.5
Data: {'query': 'halocarbyl', 'title': 'halogen addition reaction'}    Label: 0.25
Data: {'query': 'perfluoroalkyl group', 'title': 'hydroxy'}    Label: 0.25
Data: {'query': 'speed control means', 'title': 'control loop'}    Label: 0.25
Data: {'query': 'arm design', 'title': 'steel plate'}    Label: 0.25
Data: {'query': 'hybrid bearing', 'title': 'bearing system'}   Label: 0.5
Data: {'query': 'end pins', 'title': 'end days'}   Label: 0
Data: {'query': 'organic starting', 'title': 'organic farming'}    Label: 0
Data: {'query': 'make of slabs', 'title': 'making cake'}   Label: 0
Data: {'query': 'seal teeth', 'title': 'teeth whitening'}    Label: 0
Data: {'query': 'carry by platform', 'title': 'carry on platform'}   Label: 0.75
Data: {'query': 'polls', 'title': 'pooling device'}    Label: 0
Data: {'query': 'upper clamp arm', 'title': 'end visual'}    Label: 0
Data: {'query': 'clocked storage', 'title': 'clocked storage device'}    Label: 0.75
Data: {'query': 'coupling factor', 'title': 'turns impedance'}   Label: 0.5
Data: {'query': 'different conductivity', 'title': 'carrier polarity'}   Label: 0
Data: {'query': 'hybrid bearing', 'title': 'corrosion resistant'}    Label: 0.25
# 输出 csv 文件
sub=pd.read_csv("sample_submission.csv",sep=',')
for idx, text in enumerate(examples):    
    sub['score'][idx]=results[idx]
sub.to_csv('submission.csv', sep=',')
!head submission.csv
,id,score
0,4112d61851461f60,0.5
1,09e418c93a776564,0.5
2,36baf228038e314b,0.5
3,1f37ead645e7f0c8,0.25
4,71a5b6ad068d531f,0
5,474c874d0c07bd21,0.5
6,442c114ed5c4e3c9,0.5
7,b8ae62ea5e1d8bdb,0
8,faaddaf8fcba8a3f,0.25


目录
相关文章
|
4月前
|
计算机视觉
【论文速递】Arxiv2018 - 加州伯克利大学借助引导网络实现快速、准确的小样本分割
【论文速递】Arxiv2018 - 加州伯克利大学借助引导网络实现快速、准确的小样本分割
37 0
|
3月前
|
人工智能 测试技术
Claude 3正式发布,超越GPT-4,一口气读15万单词,OpenAI最强的大对手!
Claude 3正式发布,超越GPT-4,一口气读15万单词,OpenAI最强的大对手!
64 0
|
自然语言处理 数据可视化 语音技术
7 Papers & Radios | ACL 2022最佳&杰出论文;谷歌3D扫描家用物品数据集(1)
7 Papers & Radios | ACL 2022最佳&杰出论文;谷歌3D扫描家用物品数据集
7 Papers & Radios | ACL 2022最佳&杰出论文;谷歌3D扫描家用物品数据集(1)
|
机器学习/深度学习 Web App开发 自然语言处理
基于神经标签搜索,中科院&微软亚研零样本多语言抽取式摘要入选ACL 2022
基于神经标签搜索,中科院&微软亚研零样本多语言抽取式摘要入选ACL 2022
|
自然语言处理 达摩院 算法
7 Papers & Radios | ACL 2022最佳&杰出论文;谷歌3D扫描家用物品数据集(2)
7 Papers & Radios | ACL 2022最佳&杰出论文;谷歌3D扫描家用物品数据集
|
机器学习/深度学习 人工智能 编解码
7 Papers & Radios | DeepMind强化学习控制核聚变登Nature;华为诺亚方舟实验室开源中文多模态数据集
7 Papers & Radios | DeepMind强化学习控制核聚变登Nature;华为诺亚方舟实验室开源中文多模态数据集
112 0
|
机器学习/深度学习 人工智能 编解码
7 Papers & Radios | Meta AI首个多模态自监督算法;牛津、谷歌等撰文综述AutoRL
7 Papers & Radios | Meta AI首个多模态自监督算法;牛津、谷歌等撰文综述AutoRL
107 0
|
JSON 自然语言处理 大数据
基于PaddleNLP的第五届中国法研杯LAIC2022——司法文本小样本多任务竞赛
基于PaddleNLP的第五届中国法研杯LAIC2022——司法文本小样本多任务竞赛
309 0
|
机器学习/深度学习 自然语言处理 数据可视化
【转】CCF2020问答匹配比赛:如何只用“bert”夺冠
【转】CCF2020问答匹配比赛:如何只用“bert”夺冠
207 0
【转】CCF2020问答匹配比赛:如何只用“bert”夺冠
|
机器学习/深度学习 自然语言处理
基于LSTM的美国大选的新闻真假分类【NLP 新年开胃菜】
基于LSTM的美国大选的新闻真假分类【NLP 新年开胃菜】
157 0
基于LSTM的美国大选的新闻真假分类【NLP 新年开胃菜】