一、 基于PaddleNLP的美国专利中的相似短语匹配竞赛
1.项目简介
竞赛地址: www.kaggle.com/competition…
美国专利商标局(USPTO)通过其开放数据门户提供世界上最大的科学,技术和商业信息存储库之一。专利是授予知识产权的一种形式,以换取公开披露新的和有用的发明。由于专利在授予之前经过严格的审查过程,并且由于美国的创新历史跨越了两个多世纪和1100万项专利,因此美国专利档案是数据量,质量和多样性的罕见组合。
“美国专利商标局为一台美国创新机器提供服务,该机器通过授予专利,注册商标和在全球范围内推广知识产权来永不停歇。美国专利商标局与世界分享了200多年的人类创造力,从灯泡到量子计算机。结合数据科学界的创造力,USPTO数据集具有无限的潜力,可以增强AI和ML模型,从而有利于科学和整个社会的进步。” — 美国专利商标局首席信息官杰米·霍尔科姆
在本次竞赛中,您将在新颖的语义相似性数据集上训练您的模型,以通过匹配专利文件中的关键短语来提取相关信息。在专利检索和审查过程中,确定短语之间的语义相似性对于确定以前是否描述过发明至关重要。例如,如果一项发明要求“电视机”,而先前的出版物描述了“电视机”,则理想情况下,模型将认识到这些是相同的,并协助专利律师或审查员检索相关文件。这超出了释义识别的范围;如果一项发明声称使用“强材料”,而另一项发明使用“钢”,那也可能是匹配的。什么算作“坚固的材料”因域而异(一个域中可能是钢,另一个域中可能是防撕裂织物,但您不希望降落伞由钢制成)。我们已将合作专利分类作为技术领域上下文作为附加功能,以帮助您消除这些情况的歧义
您能否建立一个模型来匹配短语,以便提取上下文信息,从而帮助专利界连接数百万份专利文件之间的点?
2. 数据集介绍
该数据集为英文数据集,,根据两段信息检索文本在语义上是否相似进行5分类。
在此数据集中,将显示成对的短语(an 和 a 短语),并要求您在从 0(完全不相似)到 1(含义相同)的尺度上对它们的相似程度进行评级。该挑战与标准语义相似性任务的不同之处在于,此处已在专利中对相似性进行了评分,特别是其CPC分类(版本2021.05),该分类指示了专利所涉及的主题。例如,虽然短语“鸟”和“科德角”在正常语言中可能具有较低的语义相似性,但如果在“房子”的上下文中考虑它们的含义的相似性要接近得多。
分数在 0-1 范围内,增量为 0.25,具有以下含义:
- 1.0 - 非常接近匹配。这通常是完全匹配的,除了可能在变位,数量(例如单数与复数)以及添加或删除非索引字(例如“the”,“and”,“or”)方面的差异。
- 0.75 - 关闭同义词,例如“手机”与“手机”。这还包括缩写,例如“TCP”->“传输控制协议”。
- 0.5 - 不具有相同含义的同义词(相同的函数,相同的属性)。这包括宽-窄(通名)和窄-宽(超名)匹配。
- 0.25 - 有些相关,例如,这两个短语位于同一高级域中,但不是同义词。这也包括反义词。
- 0.0 - 不相关。
文件
- 训练.csv - 训练集,包含短语、上下文及其相似性分数
- test.csv - 测试集,结构与训练集相同,但没有分数
- sample_submission.csv - 格式正确的示例提交文件
列
- id- 一对短语的唯一标识符
- anchor- 第一个短语
- target- 第二个短语
- context- CPC分类(版本2021.05),表示要在其中对相似性进行评分的主题
- score- 相似性。这是来自一个或多个手动专家评级的组合。
二、前置基础知识
1.文本语义匹配
文本语义匹配是自然语言处理中一个重要的基础问题,NLP 领域的很多任务都可以抽象为文本匹配任务。例如,信息检索可以归结为查询项和文档的匹配,问答系统可以归结为问题和候选答案的匹配,对话系统可以归结为对话和回复的匹配。语义匹配在搜索优化、推荐系统、快速检索排序、智能客服上都有广泛的应用。如何提升文本匹配的准确度,是自然语言处理领域的一个重要挑战。
- 信息检索:在信息检索领域的很多应用中,都需要根据原文本来检索与其相似的其他文本,使用场景非常普遍。
- 新闻推荐:通过用户刚刚浏览过的新闻标题,自动检索出其他的相似新闻,个性化地为用户做推荐,从而增强用户粘性,提升产品体验。
- 智能客服:用户输入一个问题后,自动为用户检索出相似的问题和答案,节约人工客服的成本,提高效率。
让我们来看一个简单的例子,比较各候选句子哪句和原句语义更相近:
原句:“车头如何放置车牌”
- 比较句1:“前牌照怎么装”
- 比较句2:“如何办理北京车牌”
- 比较句3:“后牌照怎么装”
(1)比较句1与原句,虽然句式和语序等存在较大差异,但是所表述的含义几乎相同
(2)比较句2与原句,虽然存在“如何” 、“车牌”等共现词,但是所表述的含义完全不同
(3)比较句3与原句,二者讨论的都是如何放置车牌的问题,只不过一个是前牌照,另一个是后牌照。二者间存在一定的语义相关性
所以语义相关性,句1大于句3,句3大于句2,这就是语义匹配。
2.短文本语义匹配网络
短文本语义匹配(SimilarityNet, SimNet)是一个计算短文本相似度的框架,可以根据用户输入的两个文本,计算出相似度得分。主要包括 BOW、CNN、RNN、MMDNN 等核心网络结构形式,提供语义相似度计算训练和预测框架,适用于信息检索、新闻推荐、智能客服等多个应用场景,帮助企业解决语义匹配问题。
SimNet 模型结构如图所示,包括输入层、表示层以及匹配层。
SimilarityNet 框架
3.SimilarityNet模型框架结构图
模型框架结构图如下图所示,其中 query 和 title 是数据集经过处理后的待匹配的文本,然后经过分词处理,编码成 id,经过 SimilarityNet 处理,得到输出,训练的损失函数使用的是交叉熵损失。
SimilarityNet 模型框架结构图
三、数据基本情况
1.数据解压缩
!unzip -oaq /home/aistudio/data/data143274/us-patent-phrase-to-phrase-matching.zip
2.数据查看
查看数据并进行行数统计。
- 36474 train.csv
- 37 test.csv
- 37 sample_submission.csv
基本上分5列,分别为id,以及对比的2个短语,以及使用的语境,得分。
!head train.csv
id,anchor,target,context,score 37d61fd2272659b1,abatement,abatement of pollution,A47,0.5 7b9652b17b68b7a4,abatement,act of abating,A47,0.75 36d72442aefd8232,abatement,active catalyst,A47,0.25 5296b0c19e1ce60e,abatement,eliminating process,A47,0.5 54c1e3b9184cb5b6,abatement,forest region,A47,0 067203128142739c,abatement,greenhouse gases,A47,0.25 061d17f04be2d1cf,abatement,increased rate,A47,0.25 e1f44e48399a2027,abatement,measurement level,A47,0.25 0a425937a3e86d10,abatement,minimising sounds,A47,0.5
!wc -l train.csv
36474 train.csv
!head test.csv
id,anchor,target,context 4112d61851461f60,opc drum,inorganic photoconductor drum,G02 09e418c93a776564,adjust gas flow,altering gas flow,F23 36baf228038e314b,lower trunnion,lower locating,B60 1f37ead645e7f0c8,cap component,upper portion,D06 71a5b6ad068d531f,neural stimulation,artificial neural network,H04 474c874d0c07bd21,dry corn,dry corn starch,C12 442c114ed5c4e3c9,tunneling capacitor,capacitor housing,G11 b8ae62ea5e1d8bdb,angular contact bearing,contact therapy radiation,B23 faaddaf8fcba8a3f,produce liquid hydrocarbons,produce a treated stream,C10
!wc -l test.csv
37 test.csv
!head sample_submission.csv
id,score 4112d61851461f60,0 09e418c93a776564,0 36baf228038e314b,0 1f37ead645e7f0c8,0 71a5b6ad068d531f,0 474c874d0c07bd21,0 442c114ed5c4e3c9,0 b8ae62ea5e1d8bdb,0 faaddaf8fcba8a3f,0
!wc -l sample_submission.csv
37 sample_submission.csv
四、环境配置
此次使用PaddleNLP快速处理,PaddleNLP需要升级,版本号为 2.3.3 ,并需要重启 notebook。
1.升级PaddleNLP
- 升级pip
- 升级paddlenlp
!python -m pip install --upgrade pip --user>log.log # AI Studio上的PaddleNLP版本过低,所以需要首先升级PaddleNLP !pip install paddlenlp --upgrade --user>log.log
2.重启并导入必要库
paddlenlp升级完毕,需要重启,重新加载。
#导入paddle相关的包 import paddle import paddle.nn as nn # 导入PaddleNLP相关的包 import paddlenlp as ppnlp from paddlenlp.data import JiebaTokenizer, Pad, Stack, Tuple, Vocab # from utils import convert_example from paddlenlp.datasets import MapDataset from paddle.dataset.common import md5file from paddlenlp.datasets import DatasetBuilder from paddlenlp.datasets import load_dataset from paddle.io import Dataset, Subset from paddlenlp.datasets import MapDataset import numpy as np import paddle.nn.functional as F import time import os from paddlenlp.data import Stack, Pad, Tuple print("本项目基于Paddle的版本号为:"+ paddle.__version__) print("本项目基于PaddleNLP的版本号为:"+ ppnlp.__version__)
本项目基于Paddle的版本号为:2.3.0 本项目基于PaddleNLP的版本号为:2.3.3
五、加载预训练模型
1.ERNIE模型汇总
下表汇总介绍了目前PaddleNLP支持的ERNIE模型对应预训练权重。 关于模型的具体细节可以参考对应链接。
Pretrained Weight |
Language |
Details of the model |
|
Chinese |
12-layer, 768-hidden, 12-heads, 108M parameters. Trained on Chinese text. |
|
Chinese |
3-layer, 1024-hidden, 16-heads, _M parameters. Trained on Chinese text. |
|
English |
12-layer, 768-hidden, 12-heads, 103M parameters. Trained on lower-cased English text. |
|
English |
12-layer, 768-hidden, 12-heads, 110M parameters. Trained on finetuned squad text. |
|
English |
24-layer, 1024-hidden, 16-heads, 336M parameters. Trained on lower-cased English text. |
|
Chinese |
Please refer to:zhui/ernie-1.0-cluecorpussmall |
2.模型加载
# 若运行失败,请重启项目 MODEL_NAME = "ernie-2.0-base-en" ernie_model = ppnlp.transformers.ErnieModel.from_pretrained(MODEL_NAME) model = ppnlp.transformers.ErnieForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=5)
[2022-06-12 15:23:06,757] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-2.0-base-en/ernie_v2_eng_base.pdparams W0612 15:23:06.761130 13195 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1 W0612 15:23:06.765818 13195 gpu_context.cc:306] device: 0, cuDNN Version: 7.6. [2022-06-12 15:23:12,415] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-2.0-base-en/ernie_v2_eng_base.pdparams
# 定义ERNIE模型对应的 tokenizer,并查看效果 tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained(MODEL_NAME)
[2022-06-12 15:23:13,762] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-2.0-base-en/vocab.txt
六、加载数据集
1.查看数据
import pandas as pd train_data = "./train.csv" train_data = pd.read_csv(train_data, sep=',') test_data = "./test.csv" test_data = pd.read_csv(test_data, sep=',') train_data.head(5) .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
id | anchor | target | context | score | |
0 | 37d61fd2272659b1 | abatement | abatement of pollution | A47 | 0.50 |
1 | 7b9652b17b68b7a4 | abatement | act of abating | A47 | 0.75 |
2 | 36d72442aefd8232 | abatement | active catalyst | A47 | 0.25 |
3 | 5296b0c19e1ce60e | abatement | eliminating process | A47 | 0.50 |
4 | 54c1e3b9184cb5b6 | abatement | forest region | A47 | 0.00 |
print(max(train_data['anchor'].str.len())) print(max(train_data['target'].str.len()))
38 98
2.score分布
%matplotlib inline train_data.hist(column="score", # 具体列 figsize=(16,16), # 图片大小 color="red")
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7fc3ee0d8710>]], dtype=object)
3.label转换
def score2label(x): if x==0.0: y=0 elif x==0.25: y=1 elif x==0.5: y=2 elif x==0.75: y=3 elif x==1.0: y=4 return y
train_data['label']=train_data.score.map(score2label)
4.自定义read
def read_train(pd_data): for index, item in pd_data.iterrows(): yield {"query": item['anchor'], "title": item['target'], "label": item['label']} def read_test(pd_data): for index, item in pd_data.iterrows(): yield {"query": item['anchor'], "title": item['target'], "label": 0}
5.加载并分割数据集
# data_path为read()方法的参数 train_ds = load_dataset(read_train, pd_data=train_data,lazy=False)
for i in range(5): print(train_ds[i])
{'query': 'abatement', 'title': 'abatement of pollution', 'label': 2} {'query': 'abatement', 'title': 'act of abating', 'label': 3} {'query': 'abatement', 'title': 'active catalyst', 'label': 1} {'query': 'abatement', 'title': 'eliminating process', 'label': 2} {'query': 'abatement', 'title': 'forest region', 'label': 0}
6.数据预处理
通过 paddlenlp 加载进来的 数据集是原始的明文数据集,这部分我们来实现组 batch、tokenize 等预处理逻辑,将原始明文数据转换成网络训练的输入数据。
6.1定义样本转换函数
from functools import partial from paddlenlp.data import Stack, Tuple, Pad from utils import convert_example, create_dataloader batch_size = 64 max_seq_length = 256 trans_func = partial( convert_example, tokenizer=tokenizer, max_seq_length=max_seq_length) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment Stack(dtype="int64") # label ): [data for data in fn(samples)] train_data_loader = create_dataloader( train_ds, mode='train', batch_size=batch_size, batchify_fn=batchify_fn, trans_fn=trans_func)
from paddlenlp.transformers import LinearDecayWithWarmup # 训练过程中的最大学习率 learning_rate = 5e-5 # 训练轮次 epochs = 3 # 学习率预热比例 warmup_proportion = 0.1 # 权重衰减系数,类似模型正则项策略,避免模型过拟合 weight_decay = 0.01 num_training_steps = len(train_data_loader) * epochs lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion) optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, parameters=model.parameters(), weight_decay=weight_decay, apply_decay_param_fun=lambda x: x in [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) criterion = paddle.nn.loss.CrossEntropyLoss() metric = paddle.metric.Accuracy()
七、模型训练
import paddle.nn.functional as F from utils import evaluate global_step = 0 for epoch in range(1, epochs + 1): for step, batch in enumerate(train_data_loader, start=1): input_ids, segment_ids, labels = batch logits = model(input_ids, segment_ids) loss = criterion(logits, labels) probs = F.softmax(logits, axis=1) correct = metric.compute(probs, labels) metric.update(correct) acc = metric.accumulate() global_step += 1 if global_step % 10 == 0 : print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc)) loss.backward() optimizer.step() lr_scheduler.step() optimizer.clear_grad() evaluate(model, criterion, metric, train_data_loader)
global step 10, epoch: 1, batch: 10, loss: 1.56753, acc: 0.19219 global step 20, epoch: 1, batch: 20, loss: 1.46965, acc: 0.25234 global step 30, epoch: 1, batch: 30, loss: 1.36288, acc: 0.30000 global step 40, epoch: 1, batch: 40, loss: 1.51080, acc: 0.31602 global step 50, epoch: 1, batch: 50, loss: 1.35407, acc: 0.32500 global step 60, epoch: 1, batch: 60, loss: 1.29739, acc: 0.34036 global step 70, epoch: 1, batch: 70, loss: 1.46518, acc: 0.34375 global step 80, epoch: 1, batch: 80, loss: 1.38125, acc: 0.34590 global step 90, epoch: 1, batch: 90, loss: 1.19846, acc: 0.35365 global step 100, epoch: 1, batch: 100, loss: 1.27309, acc: 0.36250 global step 110, epoch: 1, batch: 110, loss: 1.25976, acc: 0.37173 global step 120, epoch: 1, batch: 120, loss: 1.31990, acc: 0.37799 global step 130, epoch: 1, batch: 130, loss: 1.28407, acc: 0.38762 global step 140, epoch: 1, batch: 140, loss: 1.33328, acc: 0.39308 global step 150, epoch: 1, batch: 150, loss: 1.16941, acc: 0.39427 global step 160, epoch: 1, batch: 160, loss: 1.16217, acc: 0.39951 global step 170, epoch: 1, batch: 170, loss: 1.30883, acc: 0.40450 global step 180, epoch: 1, batch: 180, loss: 1.40474, acc: 0.40608 global step 190, epoch: 1, batch: 190, loss: 1.23098, acc: 0.41020 global step 200, epoch: 1, batch: 200, loss: 1.09877, acc: 0.41156 global step 210, epoch: 1, batch: 210, loss: 1.27568, acc: 0.41235 global step 220, epoch: 1, batch: 220, loss: 1.25363, acc: 0.41506 global step 230, epoch: 1, batch: 230, loss: 1.35899, acc: 0.41692 global step 240, epoch: 1, batch: 240, loss: 1.32752, acc: 0.41823 global step 250, epoch: 1, batch: 250, loss: 1.10830, acc: 0.42094 global step 260, epoch: 1, batch: 260, loss: 1.19847, acc: 0.42284 global step 270, epoch: 1, batch: 270, loss: 1.31777, acc: 0.42541 global step 280, epoch: 1, batch: 280, loss: 1.13283, acc: 0.42840 global step 290, epoch: 1, batch: 290, loss: 1.03019, acc: 0.43141 global step 300, epoch: 1, batch: 300, loss: 1.28243, acc: 0.43297 global step 310, epoch: 1, batch: 310, loss: 1.02955, acc: 0.43533 global step 320, epoch: 1, batch: 320, loss: 1.38921, acc: 0.43706 global step 330, epoch: 1, batch: 330, loss: 1.21835, acc: 0.43954 global step 340, epoch: 1, batch: 340, loss: 1.12567, acc: 0.44040 global step 350, epoch: 1, batch: 350, loss: 1.28655, acc: 0.44214 global step 360, epoch: 1, batch: 360, loss: 1.21317, acc: 0.44462 global step 370, epoch: 1, batch: 370, loss: 1.15794, acc: 0.44561 global step 380, epoch: 1, batch: 380, loss: 1.05953, acc: 0.44704 global step 390, epoch: 1, batch: 390, loss: 1.02538, acc: 0.44804 global step 400, epoch: 1, batch: 400, loss: 1.23620, acc: 0.44949 global step 410, epoch: 1, batch: 410, loss: 1.03123, acc: 0.45107 global step 420, epoch: 1, batch: 420, loss: 1.06850, acc: 0.45279 global step 430, epoch: 1, batch: 430, loss: 1.03816, acc: 0.45396 global step 440, epoch: 1, batch: 440, loss: 1.17245, acc: 0.45568 global step 450, epoch: 1, batch: 450, loss: 1.09026, acc: 0.45684 global step 460, epoch: 1, batch: 460, loss: 1.12116, acc: 0.45822 global step 470, epoch: 1, batch: 470, loss: 1.12357, acc: 0.46051 global step 480, epoch: 1, batch: 480, loss: 1.14521, acc: 0.46260 global step 490, epoch: 1, batch: 490, loss: 1.15795, acc: 0.46403 global step 500, epoch: 1, batch: 500, loss: 1.05499, acc: 0.46506 global step 510, epoch: 1, batch: 510, loss: 1.17101, acc: 0.46572 global step 520, epoch: 1, batch: 520, loss: 1.26547, acc: 0.46596 global step 530, epoch: 1, batch: 530, loss: 1.14175, acc: 0.46713 global step 540, epoch: 1, batch: 540, loss: 1.13119, acc: 0.46881 global step 550, epoch: 1, batch: 550, loss: 1.21626, acc: 0.47000 global step 560, epoch: 1, batch: 560, loss: 1.24214, acc: 0.47146 global step 570, epoch: 1, batch: 570, loss: 1.04112, acc: 0.47265 eval loss: 0.97536, accu: 0.60310 global step 580, epoch: 2, batch: 10, loss: 0.93487, acc: 0.60781 global step 590, epoch: 2, batch: 20, loss: 0.90623, acc: 0.60625 global step 600, epoch: 2, batch: 30, loss: 0.97992, acc: 0.59271 global step 610, epoch: 2, batch: 40, loss: 1.11511, acc: 0.58555 global step 620, epoch: 2, batch: 50, loss: 1.02230, acc: 0.58969 global step 630, epoch: 2, batch: 60, loss: 1.06591, acc: 0.58698 global step 640, epoch: 2, batch: 70, loss: 0.76665, acc: 0.58527 global step 650, epoch: 2, batch: 80, loss: 0.96098, acc: 0.58926 global step 660, epoch: 2, batch: 90, loss: 0.93122, acc: 0.58490 global step 670, epoch: 2, batch: 100, loss: 0.91359, acc: 0.58313 global step 680, epoch: 2, batch: 110, loss: 0.94778, acc: 0.58466 global step 690, epoch: 2, batch: 120, loss: 1.03503, acc: 0.58672 global step 700, epoch: 2, batch: 130, loss: 1.22159, acc: 0.58437 global step 710, epoch: 2, batch: 140, loss: 1.15416, acc: 0.58460 global step 720, epoch: 2, batch: 150, loss: 1.01074, acc: 0.58365 global step 730, epoch: 2, batch: 160, loss: 0.92790, acc: 0.58389 global step 740, epoch: 2, batch: 170, loss: 0.80486, acc: 0.58529 global step 750, epoch: 2, batch: 180, loss: 0.82019, acc: 0.58672 global step 760, epoch: 2, batch: 190, loss: 0.95490, acc: 0.58824 global step 770, epoch: 2, batch: 200, loss: 0.91211, acc: 0.58820 global step 780, epoch: 2, batch: 210, loss: 0.91922, acc: 0.58824 global step 790, epoch: 2, batch: 220, loss: 0.93980, acc: 0.58899 global step 800, epoch: 2, batch: 230, loss: 1.18243, acc: 0.58927 global step 810, epoch: 2, batch: 240, loss: 1.00865, acc: 0.58939 global step 820, epoch: 2, batch: 250, loss: 0.91286, acc: 0.59031 global step 830, epoch: 2, batch: 260, loss: 0.94963, acc: 0.59069 global step 840, epoch: 2, batch: 270, loss: 1.07120, acc: 0.59080 global step 850, epoch: 2, batch: 280, loss: 0.94161, acc: 0.59096 global step 860, epoch: 2, batch: 290, loss: 1.03282, acc: 0.59154 global step 870, epoch: 2, batch: 300, loss: 1.03144, acc: 0.59260 global step 880, epoch: 2, batch: 310, loss: 0.92614, acc: 0.59350 global step 890, epoch: 2, batch: 320, loss: 0.97692, acc: 0.59390 global step 900, epoch: 2, batch: 330, loss: 0.99512, acc: 0.59332 global step 910, epoch: 2, batch: 340, loss: 1.02203, acc: 0.59283 global step 920, epoch: 2, batch: 350, loss: 0.98110, acc: 0.59308 global step 930, epoch: 2, batch: 360, loss: 1.11808, acc: 0.59388 global step 940, epoch: 2, batch: 370, loss: 0.96240, acc: 0.59468 global step 950, epoch: 2, batch: 380, loss: 0.78424, acc: 0.59552 global step 960, epoch: 2, batch: 390, loss: 0.96898, acc: 0.59583 global step 970, epoch: 2, batch: 400, loss: 1.03831, acc: 0.59574 global step 980, epoch: 2, batch: 410, loss: 0.92671, acc: 0.59661 global step 990, epoch: 2, batch: 420, loss: 0.91024, acc: 0.59702 global step 1000, epoch: 2, batch: 430, loss: 1.06812, acc: 0.59695 global step 1010, epoch: 2, batch: 440, loss: 0.93563, acc: 0.59716 global step 1020, epoch: 2, batch: 450, loss: 0.92760, acc: 0.59708 global step 1030, epoch: 2, batch: 460, loss: 0.79893, acc: 0.59789 global step 1040, epoch: 2, batch: 470, loss: 0.89773, acc: 0.59850 global step 1050, epoch: 2, batch: 480, loss: 0.82244, acc: 0.59847 global step 1060, epoch: 2, batch: 490, loss: 0.93188, acc: 0.59892 global step 1070, epoch: 2, batch: 500, loss: 0.88342, acc: 0.59953 global step 1080, epoch: 2, batch: 510, loss: 0.73247, acc: 0.60034 global step 1090, epoch: 2, batch: 520, loss: 0.83115, acc: 0.60072 global step 1100, epoch: 2, batch: 530, loss: 0.89471, acc: 0.60085 global step 1110, epoch: 2, batch: 540, loss: 0.83178, acc: 0.60084 global step 1120, epoch: 2, batch: 550, loss: 0.74519, acc: 0.60176 global step 1130, epoch: 2, batch: 560, loss: 0.88650, acc: 0.60240 global step 1140, epoch: 2, batch: 570, loss: 0.83957, acc: 0.60258 eval loss: 0.69819, accu: 0.72596 global step 1150, epoch: 3, batch: 10, loss: 0.88984, acc: 0.67500 global step 1160, epoch: 3, batch: 20, loss: 0.88784, acc: 0.68672 global step 1170, epoch: 3, batch: 30, loss: 0.59687, acc: 0.69375 global step 1180, epoch: 3, batch: 40, loss: 0.66985, acc: 0.69453 global step 1190, epoch: 3, batch: 50, loss: 0.83693, acc: 0.68875 global step 1200, epoch: 3, batch: 60, loss: 0.70200, acc: 0.68672 global step 1210, epoch: 3, batch: 70, loss: 0.64738, acc: 0.69018 global step 1220, epoch: 3, batch: 80, loss: 0.60940, acc: 0.69141 global step 1230, epoch: 3, batch: 90, loss: 0.68041, acc: 0.69219 global step 1240, epoch: 3, batch: 100, loss: 0.73690, acc: 0.69250 global step 1250, epoch: 3, batch: 110, loss: 0.69368, acc: 0.69304 global step 1260, epoch: 3, batch: 120, loss: 0.71949, acc: 0.69128 global step 1270, epoch: 3, batch: 130, loss: 0.96168, acc: 0.69291 global step 1280, epoch: 3, batch: 140, loss: 0.69623, acc: 0.69375 global step 1290, epoch: 3, batch: 150, loss: 0.69540, acc: 0.69594 global step 1300, epoch: 3, batch: 160, loss: 0.92395, acc: 0.69502 global step 1310, epoch: 3, batch: 170, loss: 0.67467, acc: 0.69568 global step 1320, epoch: 3, batch: 180, loss: 0.71054, acc: 0.69696 global step 1330, epoch: 3, batch: 190, loss: 0.77184, acc: 0.69646 global step 1340, epoch: 3, batch: 200, loss: 0.72983, acc: 0.69609 global step 1350, epoch: 3, batch: 210, loss: 0.61290, acc: 0.69695 global step 1360, epoch: 3, batch: 220, loss: 0.73045, acc: 0.69730 global step 1370, epoch: 3, batch: 230, loss: 0.68346, acc: 0.69755 global step 1380, epoch: 3, batch: 240, loss: 0.72073, acc: 0.69772 global step 1390, epoch: 3, batch: 250, loss: 0.69043, acc: 0.69769 global step 1400, epoch: 3, batch: 260, loss: 0.76760, acc: 0.69730 global step 1410, epoch: 3, batch: 270, loss: 0.81253, acc: 0.69751 global step 1420, epoch: 3, batch: 280, loss: 0.92509, acc: 0.69665 global step 1430, epoch: 3, batch: 290, loss: 0.61820, acc: 0.69698 global step 1440, epoch: 3, batch: 300, loss: 0.60527, acc: 0.69766 global step 1450, epoch: 3, batch: 310, loss: 0.78595, acc: 0.69783 global step 1460, epoch: 3, batch: 320, loss: 0.52455, acc: 0.69834 global step 1470, epoch: 3, batch: 330, loss: 0.52329, acc: 0.69938 global step 1480, epoch: 3, batch: 340, loss: 0.64003, acc: 0.69995 global step 1490, epoch: 3, batch: 350, loss: 0.84223, acc: 0.69996 global step 1500, epoch: 3, batch: 360, loss: 0.73410, acc: 0.69987 global step 1510, epoch: 3, batch: 370, loss: 0.62209, acc: 0.69996 global step 1520, epoch: 3, batch: 380, loss: 0.77763, acc: 0.69959 global step 1530, epoch: 3, batch: 390, loss: 0.78981, acc: 0.70004 global step 1540, epoch: 3, batch: 400, loss: 0.78278, acc: 0.69969 global step 1550, epoch: 3, batch: 410, loss: 0.47082, acc: 0.69954 global step 1560, epoch: 3, batch: 420, loss: 0.54026, acc: 0.69985 global step 1570, epoch: 3, batch: 430, loss: 0.62268, acc: 0.70087 global step 1580, epoch: 3, batch: 440, loss: 0.70356, acc: 0.70131 global step 1590, epoch: 3, batch: 450, loss: 0.67467, acc: 0.70149 global step 1600, epoch: 3, batch: 460, loss: 0.91220, acc: 0.70139 global step 1610, epoch: 3, batch: 470, loss: 0.72232, acc: 0.70193 global step 1620, epoch: 3, batch: 480, loss: 0.77109, acc: 0.70195 global step 1630, epoch: 3, batch: 490, loss: 0.82069, acc: 0.70185 global step 1640, epoch: 3, batch: 500, loss: 0.78051, acc: 0.70194 global step 1650, epoch: 3, batch: 510, loss: 0.72643, acc: 0.70169 global step 1660, epoch: 3, batch: 520, loss: 0.73463, acc: 0.70189 global step 1670, epoch: 3, batch: 530, loss: 0.66037, acc: 0.70200 global step 1680, epoch: 3, batch: 540, loss: 0.77316, acc: 0.70217 global step 1690, epoch: 3, batch: 550, loss: 0.56472, acc: 0.70205 global step 1700, epoch: 3, batch: 560, loss: 0.60610, acc: 0.70240 global step 1710, epoch: 3, batch: 570, loss: 0.55158, acc: 0.70277 eval loss: 0.55542, accu: 0.78784
Baseline 运行了3个 epoch,用时约45分钟。
global step 10, epoch: 1, batch: 10, loss: 1.56753, acc: 0.19219 global step 20, epoch: 1, batch: 20, loss: 1.46965, acc: 0.25234 global step 30, epoch: 1, batch: 30, loss: 1.36288, acc: 0.30000 global step 40, epoch: 1, batch: 40, loss: 1.51080, acc: 0.31602 global step 50, epoch: 1, batch: 50, loss: 1.35407, acc: 0.32500 global step 60, epoch: 1, batch: 60, loss: 1.29739, acc: 0.34036 global step 70, epoch: 1, batch: 70, loss: 1.46518, acc: 0.34375 global step 80, epoch: 1, batch: 80, loss: 1.38125, acc: 0.34590 global step 90, epoch: 1, batch: 90, loss: 1.19846, acc: 0.35365 global step 100, epoch: 1, batch: 100, loss: 1.27309, acc: 0.36250 global step 110, epoch: 1, batch: 110, loss: 1.25976, acc: 0.37173
八、保存模型
model.save_pretrained('model') tokenizer.save_pretrained('model')
[2022-06-12 15:32:15,519] [ INFO] - tokenizer config file saved in model/tokenizer_config.json [2022-06-12 15:32:15,521] [ INFO] - Special tokens file saved in model/special_tokens_map.json ('model/tokenizer_config.json', 'model/special_tokens_map.json', 'model/added_tokens.json')
九、预测模型
from utils import predict import pandas as pd label_map = {0:'0', 1:'0.25', 2:'0.5', 3:'0.75', 4:'1'} def preprocess_prediction_data(pd_data): examples = [] for index, item in pd_data.iterrows(): examples.append({"query": item['anchor'], "title": item['target']}) return examples test_file = 'test.csv' pd_data = pd.read_csv(test_file, sep=',') examples = preprocess_prediction_data(pd_data)
results = predict( model, examples, tokenizer, label_map, batch_size=batch_size)
for idx, text in enumerate(examples): print('Data: {} \t Label: {}'.format(text, results[idx]))
Data: {'query': 'opc drum', 'title': 'inorganic photoconductor drum'} Label: 0.5 Data: {'query': 'adjust gas flow', 'title': 'altering gas flow'} Label: 0.5 Data: {'query': 'lower trunnion', 'title': 'lower locating'} Label: 0.5 Data: {'query': 'cap component', 'title': 'upper portion'} Label: 0.25 Data: {'query': 'neural stimulation', 'title': 'artificial neural network'} Label: 0 Data: {'query': 'dry corn', 'title': 'dry corn starch'} Label: 0.5 Data: {'query': 'tunneling capacitor', 'title': 'capacitor housing'} Label: 0.5 Data: {'query': 'angular contact bearing', 'title': 'contact therapy radiation'} Label: 0 Data: {'query': 'produce liquid hydrocarbons', 'title': 'produce a treated stream'} Label: 0.25 Data: {'query': 'diesel fuel tank', 'title': 'diesel fuel tanks'} Label: 1 Data: {'query': 'chemical activity', 'title': 'dielectric characteristics'} Label: 0.25 Data: {'query': 'transmit to platform', 'title': 'direct receiving'} Label: 0.25 Data: {'query': 'oil tankers', 'title': 'oil carriers'} Label: 0.5 Data: {'query': 'generate in layer', 'title': 'generate by layer'} Label: 0.75 Data: {'query': 'slip segment', 'title': 'slip portion'} Label: 0.75 Data: {'query': 'el display', 'title': 'illumination'} Label: 0.5 Data: {'query': 'overflow device', 'title': 'oil filler'} Label: 0.25 Data: {'query': 'beam traveling direction', 'title': 'concrete beam'} Label: 0 Data: {'query': 'el display', 'title': 'electroluminescent'} Label: 0.5 Data: {'query': 'equipment unit', 'title': 'power detection'} Label: 0.5 Data: {'query': 'halocarbyl', 'title': 'halogen addition reaction'} Label: 0.25 Data: {'query': 'perfluoroalkyl group', 'title': 'hydroxy'} Label: 0.25 Data: {'query': 'speed control means', 'title': 'control loop'} Label: 0.25 Data: {'query': 'arm design', 'title': 'steel plate'} Label: 0.25 Data: {'query': 'hybrid bearing', 'title': 'bearing system'} Label: 0.5 Data: {'query': 'end pins', 'title': 'end days'} Label: 0 Data: {'query': 'organic starting', 'title': 'organic farming'} Label: 0 Data: {'query': 'make of slabs', 'title': 'making cake'} Label: 0 Data: {'query': 'seal teeth', 'title': 'teeth whitening'} Label: 0 Data: {'query': 'carry by platform', 'title': 'carry on platform'} Label: 0.75 Data: {'query': 'polls', 'title': 'pooling device'} Label: 0 Data: {'query': 'upper clamp arm', 'title': 'end visual'} Label: 0 Data: {'query': 'clocked storage', 'title': 'clocked storage device'} Label: 0.75 Data: {'query': 'coupling factor', 'title': 'turns impedance'} Label: 0.5 Data: {'query': 'different conductivity', 'title': 'carrier polarity'} Label: 0 Data: {'query': 'hybrid bearing', 'title': 'corrosion resistant'} Label: 0.25
# 输出 csv 文件 sub=pd.read_csv("sample_submission.csv",sep=',') for idx, text in enumerate(examples): sub['score'][idx]=results[idx] sub.to_csv('submission.csv', sep=',')
!head submission.csv
,id,score 0,4112d61851461f60,0.5 1,09e418c93a776564,0.5 2,36baf228038e314b,0.5 3,1f37ead645e7f0c8,0.25 4,71a5b6ad068d531f,0 5,474c874d0c07bd21,0.5 6,442c114ed5c4e3c9,0.5 7,b8ae62ea5e1d8bdb,0 8,faaddaf8fcba8a3f,0.25