直接使用
请打开基于EasyNLP的BERT英文机器阅读理解,并点击右上角 “ 在DSW中打开” 。
基于EasyNLP的BERT英文机器阅读理解
EasyNLP是阿里云PAI算法团队基于PyTorch开发的易用且丰富的NLP算法框架( https://github.com/alibaba/EasyNLP ),支持常用的中文预训练模型和大模型落地技术,并且提供了从训练到部署的一站式NLP开发体验。EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。
本文以机器阅读理解任务为例,为您介绍如何在PAI-DSW中基于EasyNLP快速使用BERT进行英文机器阅读理解模型的构建、训练、评估、预测。
关于BERT
BERT是2018年10月由Google AI研究院提出的一种预训练语言表征模型,全称是Bidirectional Encoder Representation from Transformers。与当前许多广泛应用于NLP领域的模型一样,BERT也采用Transformer encoder结构,不过相比传统方法中只使用单向语言模型、或把两个单向语言模型进行浅层拼接来预训练,BERT采用新的masked language model(MLM),以生成深度的双向语言表征。作为当前NLP领域最常用的模型之一,BERT刚提出时便在11种不同的NLP任务中创造出SOTA表现,成为NLP发展史上的里程碑式的模型成就。
关于机器阅读理解
机器阅读理解(machine reading comprehension),指的是同时给模型输入一个问题(question)和一段描述(context),需要模型根据给定的描述给出答案(answer)。根据任务范式的不同,机器阅读理解任务通常可以分为4大类:完型填空式(Cloze tests)、多项选择式(Multi-choice)、片段抽取式(Span extraction)及自由生成式(Free answering)。其中片段抽取式根据问句(question),直接从篇章文本(context)中预测答案文本(answer)的起止位置(start/end positions),从而抽取出答案。由于其与真实场景接近,难度适中,易于评测,且有SQuAD等高质量数据集支撑,因此成为当前的主流阅读理解任务。
运行环境要求
建议用户使用:Python 3.6,Pytorch 1.8镜像,GPU机型 P100 or V100,内存至少为 32G
EasyNLP安装
建议从GitHub下载EasyNLP源代码进行安装,命令如下:
! git clone https://github.com/alibaba/EasyNLP.git ! pip install -r EasyNLP/requirements.txt ! cd EasyNLP && python setup.py install
您可以使用如下命令验证是否安装成功:
! which easynlp
/home/pai/bin/easynlp
如果您系统内已经安装完easynlp的CLI工具,则说明EasyNLP代码库已经安装。
数据准备
首先,您需要进入指定模型目录,下载用于本示例的训练和测试集,并创建保存模型的文件夹,命令如下:
! cd examples/appzoo_tutorials/machine_reading_comprehension ! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/train_squad.tsv ! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/dev_squad.tsv
/bin/bash: line 0: cd: examples/appzoo_tutorials/machine_reading_comprehension: No such file or directory --2022-09-07 19:51:13-- http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/train_squad.tsv Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27 Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 77308556 (74M) [text/tab-separated-values] Saving to: ‘train_squad.tsv’ train_squad.tsv 100%[===================>] 73.73M 24.6MB/s in 3.0s 2022-09-07 19:51:17 (24.6 MB/s) - ‘train_squad.tsv’ saved [77308556/77308556] --2022-09-07 19:51:17-- http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/dev_squad.tsv Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27 Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 9586894 (9.1M) [text/tab-separated-values] Saving to: ‘dev_squad.tsv’ dev_squad.tsv 100%[===================>] 9.14M 4.96MB/s in 1.8s 2022-09-07 19:51:19 (4.96 MB/s) - ‘dev_squad.tsv’ saved [9586894/9586894] dev_squad.tsv 100%[===================>] 9.14M 4.96MB/s in 1.8s 2022-09-07 19:51:19 (4.96 MB/s) - ‘dev_squad.tsv’ saved [9586894/9586894]
训练和测试数据都是以\t隔开的.tsv文件,数据下载完成后,可以通过以下代码查看前5条数据。每行为一个数据,每列为一个字段值,包括需要进行文本匹配的两个句子,以及对应的匹配结果标签。
print('Training data sample:') ! head -n 5 train_squad.tsv print('Development set data sample:') ! head -n 5 dev_squad.tsv
Training data sample: 5733be284776f41900661182 Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? Saint Bernadette Soubirous 515 University_of_Notre_Dame 5733be284776f4190066117f Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. What is in front of the Notre Dame Main Building? a copper statue of Christ 188 University_of_Notre_Dame 5733be284776f41900661180 Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. The Basilica of the Sacred heart at Notre Dame is beside to which structure? the Main Building 279 University_of_Notre_Dame 5733be284776f41900661181 Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. What is the Grotto at Notre Dame? a Marian place of prayer and reflection 381 University_of_Notre_Dame 5733be284776f4190066117e Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. What sits on top of the Main Building at Notre Dame? a golden statue of the Virgin Mary 92 University_of_Notre_Dame Development set data sample: 56be4db0acb8001400a502ec Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team represented the AFC at Super Bowl 50? Denver Broncos 177 Super_Bowl_50 56be4db0acb8001400a502ed Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team represented the NFC at Super Bowl 50? Carolina Panthers 249 Super_Bowl_50 56be4db0acb8001400a502ee Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Where did Super Bowl 50 take place? Santa Clara, California 403 Super_Bowl_50 56be4db0acb8001400a502ef Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team won Super Bowl 50? Denver Broncos 177 Super_Bowl_50 56be4db0acb8001400a502f0 Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. What color was used to emphasize the 50th anniversary of the Super Bowl? gold 488 Super_Bowl_50
初始化
在Python 3.6环境下,我们首先从刚刚安装好的EasyNLP中引入模型运行需要的各种库,并做一些初始化。在本教程中,我们使用bert-base-uncased。EasyNLP中集成了丰富的预训练模型库,如果想尝试其他预训练模型,如RoBERTa、albert等,也可以在user_defined_parameters中进行相应修改,具体的模型名称可见模型列表。EasyNLP当前同时支持中英文的阅读理解,这里只需要在user_defined_parameters中指定language=en,即可指定英文文本处理相应配置。此外,机器阅读理解任务中有一些特殊的参数,如doc_stride等,也需要在user_defined_parameters中预先声明,这些参数的具体含义,我们将在下面的“载入数据”小节中详细阐述。
# 为了避免EasyNLP中的args与Jupyter系统的冲突,需要手动设置,否则无法进行初始化。 # 在命令行或py文件中运行文中代码则可忽略下述代码。 import sys sys.argv = ['main.py']
import torch.cuda from easynlp.appzoo import MachineReadingComprehensionDataset from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator, get_application_model_for_evaluation from easynlp.core import Trainer, PredictorManager from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path from easynlp.utils.global_vars import parse_user_defined_parameters initialize_easynlp() args = get_args() user_defined_parameters = parse_user_defined_parameters('pretrain_model_name_or_path=bert-base-uncased language=en qas_id=qas_id answer_name=answer_text start_position_name=start_position_character max_query_length=64 max_answer_length=30 doc_stride=128 n_best_size=10 output_answer_file=dev.ans.csv') args.checkpoint_dir = "./squad_model_dir/"
[2022-09-07 19:59:10,412.412 dsw33701-85768bd75b-s7k52:185257 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off. Please ignore the following import error if you are using tunnel table io. No module named '_common_io'
No module named 'easy_predict' ------------------------ arguments ------------------------ answer_name ..................................... None app_name ........................................ text_classify append_cols ..................................... None buckets ......................................... None checkpoint_dir .................................. None chief_hosts ..................................... data_threads .................................... 10 distributed_backend ............................. nccl do_lower_case ................................... False doc_stride ...................................... 128 epoch_num ....................................... 3.0 export_tf_checkpoint_type ....................... easytransfer first_sequence .................................. None gradient_accumulation_steps ..................... 1 input_schema .................................... None is_chief ........................................ is_master_node .................................. True job_name ........................................ None label_enumerate_values .......................... None label_name ...................................... None learning_rate ................................... 5e-05 local_rank ...................................... None logging_steps ................................... 100 master_port ..................................... 23456 max_answer_length ............................... 30 max_grad_norm ................................... 1.0 max_query_length ................................ 64 micro_batch_size ................................ 2 mode ............................................ train modelzoo_base_dir ............................... n_best_size ..................................... 10 n_cpu ........................................... 1 n_gpu ........................................... 1 odps_config ..................................... None optimizer_type .................................. AdamW output_answer_file .............................. None output_schema ................................... outputs ......................................... None predict_queue_size .............................. 1024 predict_slice_size .............................. 4096 predict_table_read_thread_num ................... 16 predict_thread_num .............................. 2 ps_hosts ........................................ qas_id .......................................... None random_seed ..................................... 1234 rank ............................................ 0 read_odps ....................................... False restore_works_dir ............................... ./.easynlp_predict_restore_works_dir resume_from_checkpoint .......................... None save_all_checkpoints ............................ False save_checkpoint_steps ........................... None second_sequence ................................. None sequence_length ................................. 16 skip_first_line ................................. False start_position_name ............................. None tables .......................................... None task_count ...................................... 1 task_index ...................................... 0 use_amp ......................................... False use_torchacc .................................... False user_defined_parameters ......................... None user_entry_file ................................. None user_script ..................................... None warmup_proportion ............................... 0.1 weight_decay .................................... 0.0001 worker_count .................................... 1 worker_cpu ...................................... -1 worker_gpu ...................................... -1 worker_hosts .................................... None world_size ...................................... 1 -------------------- end of arguments --------------------- > initializing torch distributed ...
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release. from cryptography import x509 [2022-09-07 19:59:12,471.471 dsw33701-85768bd75b-s7k52:185257 INFO distributed_c10d.py:195] Added key: store_based_barrier_key:1 to store for rank: 0
Init dist done. World size: 1, rank 0, l_rank 0 > setting random seeds to 1234 ...
netstat -tunlp|grep 6000
kill -9 PID (需要替换成上一行代码执行结果中对应的程序ID)
载入数据
我们使用EasyNLP中自带的MachineReadingComprehensionDataset,对训练和测试数据进行载入。主要参数如下:
- pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"bert-base-uncased"以得到其路径,并自动下载模型
- max_seq_length:总输入文本的最大长度,超过将截断,不足将padding
- input_schema:输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如question_text、answer_text、context_text等
- first_sequence、second_sequence:用于说明input_schema中哪些字段用于作为问句、描述,input_schema中其他字段名(qas_id、answer_name、start_position_name)将在user_defined_parameters中进行声明
- user_defined_parameters:任务特定的预声明参数。在机器阅读理解任务中,需要声明如下几种:
- pretrain_model_name_or_path:同上
- language:指定语言,中文(zh),英文(en)。不同语种数据预处理及后处理方式存在差异,不声明的话默认采用中文处理方式。
- qas_id、answer_name、start_position_name:用于说明input_schema中,哪些字段用于作为输入id、答案、答案在描述文本中的起始位置
- max_query_length:输入文本中问句的最大长度,超过将截断,不足将padding
- max_answer_length:输入文本中答案的最大长度,超过将截断,不足将padding
- doc_stride:由于描述文本可能过长,可能导致答案片段在超出max_seq_length后被截断的部分中。因此机器阅读理解中通常采用“滑窗法”,当描述文本过长时,使用滑动窗口将完整的描述文本拆成几段,将一个输入case拆成多个case分别处理。doc_stride即为滑动窗口大小。
- output_answer_file:模型预测时,除了输出的最佳答案之外,模型会将beam_search后排在top_n的答案结果,都输出到output_answer_file中,供模型优化时分析比对。
- n_best_size:上述top_n,用来指定将beam_search后排在前多少名的答案输出至output_answer_file中。
- is_training:是否为训练过程,train_dataset为True,valid_dataset为False
train_dataset = MachineReadingComprehensionDataset(pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"), data_file="train_squad.tsv", max_seq_length=384, input_schema="qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1", first_sequence="question_text", second_sequence="context_text", user_defined_parameters=user_defined_parameters, is_training=True ) valid_dataset = MachineReadingComprehensionDataset(pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"), data_file="dev_squad.tsv", max_seq_length=384, input_schema="qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1", first_sequence="question_text", second_sequence="context_text", user_defined_parameters=user_defined_parameters, is_training=False
`/root/.easynlp/modelzoo/public/bert-base-uncased.tgz` already exists `/root/.easynlp/modelzoo/public/bert-base-uncased.tgz` already exists
由于之前我们选用了bert-base-uncased,因此这里也会对预训练模型进行自动下载并载入。
模型训练
处理好数据与模型载入后,我们开始训练模型。 我们使用EasyNLP中封装好的get_application_model函数进行训练时的模型构建,其参数如下:
- app_name:任务名称,这里选择机器阅读理解"machine_reading_comprehension"
- pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"bert-base-uncased"以得到其路径,并自动下载模型
- user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters
model = get_application_model(app_name="machine_reading_comprehension", pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"), user_defined_parameters=user_defined_parameters )
`/root/.easynlp/modelzoo/public/bert-base-uncased.tgz` already exists model ============ from pretrained model ===============
Loaded weights of the model: [bert.embeddings.word_embeddings.weight,bert.embeddings.position_embeddings.weight,bert.embeddings.token_type_embeddings.weight,bert.encoder.layer.0.attention.self.query.weight,bert.encoder.layer.0.attention.self.query.bias,bert.encoder.layer.0.attention.self.key.weight,bert.encoder.layer.0.attention.self.key.bias,bert.encoder.layer.0.attention.self.value.weight,bert.encoder.layer.0.attention.self.value.bias,bert.encoder.layer.0.attention.output.dense.weight,bert.encoder.layer.0.attention.output.dense.bias,bert.encoder.layer.0.intermediate.dense.weight,bert.encoder.layer.0.intermediate.dense.bias,bert.encoder.layer.0.output.dense.weight,bert.encoder.layer.0.output.dense.bias,bert.encoder.layer.1.attention.self.query.weight,bert.encoder.layer.1.attention.self.query.bias,bert.encoder.layer.1.attention.self.key.weight,bert.encoder.layer.1.attention.self.key.bias,bert.encoder.layer.1.attention.self.value.weight,bert.encoder.layer.1.attention.self.value.bias,bert.encoder.layer.1.attention.output.dense.weight,bert.encoder.layer.1.attention.output.dense.bias,bert.encoder.layer.1.intermediate.dense.weight,bert.encoder.layer.1.intermediate.dense.bias,bert.encoder.layer.1.output.dense.weight,bert.encoder.layer.1.output.dense.bias,bert.encoder.layer.2.attention.self.query.weight,bert.encoder.layer.2.attention.self.query.bias,bert.encoder.layer.2.attention.self.key.weight,bert.encoder.layer.2.attention.self.key.bias,bert.encoder.layer.2.attention.self.value.weight,bert.encoder.layer.2.attention.self.value.bias,bert.encoder.layer.2.attention.output.dense.weight,bert.encoder.layer.2.attention.output.dense.bias,bert.encoder.layer.2.intermediate.dense.weight,bert.encoder.layer.2.intermediate.dense.bias,bert.encoder.layer.2.output.dense.weight,bert.encoder.layer.2.output.dense.bias,bert.encoder.layer.3.attention.self.query.weight,bert.encoder.layer.3.attention.self.query.bias,bert.encoder.layer.3.attention.self.key.weight,bert.encoder.layer.3.attention.self.key.bias,bert.encoder.layer.3.attention.self.value.weight,bert.encoder.layer.3.attention.self.value.bias,bert.encoder.layer.3.attention.output.dense.weight,bert.encoder.layer.3.attention.output.dense.bias,bert.encoder.layer.3.intermediate.dense.weight,bert.encoder.layer.3.intermediate.dense.bias,bert.encoder.layer.3.output.dense.weight,bert.encoder.layer.3.output.dense.bias,bert.encoder.layer.4.attention.self.query.weight,bert.encoder.layer.4.attention.self.query.bias,bert.encoder.layer.4.attention.self.key.weight,bert.encoder.layer.4.attention.self.key.bias,bert.encoder.layer.4.attention.self.value.weight,bert.encoder.layer.4.attention.self.value.bias,bert.encoder.layer.4.attention.output.dense.weight,bert.encoder.layer.4.attention.output.dense.bias,bert.encoder.layer.4.intermediate.dense.weight,bert.encoder.layer.4.intermediate.dense.bias,bert.encoder.layer.4.output.dense.weight,bert.encoder.layer.4.output.dense.bias,bert.encoder.layer.5.attention.self.query.weight,bert.encoder.layer.5.attention.self.query.bias,bert.encoder.layer.5.attention.self.key.weight,bert.encoder.layer.5.attention.self.key.bias,bert.encoder.layer.5.attention.self.value.weight,bert.encoder.layer.5.attention.self.value.bias,bert.encoder.layer.5.attention.output.dense.weight,bert.encoder.layer.5.attention.output.dense.bias,bert.encoder.layer.5.intermediate.dense.weight,bert.encoder.layer.5.intermediate.dense.bias,bert.encoder.layer.5.output.dense.weight,bert.encoder.layer.5.output.dense.bias,bert.encoder.layer.6.attention.self.query.weight,bert.encoder.layer.6.attention.self.query.bias,bert.encoder.layer.6.attention.self.key.weight,bert.encoder.layer.6.attention.self.key.bias,bert.encoder.layer.6.attention.self.value.weight,bert.encoder.layer.6.attention.self.value.bias,bert.encoder.layer.6.attention.output.dense.weight,bert.encoder.layer.6.attention.output.dense.bias,bert.encoder.layer.6.intermediate.dense.weight,bert.encoder.layer.6.intermediate.dense.bias,bert.encoder.layer.6.output.dense.weight,bert.encoder.layer.6.output.dense.bias,bert.encoder.layer.7.attention.self.query.weight,bert.encoder.layer.7.attention.self.query.bias,bert.encoder.layer.7.attention.self.key.weight,bert.encoder.layer.7.attention.self.key.bias,bert.encoder.layer.7.attention.self.value.weight,bert.encoder.layer.7.attention.self.value.bias,bert.encoder.layer.7.attention.output.dense.weight,bert.encoder.layer.7.attention.output.dense.bias,bert.encoder.layer.7.intermediate.dense.weight,bert.encoder.layer.7.intermediate.dense.bias,bert.encoder.layer.7.output.dense.weight,bert.encoder.layer.7.output.dense.bias,bert.encoder.layer.8.attention.self.query.weight,bert.encoder.layer.8.attention.self.query.bias,bert.encoder.layer.8.attention.self.key.weight,bert.encoder.layer.8.attention.self.key.bias,bert.encoder.layer.8.attention.self.value.weight,bert.encoder.layer.8.attention.self.value.bias,bert.encoder.layer.8.attention.output.dense.weight,bert.encoder.layer.8.attention.output.dense.bias,bert.encoder.layer.8.intermediate.dense.weight,bert.encoder.layer.8.intermediate.dense.bias,bert.encoder.layer.8.output.dense.weight,bert.encoder.layer.8.output.dense.bias,bert.encoder.layer.9.attention.self.query.weight,bert.encoder.layer.9.attention.self.query.bias,bert.encoder.layer.9.attention.self.key.weight,bert.encoder.layer.9.attention.self.key.bias,bert.encoder.layer.9.attention.self.value.weight,bert.encoder.layer.9.attention.self.value.bias,bert.encoder.layer.9.attention.output.dense.weight,bert.encoder.layer.9.attention.output.dense.bias,bert.encoder.layer.9.intermediate.dense.weight,bert.encoder.layer.9.intermediate.dense.bias,bert.encoder.layer.9.output.dense.weight,bert.encoder.layer.9.output.dense.bias,bert.encoder.layer.10.attention.self.query.weight,bert.encoder.layer.10.attention.self.query.bias,bert.encoder.layer.10.attention.self.key.weight,bert.encoder.layer.10.attention.self.key.bias,bert.encoder.layer.10.attention.self.value.weight,bert.encoder.layer.10.attention.self.value.bias,bert.encoder.layer.10.attention.output.dense.weight,bert.encoder.layer.10.attention.output.dense.bias,bert.encoder.layer.10.intermediate.dense.weight,bert.encoder.layer.10.intermediate.dense.bias,bert.encoder.layer.10.output.dense.weight,bert.encoder.layer.10.output.dense.bias,bert.encoder.layer.11.attention.self.query.weight,bert.encoder.layer.11.attention.self.query.bias,bert.encoder.layer.11.attention.self.key.weight,bert.encoder.layer.11.attention.self.key.bias,bert.encoder.layer.11.attention.self.value.weight,bert.encoder.layer.11.attention.self.value.bias,bert.encoder.layer.11.attention.output.dense.weight,bert.encoder.layer.11.attention.output.dense.bias,bert.encoder.layer.11.intermediate.dense.weight,bert.encoder.layer.11.intermediate.dense.bias,bert.encoder.layer.11.output.dense.weight,bert.encoder.layer.11.output.dense.bias,bert.pooler.dense.weight,bert.pooler.dense.bias,cls.predictions.bias,cls.predictions.transform.dense.weight,cls.predictions.transform.dense.bias,cls.predictions.decoder.weight,cls.seq_relationship.weight,cls.seq_relationship.bias,bert.embeddings.LayerNorm.weight,bert.embeddings.LayerNorm.bias,bert.encoder.layer.0.attention.output.LayerNorm.weight,bert.encoder.layer.0.attention.output.LayerNorm.bias,bert.encoder.layer.0.output.LayerNorm.weight,bert.encoder.layer.0.output.LayerNorm.bias,bert.encoder.layer.1.attention.output.LayerNorm.weight,bert.encoder.layer.1.attention.output.LayerNorm.bias,bert.encoder.layer.1.output.LayerNorm.weight,bert.encoder.layer.1.output.LayerNorm.bias,bert.encoder.layer.2.attention.output.LayerNorm.weight,bert.encoder.layer.2.attention.output.LayerNorm.bias,bert.encoder.layer.2.output.LayerNorm.weight,bert.encoder.layer.2.output.LayerNorm.bias,bert.encoder.layer.3.attention.output.LayerNorm.weight,bert.encoder.layer.3.attention.output.LayerNorm.bias,bert.encoder.layer.3.output.LayerNorm.weight,bert.encoder.layer.3.output.LayerNorm.bias,bert.encoder.layer.4.attention.output.LayerNorm.weight,bert.encoder.layer.4.attention.output.LayerNorm.bias,bert.encoder.layer.4.output.LayerNorm.weight,bert.encoder.layer.4.output.LayerNorm.bias,bert.encoder.layer.5.attention.output.LayerNorm.weight,bert.encoder.layer.5.attention.output.LayerNorm.bias,bert.encoder.layer.5.output.LayerNorm.weight,bert.encoder.layer.5.output.LayerNorm.bias,bert.encoder.layer.6.attention.output.LayerNorm.weight,bert.encoder.layer.6.attention.output.LayerNorm.bias,bert.encoder.layer.6.output.LayerNorm.weight,bert.encoder.layer.6.output.LayerNorm.bias,bert.encoder.layer.7.attention.output.LayerNorm.weight,bert.encoder.layer.7.attention.output.LayerNorm.bias,bert.encoder.layer.7.output.LayerNorm.weight,bert.encoder.layer.7.output.LayerNorm.bias,bert.encoder.layer.8.attention.output.LayerNorm.weight,bert.encoder.layer.8.attention.output.LayerNorm.bias,bert.encoder.layer.8.output.LayerNorm.weight,bert.encoder.layer.8.output.LayerNorm.bias,bert.encoder.layer.9.attention.output.LayerNorm.weight,bert.encoder.layer.9.attention.output.LayerNorm.bias,bert.encoder.layer.9.output.LayerNorm.weight,bert.encoder.layer.9.output.LayerNorm.bias,bert.encoder.layer.10.attention.output.LayerNorm.weight,bert.encoder.layer.10.attention.output.LayerNorm.bias,bert.encoder.layer.10.output.LayerNorm.weight,bert.encoder.layer.10.output.LayerNorm.bias,bert.encoder.layer.11.attention.output.LayerNorm.weight,bert.encoder.layer.11.attention.output.LayerNorm.bias,bert.encoder.layer.11.output.LayerNorm.weight,bert.encoder.layer.11.output.LayerNorm.bias,cls.predictions.transform.LayerNorm.weight,cls.predictions.transform.LayerNorm.bias]. Unloaded weights of the model: [cls.predictions.decoder.weight,cls.seq_relationship.weight,cls.predictions.bias,cls.seq_relationship.bias,cls.predictions.transform.LayerNorm.weight,cls.predictions.transform.LayerNorm.bias,cls.predictions.transform.dense.bias,cls.predictions.transform.dense.weight]. This IS expected if you initialize A model from B. This IS NOT expected if you initialize A model from A.
trainer = Trainer(model=model, train_dataset=train_dataset, evaluator=get_application_evaluator(app_name="machine_reading_comprehension", valid_dataset=valid_dataset, eval_batch_size=32, pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"), user_defined_parameters=user_defined_parameters )) trainer.train()
需要注意的是,由于滑动窗口机制的存在,在训练时的batch_size不可设置过大,否则容易挤爆内存。
从日志中可以看出,我们对预训练模型的参数进行了载入。下一步我们使用EasyNLP中的Train类创建训练实例,并进行训练。
模型评估
训练过程结束后,train好的模型被我们保存在一开始指定好的checkpoint_dir中,本地路径为"./squad_model_dir/"。我们可以对train好的模型进行效果评估。我们同样先使用EasyNLP中的get_application_model_for_evaluation方法构建评估模型。
model = get_application_model_for_evaluation(app_name="machine_reading_comprehension", pretrained_model_name_or_path="./squad_model_dir/", user_defined_parameters=user_defined_parameters )
model ============ from config ===============
[2022-09-07 20:02:29,052 INFO] Loading model... [2022-09-07 20:02:29,097 INFO] Load finished! Inited keys of the model: [_model.encoder.layer.1.attention.output.dense.weight,_model.encoder.layer.6.attention.output.LayerNorm.bias,_model.encoder.layer.6.attention.self.key.weight,_model.encoder.layer.4.intermediate.dense.bias,_model.encoder.layer.10.intermediate.dense.weight,_model.encoder.layer.6.attention.output.LayerNorm.weight,_model.encoder.layer.11.attention.self.key.bias,_model.encoder.layer.4.attention.self.key.bias,_model.encoder.layer.8.attention.self.key.bias,_model.encoder.layer.5.attention.self.key.bias,_model.encoder.layer.10.attention.self.key.weight,_model.encoder.layer.5.intermediate.dense.bias,_model.encoder.layer.4.attention.self.value.weight,_model.encoder.layer.7.attention.output.dense.weight,_model.encoder.layer.5.attention.self.query.bias,_model.encoder.layer.0.attention.self.value.bias,_model.encoder.layer.1.attention.output.LayerNorm.bias,_model.encoder.layer.1.attention.self.query.weight,_model.encoder.layer.3.output.LayerNorm.weight,_model.encoder.layer.11.attention.self.query.bias,_model.pooler.dense.weight,_model.encoder.layer.7.attention.self.key.bias,_model.encoder.layer.9.attention.self.value.weight,_model.encoder.layer.0.intermediate.dense.bias,classifier.weight,_model.encoder.layer.4.attention.output.LayerNorm.weight,_model.encoder.layer.1.output.dense.weight,_model.encoder.layer.4.attention.self.query.bias,_model.encoder.layer.2.intermediate.dense.weight,_model.encoder.layer.8.attention.output.LayerNorm.weight,_model.encoder.layer.5.output.LayerNorm.weight,_model.encoder.layer.2.attention.self.key.weight,_model.encoder.layer.2.intermediate.dense.bias,_model.encoder.layer.10.output.dense.bias,_model.encoder.layer.6.output.LayerNorm.bias,_model.encoder.layer.3.attention.self.value.bias,_model.encoder.layer.7.output.dense.bias,_model.encoder.layer.5.output.dense.weight,_model.encoder.layer.10.attention.self.key.bias,_model.encoder.layer.10.attention.self.value.weight,_model.encoder.layer.1.attention.output.dense.bias,_model.encoder.layer.6.attention.output.dense.bias,_model.encoder.layer.9.intermediate.dense.bias,_model.encoder.layer.3.output.LayerNorm.bias,_model.encoder.layer.4.attention.self.value.bias,_model.encoder.layer.2.output.dense.weight,_model.encoder.layer.8.attention.self.key.weight,_model.encoder.layer.6.output.LayerNorm.weight,_model.encoder.layer.3.output.dense.weight,_model.encoder.layer.9.attention.output.dense.bias,_model.encoder.layer.9.intermediate.dense.weight,_model.encoder.layer.0.attention.output.dense.weight,_model.encoder.layer.11.attention.self.query.weight,_model.encoder.layer.1.attention.self.key.bias,_model.encoder.layer.10.attention.output.dense.weight,_model.encoder.layer.5.output.LayerNorm.bias,_model.encoder.layer.7.attention.output.LayerNorm.bias,_model.encoder.layer.10.attention.self.value.bias,_model.encoder.layer.2.output.dense.bias,_model.encoder.layer.11.intermediate.dense.bias,_model.encoder.layer.7.attention.self.key.weight,_model.encoder.layer.6.attention.self.query.bias,_model.encoder.layer.5.attention.self.value.weight,_model.encoder.layer.1.attention.output.LayerNorm.weight,_model.encoder.layer.5.attention.self.query.weight,_model.encoder.layer.3.attention.self.key.bias,_model.encoder.layer.3.attention.output.dense.weight,_model.encoder.layer.10.attention.output.dense.bias,_model.encoder.layer.2.attention.output.LayerNorm.bias,_model.encoder.layer.7.attention.self.value.bias,_model.encoder.layer.3.attention.self.value.weight,_model.encoder.layer.11.output.dense.weight,_model.encoder.layer.3.attention.output.dense.bias,_model.encoder.layer.8.attention.output.dense.weight,_model.encoder.layer.5.attention.output.LayerNorm.bias,_model.encoder.layer.3.intermediate.dense.weight,_model.encoder.layer.4.attention.output.dense.weight,_model.encoder.layer.6.intermediate.dense.weight,_model.encoder.layer.3.attention.self.query.bias,_model.embeddings.LayerNorm.weight,_model.encoder.layer.8.intermediate.dense.weight,_model.encoder.layer.9.attention.output.LayerNorm.bias,_model.encoder.layer.11.output.LayerNorm.weight,_model.encoder.layer.4.attention.output.dense.bias,_model.encoder.layer.4.output.dense.bias,_model.encoder.layer.8.output.LayerNorm.weight,_model.encoder.layer.2.attention.self.query.weight,classifier.bias,_model.encoder.layer.7.intermediate.dense.weight,_model.encoder.layer.3.output.dense.bias,_model.encoder.layer.10.output.LayerNorm.bias,_model.encoder.layer.6.attention.output.dense.weight,_model.encoder.layer.8.attention.self.query.bias,_model.encoder.layer.6.attention.self.value.weight,_model.encoder.layer.11.attention.output.dense.weight,_model.encoder.layer.1.output.LayerNorm.bias,_model.encoder.layer.2.output.LayerNorm.bias,_model.encoder.layer.9.attention.output.dense.weight,_model.encoder.layer.2.attention.output.LayerNorm.weight,_model.encoder.layer.7.attention.self.query.bias,_model.encoder.layer.1.intermediate.dense.weight,_model.encoder.layer.9.output.LayerNorm.bias,_model.encoder.layer.8.attention.self.value.weight,_model.encoder.layer.9.output.LayerNorm.weight,_model.encoder.layer.3.intermediate.dense.bias,_model.encoder.layer.2.attention.self.value.weight,_model.encoder.layer.11.attention.self.value.weight,_model.encoder.layer.5.attention.output.LayerNorm.weight,_model.encoder.layer.11.attention.output.dense.bias,_model.encoder.layer.5.output.dense.bias,_model.encoder.layer.8.intermediate.dense.bias,_model.encoder.layer.10.attention.self.query.weight,_model.encoder.layer.10.attention.output.LayerNorm.bias,_model.encoder.layer.11.attention.output.LayerNorm.weight,_model.encoder.layer.7.attention.output.LayerNorm.weight,_model.encoder.layer.11.attention.self.value.bias,_model.embeddings.position_embeddings.weight,_model.encoder.layer.11.attention.output.LayerNorm.bias,_model.encoder.layer.11.output.dense.bias,_model.encoder.layer.1.attention.self.value.bias,_model.encoder.layer.5.attention.output.dense.weight,_model.encoder.layer.5.attention.self.key.weight,_model.encoder.layer.3.attention.output.LayerNorm.bias,_model.encoder.layer.0.output.dense.weight,_model.encoder.layer.0.attention.self.query.bias,_model.encoder.layer.8.output.dense.bias,_model.encoder.layer.10.intermediate.dense.bias,_model.encoder.layer.6.attention.self.value.bias,_model.encoder.layer.8.output.LayerNorm.bias,_model.pooler.dense.bias,_model.encoder.layer.2.attention.output.dense.bias,_model.encoder.layer.7.attention.self.value.weight,_model.encoder.layer.8.attention.self.value.bias,_model.encoder.layer.1.output.LayerNorm.weight,_model.encoder.layer.4.attention.output.LayerNorm.bias,_model.encoder.layer.11.output.LayerNorm.bias,_model.encoder.layer.0.attention.self.key.weight,_model.encoder.layer.0.attention.output.dense.bias,_model.encoder.layer.0.attention.output.LayerNorm.weight,_model.encoder.layer.0.attention.output.LayerNorm.bias,_model.encoder.layer.1.attention.self.key.weight,_model.encoder.layer.2.attention.self.query.bias,_model.encoder.layer.7.attention.self.query.weight,_model.encoder.layer.2.attention.self.key.bias,_model.encoder.layer.4.attention.self.query.weight,_model.encoder.layer.8.output.dense.weight,_model.encoder.layer.9.output.dense.weight,_model.encoder.layer.2.attention.output.dense.weight,_model.encoder.layer.7.output.dense.weight,_model.encoder.layer.10.output.LayerNorm.weight,_model.encoder.layer.7.output.LayerNorm.bias,_model.encoder.layer.11.attention.self.key.weight,_model.encoder.layer.4.attention.self.key.weight,_model.encoder.layer.0.attention.self.value.weight,_model.encoder.layer.9.attention.self.value.bias,_model.encoder.layer.6.output.dense.weight,_model.encoder.layer.9.output.dense.bias,_model.encoder.layer.0.intermediate.dense.weight,_model.encoder.layer.10.output.dense.weight,_model.encoder.layer.0.attention.self.key.bias,_model.encoder.layer.8.attention.self.query.weight,_model.encoder.layer.2.output.LayerNorm.weight,_model.encoder.layer.6.attention.self.query.weight,_model.encoder.layer.8.attention.output.LayerNorm.bias,_model.encoder.layer.3.attention.output.LayerNorm.weight,_model.embeddings.position_ids,_model.embeddings.word_embeddings.weight,_model.encoder.layer.4.output.LayerNorm.weight,_model.encoder.layer.6.output.dense.bias,_model.encoder.layer.1.intermediate.dense.bias,_model.encoder.layer.9.attention.self.key.bias,_model.encoder.layer.0.output.LayerNorm.bias,_model.encoder.layer.3.attention.self.key.weight,_model.encoder.layer.2.attention.self.value.bias,_model.encoder.layer.5.intermediate.dense.weight,_model.encoder.layer.1.output.dense.bias,_model.embeddings.LayerNorm.bias,_model.encoder.layer.0.attention.self.query.weight,_model.encoder.layer.1.attention.self.value.weight,_model.encoder.layer.10.attention.output.LayerNorm.weight,_model.encoder.layer.3.attention.self.query.weight,_model.encoder.layer.7.intermediate.dense.bias,_model.encoder.layer.4.output.dense.weight,_model.encoder.layer.6.intermediate.dense.bias,_model.encoder.layer.9.attention.output.LayerNorm.weight,_model.encoder.layer.8.attention.output.dense.bias,_model.encoder.layer.4.intermediate.dense.weight,_model.encoder.layer.9.attention.self.query.weight,_model.encoder.layer.9.attention.self.key.weight,_model.encoder.layer.11.intermediate.dense.weight,_model.encoder.layer.10.attention.self.query.bias,_model.encoder.layer.4.output.LayerNorm.bias,_model.encoder.layer.7.output.LayerNorm.weight,_model.embeddings.token_type_embeddings.weight,_model.encoder.layer.9.attention.self.query.bias,_model.encoder.layer.0.output.dense.bias,_model.encoder.layer.1.attention.self.query.bias,_model.encoder.layer.5.attention.self.value.bias,_model.encoder.layer.5.attention.output.dense.bias,_model.encoder.layer.6.attention.self.key.bias,_model.encoder.layer.0.output.LayerNorm.weight,_model.encoder.layer.7.attention.output.dense.bias]. All keys are initialized.
之后我们使用EasyNLP中的get_application_evaluator来初始化evaluator,并指定当前device下的当前模型,进行模型评估。
evaluator = get_application_evaluator(app_name="machine_reading_comprehension", valid_dataset=valid_dataset, eval_batch_size=32, user_defined_parameters=user_defined_parameters ) model.to(torch.cuda.current_device()) evaluator.evaluate(model=model)
模型预测
我们同样可以使用训练好的模型进行文本分类的预测。我们首先创建一个predictor,并据此实例化一个PredictorManager实例。我们指定预测好的结果输出在"dev.pred.csv",并指定输出格式为"unique_id, best_answer, query, context"。此外,我们将beam_search后排在top_n的答案结果,都输出在"dev.ans.csv"中,供模型优化时分析比对。
predictor = get_application_predictor(app_name="machine_reading_comprehension", model_dir="./squad_model_dir/", first_sequence="question_text", second_sequence="context_text", max_seq_length=384, output_file="dev.pred.csv", user_defined_parameters=user_defined_parameters ) predictor_manager = PredictorManager(predictor=predictor, input_file="dev_squad.tsv", input_schema="qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1", output_file="dev.pred.csv", output_schema="unique_id,best_answer,query,context", append_cols=args.append_cols, batch_size=256 ) predictor_manager.run() exit()
输入以下命令,可以展示测试集中的5条数据,其id、context、query,以及经过训练后的模型预测出的answer。
print('Predicted results: id best_answer query context') ! head -n 5 dev.pred.csv
Predicted results: id best_answer query context 56be4db0acb8001400a502ec Denver Broncos Which NFL team represented the AFC at Super Bowl 50? Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. 56be4db0acb8001400a502ed Carolina Panthers Which NFL team represented the NFC at Super Bowl 50? Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. 56be4db0acb8001400a502ee Santa Clara, California. Where did Super Bowl 50 take place? Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. 56be4db0acb8001400a502ef Denver Broncos Which NFL team won Super Bowl 50? Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. 56be4db0acb8001400a502f0 gold-themed What color was used to emphasize the 50th anniversary of the Super Bowl? Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
以上输出的每行结果,从左到右依次为:id, best_answer, context, query。可以看出,模型很好地预测出了每个问题的答案,答案准确而简洁。
一步执行
值得一提的是,上述所有训练/评估/预测代码,都已经被集成在EasyNLP/examples/appzoo_tutorials/machine_reading_comprehension/main.py中,此外,我们也预先编写好了多种可供直接执行的脚本。用户可以通过带参数运行main.py中指令,或者直接使用bash文件命令行执行的方式,一步执行上述所有训练/评估/预测操作。
main文件一步执行#
用户通过以下代码带参数执行main.py中的指令,可直接对模型进行训练/评估/预测操作。
训练代码指令如下。参数说明如“载入数据”小节中所示。在本示例中,预训练模型指定为bert-base-uncased。
! python main.py \ --mode train \ --app_name=machine_reading_comprehension \ --worker_gpu=1 \ --tables=train_squad.tsv,dev_squad.tsv \ --input_schema=qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1 \ --first_sequence=question_text \ --second_sequence=context_text \ --sequence_length=384 \ --checkpoint_dir=./squad_model_dir \ --learning_rate=3.5e-5 \ --epoch_num=5 \ --random_seed=42 \ --save_checkpoint_steps=2000 \ --train_batch_size=32 \ --user_defined_parameters=' pretrain_model_name_or_path=bert-base-uncased language=en answer_name=answer_text qas_id=qas_id start_position_name=start_position_character doc_stride=128 max_query_length=64 '
评估代码如下,参数含义与训练是一致的。
! python main.py \ --mode evaluate \ --app_name=machine_reading_comprehension \ --worker_gpu=1 \ --tables=dev_squad.tsv \ --input_schema=qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1 \ --first_sequence=question_text \ --second_sequence=context_text \ --sequence_length=384 \ --checkpoint_dir=./squad_model_dir \ --micro_batch_size=32 \ --user_defined_parameters=' pretrain_model_name_or_path=bert-base-uncased language=en qas_id=qas_id answer_name=answer_text start_position_name=start_position_character doc_stride=128 max_query_length=64 '
预测代码如下。参数同样与上面保持一致,输出结果可在dev.pred.tsv中查看。
! python main.py \ --mode predict \ --app_name=machine_reading_comprehension \ --worker_gpu=1 \ --tables=dev_squad.tsv \ --outputs=dev.pred.csv \ --input_schema=qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1 \ --output_schema=unique_id,best_answer,query,context \ --first_sequence=question_text \ --second_sequence=context_text \ --sequence_length=384 \ --checkpoint_dir=./squad_model_dir \ --micro_batch_size=256 \ --user_defined_parameters=' pretrain_model_name_or_path=bert-base-uncased language=en qas_id=qas_id answer_name=answer_text start_position_name=start_position_character max_query_length=64 max_answer_length=30 doc_stride=128 n_best_size=10 output_answer_file=dev.ans.csv '
利用bash文件命令行执行
我们在EasyNLP/examples/appzoo_tutorials/machine_reading_comprehension/文件夹下封装好了多种可直接执行的bash脚本,用户同样可以通过使用bash文件命令行执行的方式来一步完成模型的训练/评估/预测。以下以run_train_eval_predict_user_defined_local_en.sh脚本为例。该bash文件需要传入两个参数,第一个参数为运行程序的GPU编号,一般为0;第二个参数代表模型的训练/评估/预测。
模型训练:
! bash run_train_eval_predict_user_defined_local_en.sh 0 train
模型评估:
! bash run_train_eval_predict_user_defined_local_en.sh 0 evaluate
模型预测:
! bash run_train_eval_predict_user_defined_local_en.sh 0 predict