【DSW Gallery】基于EasyNLP的BERT英文机器阅读理解

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文以机器阅读理解为例,为您介绍如何在PAI-DSW中基于EasyNLP快速使用BERT进行英文机器阅读理解模型的训练、推理。

直接使用

请打开基于EasyNLP的BERT英文机器阅读理解,并点击右上角 “ 在DSW中打开” 。

image.png

基于EasyNLP的BERT英文机器阅读理解

EasyNLP是阿里云PAI算法团队基于PyTorch开发的易用且丰富的NLP算法框架( https://github.com/alibaba/EasyNLP ),支持常用的中文预训练模型和大模型落地技术,并且提供了从训练到部署的一站式NLP开发体验。EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。

本文以机器阅读理解任务为例,为您介绍如何在PAI-DSW中基于EasyNLP快速使用BERT进行英文机器阅读理解模型的构建、训练、评估、预测。

关于BERT

BERT是2018年10月由Google AI研究院提出的一种预训练语言表征模型,全称是Bidirectional Encoder Representation from Transformers。与当前许多广泛应用于NLP领域的模型一样,BERT也采用Transformer encoder结构,不过相比传统方法中只使用单向语言模型、或把两个单向语言模型进行浅层拼接来预训练,BERT采用新的masked language model(MLM),以生成深度的双向语言表征。作为当前NLP领域最常用的模型之一,BERT刚提出时便在11种不同的NLP任务中创造出SOTA表现,成为NLP发展史上的里程碑式的模型成就。

关于机器阅读理解

机器阅读理解(machine reading comprehension),指的是同时给模型输入一个问题(question)和一段描述(context),需要模型根据给定的描述给出答案(answer)。根据任务范式的不同,机器阅读理解任务通常可以分为4大类:完型填空式(Cloze tests)、多项选择式(Multi-choice)、片段抽取式(Span extraction)及自由生成式(Free answering)。其中片段抽取式根据问句(question),直接从篇章文本(context)中预测答案文本(answer)的起止位置(start/end positions),从而抽取出答案。由于其与真实场景接近,难度适中,易于评测,且有SQuAD等高质量数据集支撑,因此成为当前的主流阅读理解任务。

运行环境要求

建议用户使用:Python 3.6,Pytorch 1.8镜像,GPU机型 P100 or V100,内存至少为 32G

EasyNLP安装

建议从GitHub下载EasyNLP源代码进行安装,命令如下:

! git clone https://github.com/alibaba/EasyNLP.git
! pip install -r EasyNLP/requirements.txt
! cd EasyNLP && python setup.py install

您可以使用如下命令验证是否安装成功:

! which easynlp
/home/pai/bin/easynlp

如果您系统内已经安装完easynlp的CLI工具,则说明EasyNLP代码库已经安装。

数据准备

首先,您需要进入指定模型目录,下载用于本示例的训练和测试集,并创建保存模型的文件夹,命令如下:

! cd examples/appzoo_tutorials/machine_reading_comprehension
! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/train_squad.tsv
! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/dev_squad.tsv
/bin/bash: line 0: cd: examples/appzoo_tutorials/machine_reading_comprehension: No such file or directory
--2022-09-07 19:51:13--  http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/train_squad.tsv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 77308556 (74M) [text/tab-separated-values]
Saving to: ‘train_squad.tsv’
train_squad.tsv     100%[===================>]  73.73M  24.6MB/s    in 3.0s    
2022-09-07 19:51:17 (24.6 MB/s) - ‘train_squad.tsv’ saved [77308556/77308556]
--2022-09-07 19:51:17--  http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/machine_reading_comprehension/dev_squad.tsv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9586894 (9.1M) [text/tab-separated-values]
Saving to: ‘dev_squad.tsv’
dev_squad.tsv       100%[===================>]   9.14M  4.96MB/s    in 1.8s    
2022-09-07 19:51:19 (4.96 MB/s) - ‘dev_squad.tsv’ saved [9586894/9586894]
dev_squad.tsv       100%[===================>]   9.14M  4.96MB/s    in 1.8s    
2022-09-07 19:51:19 (4.96 MB/s) - ‘dev_squad.tsv’ saved [9586894/9586894]

训练和测试数据都是以\t隔开的.tsv文件,数据下载完成后,可以通过以下代码查看前5条数据。每行为一个数据,每列为一个字段值,包括需要进行文本匹配的两个句子,以及对应的匹配结果标签。

print('Training data sample:')
! head -n 5 train_squad.tsv
print('Development set data sample:')
! head -n 5 dev_squad.tsv
Training data sample:
5733be284776f41900661182  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? Saint Bernadette Soubirous  515 University_of_Notre_Dame
5733be284776f4190066117f  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. What is in front of the Notre Dame Main Building? a copper statue of Christ 188 University_of_Notre_Dame
5733be284776f41900661180  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. The Basilica of the Sacred heart at Notre Dame is beside to which structure?  the Main Building 279 University_of_Notre_Dame
5733be284776f41900661181  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. What is the Grotto at Notre Dame? a Marian place of prayer and reflection 381 University_of_Notre_Dame
5733be284776f4190066117e  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. What sits on top of the Main Building at Notre Dame?  a golden statue of the Virgin Mary  92  University_of_Notre_Dame
Development set data sample:
56be4db0acb8001400a502ec  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team represented the AFC at Super Bowl 50?  Denver Broncos  177 Super_Bowl_50
56be4db0acb8001400a502ed  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team represented the NFC at Super Bowl 50?  Carolina Panthers 249 Super_Bowl_50
56be4db0acb8001400a502ee  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Where did Super Bowl 50 take place? Santa Clara, California 403 Super_Bowl_50
56be4db0acb8001400a502ef  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team won Super Bowl 50? Denver Broncos  177 Super_Bowl_50
56be4db0acb8001400a502f0  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. What color was used to emphasize the 50th anniversary of the Super Bowl?  gold  488 Super_Bowl_50

初始化

在Python 3.6环境下,我们首先从刚刚安装好的EasyNLP中引入模型运行需要的各种库,并做一些初始化。在本教程中,我们使用bert-base-uncased。EasyNLP中集成了丰富的预训练模型库,如果想尝试其他预训练模型,如RoBERTa、albert等,也可以在user_defined_parameters中进行相应修改,具体的模型名称可见模型列表。EasyNLP当前同时支持中英文的阅读理解,这里只需要在user_defined_parameters中指定language=en,即可指定英文文本处理相应配置。此外,机器阅读理解任务中有一些特殊的参数,如doc_stride等,也需要在user_defined_parameters中预先声明,这些参数的具体含义,我们将在下面的“载入数据”小节中详细阐述。

# 为了避免EasyNLP中的args与Jupyter系统的冲突,需要手动设置,否则无法进行初始化。
# 在命令行或py文件中运行文中代码则可忽略下述代码。
import sys
sys.argv = ['main.py']
import torch.cuda
from easynlp.appzoo import MachineReadingComprehensionDataset
from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator, get_application_model_for_evaluation
from easynlp.core import Trainer, PredictorManager
from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path
from easynlp.utils.global_vars import parse_user_defined_parameters
initialize_easynlp()
args = get_args()
user_defined_parameters = parse_user_defined_parameters('pretrain_model_name_or_path=bert-base-uncased language=en qas_id=qas_id answer_name=answer_text start_position_name=start_position_character max_query_length=64 max_answer_length=30 doc_stride=128 n_best_size=10 output_answer_file=dev.ans.csv')
args.checkpoint_dir = "./squad_model_dir/"
[2022-09-07 19:59:10,412.412 dsw33701-85768bd75b-s7k52:185257 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
Please ignore the following import error if you are using tunnel table io.
No module named '_common_io'
No module named 'easy_predict'
------------------------ arguments ------------------------
  answer_name ..................................... None
  app_name ........................................ text_classify
  append_cols ..................................... None
  buckets ......................................... None
  checkpoint_dir .................................. None
  chief_hosts ..................................... 
  data_threads .................................... 10
  distributed_backend ............................. nccl
  do_lower_case ................................... False
  doc_stride ...................................... 128
  epoch_num ....................................... 3.0
  export_tf_checkpoint_type ....................... easytransfer
  first_sequence .................................. None
  gradient_accumulation_steps ..................... 1
  input_schema .................................... None
  is_chief ........................................ 
  is_master_node .................................. True
  job_name ........................................ None
  label_enumerate_values .......................... None
  label_name ...................................... None
  learning_rate ................................... 5e-05
  local_rank ...................................... None
  logging_steps ................................... 100
  master_port ..................................... 23456
  max_answer_length ............................... 30
  max_grad_norm ................................... 1.0
  max_query_length ................................ 64
  micro_batch_size ................................ 2
  mode ............................................ train
  modelzoo_base_dir ............................... 
  n_best_size ..................................... 10
  n_cpu ........................................... 1
  n_gpu ........................................... 1
  odps_config ..................................... None
  optimizer_type .................................. AdamW
  output_answer_file .............................. None
  output_schema ................................... 
  outputs ......................................... None
  predict_queue_size .............................. 1024
  predict_slice_size .............................. 4096
  predict_table_read_thread_num ................... 16
  predict_thread_num .............................. 2
  ps_hosts ........................................ 
  qas_id .......................................... None
  random_seed ..................................... 1234
  rank ............................................ 0
  read_odps ....................................... False
  restore_works_dir ............................... ./.easynlp_predict_restore_works_dir
  resume_from_checkpoint .......................... None
  save_all_checkpoints ............................ False
  save_checkpoint_steps ........................... None
  second_sequence ................................. None
  sequence_length ................................. 16
  skip_first_line ................................. False
  start_position_name ............................. None
  tables .......................................... None
  task_count ...................................... 1
  task_index ...................................... 0
  use_amp ......................................... False
  use_torchacc .................................... False
  user_defined_parameters ......................... None
  user_entry_file ................................. None
  user_script ..................................... None
  warmup_proportion ............................... 0.1
  weight_decay .................................... 0.0001
  worker_count .................................... 1
  worker_cpu ...................................... -1
  worker_gpu ...................................... -1
  worker_hosts .................................... None
  world_size ...................................... 1
-------------------- end of arguments ---------------------
> initializing torch distributed ...
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
  from cryptography import x509
[2022-09-07 19:59:12,471.471 dsw33701-85768bd75b-s7k52:185257 INFO distributed_c10d.py:195] Added key: store_based_barrier_key:1 to store for rank: 0
Init dist done. World size: 1, rank 0, l_rank 0
> setting random seeds to 1234 ...

netstat -tunlp|grep 6000

kill -9 PID (需要替换成上一行代码执行结果中对应的程序ID)

载入数据

我们使用EasyNLP中自带的MachineReadingComprehensionDataset,对训练和测试数据进行载入。主要参数如下:

  • pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"bert-base-uncased"以得到其路径,并自动下载模型
  • max_seq_length:总输入文本的最大长度,超过将截断,不足将padding
  • input_schema:输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如question_text、answer_text、context_text等
  • first_sequence、second_sequence:用于说明input_schema中哪些字段用于作为问句、描述,input_schema中其他字段名(qas_id、answer_name、start_position_name)将在user_defined_parameters中进行声明
  • user_defined_parameters:任务特定的预声明参数。在机器阅读理解任务中,需要声明如下几种:
  • pretrain_model_name_or_path:同上
  • language:指定语言,中文(zh),英文(en)。不同语种数据预处理及后处理方式存在差异,不声明的话默认采用中文处理方式。
  • qas_id、answer_name、start_position_name:用于说明input_schema中,哪些字段用于作为输入id、答案、答案在描述文本中的起始位置
  • max_query_length:输入文本中问句的最大长度,超过将截断,不足将padding
  • max_answer_length:输入文本中答案的最大长度,超过将截断,不足将padding
  • doc_stride:由于描述文本可能过长,可能导致答案片段在超出max_seq_length后被截断的部分中。因此机器阅读理解中通常采用“滑窗法”,当描述文本过长时,使用滑动窗口将完整的描述文本拆成几段,将一个输入case拆成多个case分别处理。doc_stride即为滑动窗口大小。
  • output_answer_file:模型预测时,除了输出的最佳答案之外,模型会将beam_search后排在top_n的答案结果,都输出到output_answer_file中,供模型优化时分析比对。
  • n_best_size:上述top_n,用来指定将beam_search后排在前多少名的答案输出至output_answer_file中。
  • is_training:是否为训练过程,train_dataset为True,valid_dataset为False
train_dataset = MachineReadingComprehensionDataset(pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"),
                                                   data_file="train_squad.tsv",
                                                   max_seq_length=384,
                                                   input_schema="qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1",
                                                   first_sequence="question_text",
                                                   second_sequence="context_text",
                                                   user_defined_parameters=user_defined_parameters,
                                                   is_training=True
                                                   )
valid_dataset = MachineReadingComprehensionDataset(pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"),
                                                   data_file="dev_squad.tsv",
                                                   max_seq_length=384,
                                                   input_schema="qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1",
                                                   first_sequence="question_text",
                                                   second_sequence="context_text",
                                                   user_defined_parameters=user_defined_parameters,
                                                   is_training=False
`/root/.easynlp/modelzoo/public/bert-base-uncased.tgz` already exists
`/root/.easynlp/modelzoo/public/bert-base-uncased.tgz` already exists

由于之前我们选用了bert-base-uncased,因此这里也会对预训练模型进行自动下载并载入。

模型训练

处理好数据与模型载入后,我们开始训练模型。 我们使用EasyNLP中封装好的get_application_model函数进行训练时的模型构建,其参数如下:

  • app_name:任务名称,这里选择机器阅读理解"machine_reading_comprehension"
  • pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"bert-base-uncased"以得到其路径,并自动下载模型
  • user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters
model = get_application_model(app_name="machine_reading_comprehension",
                              pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"),
                              user_defined_parameters=user_defined_parameters
                              )
`/root/.easynlp/modelzoo/public/bert-base-uncased.tgz` already exists
model ============ from pretrained model ===============
 Loaded weights of the model:
 [bert.embeddings.word_embeddings.weight,bert.embeddings.position_embeddings.weight,bert.embeddings.token_type_embeddings.weight,bert.encoder.layer.0.attention.self.query.weight,bert.encoder.layer.0.attention.self.query.bias,bert.encoder.layer.0.attention.self.key.weight,bert.encoder.layer.0.attention.self.key.bias,bert.encoder.layer.0.attention.self.value.weight,bert.encoder.layer.0.attention.self.value.bias,bert.encoder.layer.0.attention.output.dense.weight,bert.encoder.layer.0.attention.output.dense.bias,bert.encoder.layer.0.intermediate.dense.weight,bert.encoder.layer.0.intermediate.dense.bias,bert.encoder.layer.0.output.dense.weight,bert.encoder.layer.0.output.dense.bias,bert.encoder.layer.1.attention.self.query.weight,bert.encoder.layer.1.attention.self.query.bias,bert.encoder.layer.1.attention.self.key.weight,bert.encoder.layer.1.attention.self.key.bias,bert.encoder.layer.1.attention.self.value.weight,bert.encoder.layer.1.attention.self.value.bias,bert.encoder.layer.1.attention.output.dense.weight,bert.encoder.layer.1.attention.output.dense.bias,bert.encoder.layer.1.intermediate.dense.weight,bert.encoder.layer.1.intermediate.dense.bias,bert.encoder.layer.1.output.dense.weight,bert.encoder.layer.1.output.dense.bias,bert.encoder.layer.2.attention.self.query.weight,bert.encoder.layer.2.attention.self.query.bias,bert.encoder.layer.2.attention.self.key.weight,bert.encoder.layer.2.attention.self.key.bias,bert.encoder.layer.2.attention.self.value.weight,bert.encoder.layer.2.attention.self.value.bias,bert.encoder.layer.2.attention.output.dense.weight,bert.encoder.layer.2.attention.output.dense.bias,bert.encoder.layer.2.intermediate.dense.weight,bert.encoder.layer.2.intermediate.dense.bias,bert.encoder.layer.2.output.dense.weight,bert.encoder.layer.2.output.dense.bias,bert.encoder.layer.3.attention.self.query.weight,bert.encoder.layer.3.attention.self.query.bias,bert.encoder.layer.3.attention.self.key.weight,bert.encoder.layer.3.attention.self.key.bias,bert.encoder.layer.3.attention.self.value.weight,bert.encoder.layer.3.attention.self.value.bias,bert.encoder.layer.3.attention.output.dense.weight,bert.encoder.layer.3.attention.output.dense.bias,bert.encoder.layer.3.intermediate.dense.weight,bert.encoder.layer.3.intermediate.dense.bias,bert.encoder.layer.3.output.dense.weight,bert.encoder.layer.3.output.dense.bias,bert.encoder.layer.4.attention.self.query.weight,bert.encoder.layer.4.attention.self.query.bias,bert.encoder.layer.4.attention.self.key.weight,bert.encoder.layer.4.attention.self.key.bias,bert.encoder.layer.4.attention.self.value.weight,bert.encoder.layer.4.attention.self.value.bias,bert.encoder.layer.4.attention.output.dense.weight,bert.encoder.layer.4.attention.output.dense.bias,bert.encoder.layer.4.intermediate.dense.weight,bert.encoder.layer.4.intermediate.dense.bias,bert.encoder.layer.4.output.dense.weight,bert.encoder.layer.4.output.dense.bias,bert.encoder.layer.5.attention.self.query.weight,bert.encoder.layer.5.attention.self.query.bias,bert.encoder.layer.5.attention.self.key.weight,bert.encoder.layer.5.attention.self.key.bias,bert.encoder.layer.5.attention.self.value.weight,bert.encoder.layer.5.attention.self.value.bias,bert.encoder.layer.5.attention.output.dense.weight,bert.encoder.layer.5.attention.output.dense.bias,bert.encoder.layer.5.intermediate.dense.weight,bert.encoder.layer.5.intermediate.dense.bias,bert.encoder.layer.5.output.dense.weight,bert.encoder.layer.5.output.dense.bias,bert.encoder.layer.6.attention.self.query.weight,bert.encoder.layer.6.attention.self.query.bias,bert.encoder.layer.6.attention.self.key.weight,bert.encoder.layer.6.attention.self.key.bias,bert.encoder.layer.6.attention.self.value.weight,bert.encoder.layer.6.attention.self.value.bias,bert.encoder.layer.6.attention.output.dense.weight,bert.encoder.layer.6.attention.output.dense.bias,bert.encoder.layer.6.intermediate.dense.weight,bert.encoder.layer.6.intermediate.dense.bias,bert.encoder.layer.6.output.dense.weight,bert.encoder.layer.6.output.dense.bias,bert.encoder.layer.7.attention.self.query.weight,bert.encoder.layer.7.attention.self.query.bias,bert.encoder.layer.7.attention.self.key.weight,bert.encoder.layer.7.attention.self.key.bias,bert.encoder.layer.7.attention.self.value.weight,bert.encoder.layer.7.attention.self.value.bias,bert.encoder.layer.7.attention.output.dense.weight,bert.encoder.layer.7.attention.output.dense.bias,bert.encoder.layer.7.intermediate.dense.weight,bert.encoder.layer.7.intermediate.dense.bias,bert.encoder.layer.7.output.dense.weight,bert.encoder.layer.7.output.dense.bias,bert.encoder.layer.8.attention.self.query.weight,bert.encoder.layer.8.attention.self.query.bias,bert.encoder.layer.8.attention.self.key.weight,bert.encoder.layer.8.attention.self.key.bias,bert.encoder.layer.8.attention.self.value.weight,bert.encoder.layer.8.attention.self.value.bias,bert.encoder.layer.8.attention.output.dense.weight,bert.encoder.layer.8.attention.output.dense.bias,bert.encoder.layer.8.intermediate.dense.weight,bert.encoder.layer.8.intermediate.dense.bias,bert.encoder.layer.8.output.dense.weight,bert.encoder.layer.8.output.dense.bias,bert.encoder.layer.9.attention.self.query.weight,bert.encoder.layer.9.attention.self.query.bias,bert.encoder.layer.9.attention.self.key.weight,bert.encoder.layer.9.attention.self.key.bias,bert.encoder.layer.9.attention.self.value.weight,bert.encoder.layer.9.attention.self.value.bias,bert.encoder.layer.9.attention.output.dense.weight,bert.encoder.layer.9.attention.output.dense.bias,bert.encoder.layer.9.intermediate.dense.weight,bert.encoder.layer.9.intermediate.dense.bias,bert.encoder.layer.9.output.dense.weight,bert.encoder.layer.9.output.dense.bias,bert.encoder.layer.10.attention.self.query.weight,bert.encoder.layer.10.attention.self.query.bias,bert.encoder.layer.10.attention.self.key.weight,bert.encoder.layer.10.attention.self.key.bias,bert.encoder.layer.10.attention.self.value.weight,bert.encoder.layer.10.attention.self.value.bias,bert.encoder.layer.10.attention.output.dense.weight,bert.encoder.layer.10.attention.output.dense.bias,bert.encoder.layer.10.intermediate.dense.weight,bert.encoder.layer.10.intermediate.dense.bias,bert.encoder.layer.10.output.dense.weight,bert.encoder.layer.10.output.dense.bias,bert.encoder.layer.11.attention.self.query.weight,bert.encoder.layer.11.attention.self.query.bias,bert.encoder.layer.11.attention.self.key.weight,bert.encoder.layer.11.attention.self.key.bias,bert.encoder.layer.11.attention.self.value.weight,bert.encoder.layer.11.attention.self.value.bias,bert.encoder.layer.11.attention.output.dense.weight,bert.encoder.layer.11.attention.output.dense.bias,bert.encoder.layer.11.intermediate.dense.weight,bert.encoder.layer.11.intermediate.dense.bias,bert.encoder.layer.11.output.dense.weight,bert.encoder.layer.11.output.dense.bias,bert.pooler.dense.weight,bert.pooler.dense.bias,cls.predictions.bias,cls.predictions.transform.dense.weight,cls.predictions.transform.dense.bias,cls.predictions.decoder.weight,cls.seq_relationship.weight,cls.seq_relationship.bias,bert.embeddings.LayerNorm.weight,bert.embeddings.LayerNorm.bias,bert.encoder.layer.0.attention.output.LayerNorm.weight,bert.encoder.layer.0.attention.output.LayerNorm.bias,bert.encoder.layer.0.output.LayerNorm.weight,bert.encoder.layer.0.output.LayerNorm.bias,bert.encoder.layer.1.attention.output.LayerNorm.weight,bert.encoder.layer.1.attention.output.LayerNorm.bias,bert.encoder.layer.1.output.LayerNorm.weight,bert.encoder.layer.1.output.LayerNorm.bias,bert.encoder.layer.2.attention.output.LayerNorm.weight,bert.encoder.layer.2.attention.output.LayerNorm.bias,bert.encoder.layer.2.output.LayerNorm.weight,bert.encoder.layer.2.output.LayerNorm.bias,bert.encoder.layer.3.attention.output.LayerNorm.weight,bert.encoder.layer.3.attention.output.LayerNorm.bias,bert.encoder.layer.3.output.LayerNorm.weight,bert.encoder.layer.3.output.LayerNorm.bias,bert.encoder.layer.4.attention.output.LayerNorm.weight,bert.encoder.layer.4.attention.output.LayerNorm.bias,bert.encoder.layer.4.output.LayerNorm.weight,bert.encoder.layer.4.output.LayerNorm.bias,bert.encoder.layer.5.attention.output.LayerNorm.weight,bert.encoder.layer.5.attention.output.LayerNorm.bias,bert.encoder.layer.5.output.LayerNorm.weight,bert.encoder.layer.5.output.LayerNorm.bias,bert.encoder.layer.6.attention.output.LayerNorm.weight,bert.encoder.layer.6.attention.output.LayerNorm.bias,bert.encoder.layer.6.output.LayerNorm.weight,bert.encoder.layer.6.output.LayerNorm.bias,bert.encoder.layer.7.attention.output.LayerNorm.weight,bert.encoder.layer.7.attention.output.LayerNorm.bias,bert.encoder.layer.7.output.LayerNorm.weight,bert.encoder.layer.7.output.LayerNorm.bias,bert.encoder.layer.8.attention.output.LayerNorm.weight,bert.encoder.layer.8.attention.output.LayerNorm.bias,bert.encoder.layer.8.output.LayerNorm.weight,bert.encoder.layer.8.output.LayerNorm.bias,bert.encoder.layer.9.attention.output.LayerNorm.weight,bert.encoder.layer.9.attention.output.LayerNorm.bias,bert.encoder.layer.9.output.LayerNorm.weight,bert.encoder.layer.9.output.LayerNorm.bias,bert.encoder.layer.10.attention.output.LayerNorm.weight,bert.encoder.layer.10.attention.output.LayerNorm.bias,bert.encoder.layer.10.output.LayerNorm.weight,bert.encoder.layer.10.output.LayerNorm.bias,bert.encoder.layer.11.attention.output.LayerNorm.weight,bert.encoder.layer.11.attention.output.LayerNorm.bias,bert.encoder.layer.11.output.LayerNorm.weight,bert.encoder.layer.11.output.LayerNorm.bias,cls.predictions.transform.LayerNorm.weight,cls.predictions.transform.LayerNorm.bias].
 Unloaded weights of the model:
 [cls.predictions.decoder.weight,cls.seq_relationship.weight,cls.predictions.bias,cls.seq_relationship.bias,cls.predictions.transform.LayerNorm.weight,cls.predictions.transform.LayerNorm.bias,cls.predictions.transform.dense.bias,cls.predictions.transform.dense.weight]. 
 This IS expected if you initialize A model from B.
 This IS NOT expected if you initialize A model from A.
trainer = Trainer(model=model, 
                  train_dataset=train_dataset,
                  evaluator=get_application_evaluator(app_name="machine_reading_comprehension", 
                                                      valid_dataset=valid_dataset,
                                                      eval_batch_size=32,
                                                      pretrained_model_name_or_path=get_pretrain_model_path("bert-base-uncased"),
                                                      user_defined_parameters=user_defined_parameters
                                                      ))
trainer.train()

需要注意的是,由于滑动窗口机制的存在,在训练时的batch_size不可设置过大,否则容易挤爆内存。

从日志中可以看出,我们对预训练模型的参数进行了载入。下一步我们使用EasyNLP中的Train类创建训练实例,并进行训练。

模型评估

训练过程结束后,train好的模型被我们保存在一开始指定好的checkpoint_dir中,本地路径为"./squad_model_dir/"。我们可以对train好的模型进行效果评估。我们同样先使用EasyNLP中的get_application_model_for_evaluation方法构建评估模型。

model = get_application_model_for_evaluation(app_name="machine_reading_comprehension",
                                             pretrained_model_name_or_path="./squad_model_dir/", 
                                             user_defined_parameters=user_defined_parameters
                                             )
model ============ from config ===============
[2022-09-07 20:02:29,052 INFO] Loading model...
[2022-09-07 20:02:29,097 INFO] Load finished!
 Inited keys of the model:
 [_model.encoder.layer.1.attention.output.dense.weight,_model.encoder.layer.6.attention.output.LayerNorm.bias,_model.encoder.layer.6.attention.self.key.weight,_model.encoder.layer.4.intermediate.dense.bias,_model.encoder.layer.10.intermediate.dense.weight,_model.encoder.layer.6.attention.output.LayerNorm.weight,_model.encoder.layer.11.attention.self.key.bias,_model.encoder.layer.4.attention.self.key.bias,_model.encoder.layer.8.attention.self.key.bias,_model.encoder.layer.5.attention.self.key.bias,_model.encoder.layer.10.attention.self.key.weight,_model.encoder.layer.5.intermediate.dense.bias,_model.encoder.layer.4.attention.self.value.weight,_model.encoder.layer.7.attention.output.dense.weight,_model.encoder.layer.5.attention.self.query.bias,_model.encoder.layer.0.attention.self.value.bias,_model.encoder.layer.1.attention.output.LayerNorm.bias,_model.encoder.layer.1.attention.self.query.weight,_model.encoder.layer.3.output.LayerNorm.weight,_model.encoder.layer.11.attention.self.query.bias,_model.pooler.dense.weight,_model.encoder.layer.7.attention.self.key.bias,_model.encoder.layer.9.attention.self.value.weight,_model.encoder.layer.0.intermediate.dense.bias,classifier.weight,_model.encoder.layer.4.attention.output.LayerNorm.weight,_model.encoder.layer.1.output.dense.weight,_model.encoder.layer.4.attention.self.query.bias,_model.encoder.layer.2.intermediate.dense.weight,_model.encoder.layer.8.attention.output.LayerNorm.weight,_model.encoder.layer.5.output.LayerNorm.weight,_model.encoder.layer.2.attention.self.key.weight,_model.encoder.layer.2.intermediate.dense.bias,_model.encoder.layer.10.output.dense.bias,_model.encoder.layer.6.output.LayerNorm.bias,_model.encoder.layer.3.attention.self.value.bias,_model.encoder.layer.7.output.dense.bias,_model.encoder.layer.5.output.dense.weight,_model.encoder.layer.10.attention.self.key.bias,_model.encoder.layer.10.attention.self.value.weight,_model.encoder.layer.1.attention.output.dense.bias,_model.encoder.layer.6.attention.output.dense.bias,_model.encoder.layer.9.intermediate.dense.bias,_model.encoder.layer.3.output.LayerNorm.bias,_model.encoder.layer.4.attention.self.value.bias,_model.encoder.layer.2.output.dense.weight,_model.encoder.layer.8.attention.self.key.weight,_model.encoder.layer.6.output.LayerNorm.weight,_model.encoder.layer.3.output.dense.weight,_model.encoder.layer.9.attention.output.dense.bias,_model.encoder.layer.9.intermediate.dense.weight,_model.encoder.layer.0.attention.output.dense.weight,_model.encoder.layer.11.attention.self.query.weight,_model.encoder.layer.1.attention.self.key.bias,_model.encoder.layer.10.attention.output.dense.weight,_model.encoder.layer.5.output.LayerNorm.bias,_model.encoder.layer.7.attention.output.LayerNorm.bias,_model.encoder.layer.10.attention.self.value.bias,_model.encoder.layer.2.output.dense.bias,_model.encoder.layer.11.intermediate.dense.bias,_model.encoder.layer.7.attention.self.key.weight,_model.encoder.layer.6.attention.self.query.bias,_model.encoder.layer.5.attention.self.value.weight,_model.encoder.layer.1.attention.output.LayerNorm.weight,_model.encoder.layer.5.attention.self.query.weight,_model.encoder.layer.3.attention.self.key.bias,_model.encoder.layer.3.attention.output.dense.weight,_model.encoder.layer.10.attention.output.dense.bias,_model.encoder.layer.2.attention.output.LayerNorm.bias,_model.encoder.layer.7.attention.self.value.bias,_model.encoder.layer.3.attention.self.value.weight,_model.encoder.layer.11.output.dense.weight,_model.encoder.layer.3.attention.output.dense.bias,_model.encoder.layer.8.attention.output.dense.weight,_model.encoder.layer.5.attention.output.LayerNorm.bias,_model.encoder.layer.3.intermediate.dense.weight,_model.encoder.layer.4.attention.output.dense.weight,_model.encoder.layer.6.intermediate.dense.weight,_model.encoder.layer.3.attention.self.query.bias,_model.embeddings.LayerNorm.weight,_model.encoder.layer.8.intermediate.dense.weight,_model.encoder.layer.9.attention.output.LayerNorm.bias,_model.encoder.layer.11.output.LayerNorm.weight,_model.encoder.layer.4.attention.output.dense.bias,_model.encoder.layer.4.output.dense.bias,_model.encoder.layer.8.output.LayerNorm.weight,_model.encoder.layer.2.attention.self.query.weight,classifier.bias,_model.encoder.layer.7.intermediate.dense.weight,_model.encoder.layer.3.output.dense.bias,_model.encoder.layer.10.output.LayerNorm.bias,_model.encoder.layer.6.attention.output.dense.weight,_model.encoder.layer.8.attention.self.query.bias,_model.encoder.layer.6.attention.self.value.weight,_model.encoder.layer.11.attention.output.dense.weight,_model.encoder.layer.1.output.LayerNorm.bias,_model.encoder.layer.2.output.LayerNorm.bias,_model.encoder.layer.9.attention.output.dense.weight,_model.encoder.layer.2.attention.output.LayerNorm.weight,_model.encoder.layer.7.attention.self.query.bias,_model.encoder.layer.1.intermediate.dense.weight,_model.encoder.layer.9.output.LayerNorm.bias,_model.encoder.layer.8.attention.self.value.weight,_model.encoder.layer.9.output.LayerNorm.weight,_model.encoder.layer.3.intermediate.dense.bias,_model.encoder.layer.2.attention.self.value.weight,_model.encoder.layer.11.attention.self.value.weight,_model.encoder.layer.5.attention.output.LayerNorm.weight,_model.encoder.layer.11.attention.output.dense.bias,_model.encoder.layer.5.output.dense.bias,_model.encoder.layer.8.intermediate.dense.bias,_model.encoder.layer.10.attention.self.query.weight,_model.encoder.layer.10.attention.output.LayerNorm.bias,_model.encoder.layer.11.attention.output.LayerNorm.weight,_model.encoder.layer.7.attention.output.LayerNorm.weight,_model.encoder.layer.11.attention.self.value.bias,_model.embeddings.position_embeddings.weight,_model.encoder.layer.11.attention.output.LayerNorm.bias,_model.encoder.layer.11.output.dense.bias,_model.encoder.layer.1.attention.self.value.bias,_model.encoder.layer.5.attention.output.dense.weight,_model.encoder.layer.5.attention.self.key.weight,_model.encoder.layer.3.attention.output.LayerNorm.bias,_model.encoder.layer.0.output.dense.weight,_model.encoder.layer.0.attention.self.query.bias,_model.encoder.layer.8.output.dense.bias,_model.encoder.layer.10.intermediate.dense.bias,_model.encoder.layer.6.attention.self.value.bias,_model.encoder.layer.8.output.LayerNorm.bias,_model.pooler.dense.bias,_model.encoder.layer.2.attention.output.dense.bias,_model.encoder.layer.7.attention.self.value.weight,_model.encoder.layer.8.attention.self.value.bias,_model.encoder.layer.1.output.LayerNorm.weight,_model.encoder.layer.4.attention.output.LayerNorm.bias,_model.encoder.layer.11.output.LayerNorm.bias,_model.encoder.layer.0.attention.self.key.weight,_model.encoder.layer.0.attention.output.dense.bias,_model.encoder.layer.0.attention.output.LayerNorm.weight,_model.encoder.layer.0.attention.output.LayerNorm.bias,_model.encoder.layer.1.attention.self.key.weight,_model.encoder.layer.2.attention.self.query.bias,_model.encoder.layer.7.attention.self.query.weight,_model.encoder.layer.2.attention.self.key.bias,_model.encoder.layer.4.attention.self.query.weight,_model.encoder.layer.8.output.dense.weight,_model.encoder.layer.9.output.dense.weight,_model.encoder.layer.2.attention.output.dense.weight,_model.encoder.layer.7.output.dense.weight,_model.encoder.layer.10.output.LayerNorm.weight,_model.encoder.layer.7.output.LayerNorm.bias,_model.encoder.layer.11.attention.self.key.weight,_model.encoder.layer.4.attention.self.key.weight,_model.encoder.layer.0.attention.self.value.weight,_model.encoder.layer.9.attention.self.value.bias,_model.encoder.layer.6.output.dense.weight,_model.encoder.layer.9.output.dense.bias,_model.encoder.layer.0.intermediate.dense.weight,_model.encoder.layer.10.output.dense.weight,_model.encoder.layer.0.attention.self.key.bias,_model.encoder.layer.8.attention.self.query.weight,_model.encoder.layer.2.output.LayerNorm.weight,_model.encoder.layer.6.attention.self.query.weight,_model.encoder.layer.8.attention.output.LayerNorm.bias,_model.encoder.layer.3.attention.output.LayerNorm.weight,_model.embeddings.position_ids,_model.embeddings.word_embeddings.weight,_model.encoder.layer.4.output.LayerNorm.weight,_model.encoder.layer.6.output.dense.bias,_model.encoder.layer.1.intermediate.dense.bias,_model.encoder.layer.9.attention.self.key.bias,_model.encoder.layer.0.output.LayerNorm.bias,_model.encoder.layer.3.attention.self.key.weight,_model.encoder.layer.2.attention.self.value.bias,_model.encoder.layer.5.intermediate.dense.weight,_model.encoder.layer.1.output.dense.bias,_model.embeddings.LayerNorm.bias,_model.encoder.layer.0.attention.self.query.weight,_model.encoder.layer.1.attention.self.value.weight,_model.encoder.layer.10.attention.output.LayerNorm.weight,_model.encoder.layer.3.attention.self.query.weight,_model.encoder.layer.7.intermediate.dense.bias,_model.encoder.layer.4.output.dense.weight,_model.encoder.layer.6.intermediate.dense.bias,_model.encoder.layer.9.attention.output.LayerNorm.weight,_model.encoder.layer.8.attention.output.dense.bias,_model.encoder.layer.4.intermediate.dense.weight,_model.encoder.layer.9.attention.self.query.weight,_model.encoder.layer.9.attention.self.key.weight,_model.encoder.layer.11.intermediate.dense.weight,_model.encoder.layer.10.attention.self.query.bias,_model.encoder.layer.4.output.LayerNorm.bias,_model.encoder.layer.7.output.LayerNorm.weight,_model.embeddings.token_type_embeddings.weight,_model.encoder.layer.9.attention.self.query.bias,_model.encoder.layer.0.output.dense.bias,_model.encoder.layer.1.attention.self.query.bias,_model.encoder.layer.5.attention.self.value.bias,_model.encoder.layer.5.attention.output.dense.bias,_model.encoder.layer.6.attention.self.key.bias,_model.encoder.layer.0.output.LayerNorm.weight,_model.encoder.layer.7.attention.output.dense.bias].
All keys are initialized.

之后我们使用EasyNLP中的get_application_evaluator来初始化evaluator,并指定当前device下的当前模型,进行模型评估。

evaluator = get_application_evaluator(app_name="machine_reading_comprehension", 
                                      valid_dataset=valid_dataset,
                                      eval_batch_size=32, 
                                      user_defined_parameters=user_defined_parameters
                                      )
model.to(torch.cuda.current_device())
evaluator.evaluate(model=model)

模型预测

我们同样可以使用训练好的模型进行文本分类的预测。我们首先创建一个predictor,并据此实例化一个PredictorManager实例。我们指定预测好的结果输出在"dev.pred.csv",并指定输出格式为"unique_id, best_answer, query, context"。此外,我们将beam_search后排在top_n的答案结果,都输出在"dev.ans.csv"中,供模型优化时分析比对。

predictor = get_application_predictor(app_name="machine_reading_comprehension",
                                      model_dir="./squad_model_dir/", 
                                      first_sequence="question_text",
                                      second_sequence="context_text",
                                      max_seq_length=384,
                                      output_file="dev.pred.csv",
                                      user_defined_parameters=user_defined_parameters
                                      )
predictor_manager = PredictorManager(predictor=predictor,
                                     input_file="dev_squad.tsv",
                                     input_schema="qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1",
                                     output_file="dev.pred.csv",
                                     output_schema="unique_id,best_answer,query,context",
                                     append_cols=args.append_cols,
                                     batch_size=256
                                     )
predictor_manager.run()
exit()

输入以下命令,可以展示测试集中的5条数据,其id、context、query,以及经过训练后的模型预测出的answer。

print('Predicted results:  id  best_answer  query  context')
! head -n 5 dev.pred.csv
Predicted results:  id  best_answer  query  context
56be4db0acb8001400a502ec  Denver Broncos  Which NFL team represented the AFC at Super Bowl 50?  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
56be4db0acb8001400a502ed  Carolina Panthers Which NFL team represented the NFC at Super Bowl 50?  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
56be4db0acb8001400a502ee  Santa Clara, California.  Where did Super Bowl 50 take place? Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
56be4db0acb8001400a502ef  Denver Broncos  Which NFL team won Super Bowl 50? Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
56be4db0acb8001400a502f0  gold-themed What color was used to emphasize the 50th anniversary of the Super Bowl?  Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
以上输出的每行结果,从左到右依次为:id, best_answer, context, query。可以看出,模型很好地预测出了每个问题的答案,答案准确而简洁。

一步执行

值得一提的是,上述所有训练/评估/预测代码,都已经被集成在EasyNLP/examples/appzoo_tutorials/machine_reading_comprehension/main.py中,此外,我们也预先编写好了多种可供直接执行的脚本。用户可以通过带参数运行main.py中指令,或者直接使用bash文件命令行执行的方式,一步执行上述所有训练/评估/预测操作。

main文件一步执行#

用户通过以下代码带参数执行main.py中的指令,可直接对模型进行训练/评估/预测操作。

训练代码指令如下。参数说明如“载入数据”小节中所示。在本示例中,预训练模型指定为bert-base-uncased。

! python main.py \
    --mode train \
    --app_name=machine_reading_comprehension \
    --worker_gpu=1 \
    --tables=train_squad.tsv,dev_squad.tsv \
    --input_schema=qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1 \
    --first_sequence=question_text \
    --second_sequence=context_text \
    --sequence_length=384 \
    --checkpoint_dir=./squad_model_dir \
    --learning_rate=3.5e-5 \
    --epoch_num=5 \
    --random_seed=42 \
    --save_checkpoint_steps=2000 \
    --train_batch_size=32 \
    --user_defined_parameters='
        pretrain_model_name_or_path=bert-base-uncased
        language=en
        answer_name=answer_text
        qas_id=qas_id
        start_position_name=start_position_character
        doc_stride=128
        max_query_length=64
    '

评估代码如下,参数含义与训练是一致的。

! python main.py \
    --mode evaluate \
    --app_name=machine_reading_comprehension \
    --worker_gpu=1 \
    --tables=dev_squad.tsv \
    --input_schema=qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1 \
    --first_sequence=question_text \
    --second_sequence=context_text \
    --sequence_length=384 \
    --checkpoint_dir=./squad_model_dir \
    --micro_batch_size=32 \
    --user_defined_parameters='
        pretrain_model_name_or_path=bert-base-uncased
        language=en
        qas_id=qas_id
        answer_name=answer_text
        start_position_name=start_position_character
        doc_stride=128
        max_query_length=64
    '

预测代码如下。参数同样与上面保持一致,输出结果可在dev.pred.tsv中查看。

! python main.py \
    --mode predict \
    --app_name=machine_reading_comprehension \
    --worker_gpu=1 \
    --tables=dev_squad.tsv \
    --outputs=dev.pred.csv \
    --input_schema=qas_id:str:1,context_text:str:1,question_text:str:1,answer_text:str:1,start_position_character:str:1,title:str:1 \
    --output_schema=unique_id,best_answer,query,context \
    --first_sequence=question_text \
    --second_sequence=context_text \
    --sequence_length=384 \
    --checkpoint_dir=./squad_model_dir \
    --micro_batch_size=256 \
    --user_defined_parameters='
        pretrain_model_name_or_path=bert-base-uncased
        language=en
        qas_id=qas_id
        answer_name=answer_text
        start_position_name=start_position_character
        max_query_length=64
        max_answer_length=30
        doc_stride=128
        n_best_size=10
        output_answer_file=dev.ans.csv
    '

利用bash文件命令行执行

我们在EasyNLP/examples/appzoo_tutorials/machine_reading_comprehension/文件夹下封装好了多种可直接执行的bash脚本,用户同样可以通过使用bash文件命令行执行的方式来一步完成模型的训练/评估/预测。以下以run_train_eval_predict_user_defined_local_en.sh脚本为例。该bash文件需要传入两个参数,第一个参数为运行程序的GPU编号,一般为0;第二个参数代表模型的训练/评估/预测。

模型训练:

! bash run_train_eval_predict_user_defined_local_en.sh 0 train

模型评估:

! bash run_train_eval_predict_user_defined_local_en.sh 0 evaluate

模型预测:

! bash run_train_eval_predict_user_defined_local_en.sh 0 predict
相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
存储 人工智能 自然语言处理
【DSW Gallery】基于EasyNLP的BERT文本分类
EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文以文本分类为例,为您介绍如何在PAI-DSW中基于EasyNLP快速使用BERT进行文本分类模型的训练、推理。
【DSW Gallery】基于EasyNLP的BERT文本分类
|
3月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图 REV1
Bert Pytorch 源码分析:五、模型架构简图 REV1
33 0
|
3月前
|
机器学习/深度学习 人工智能 开发工具
如何快速部署本地训练的 Bert-VITS2 语音模型到 Hugging Face
Hugging Face是一个机器学习(ML)和数据科学平台和社区,帮助用户构建、部署和训练机器学习模型。它提供基础设施,用于在实时应用中演示、运行和部署人工智能(AI)。用户还可以浏览其他用户上传的模型和数据集。Hugging Face通常被称为机器学习界的GitHub,因为它让开发人员公开分享和测试他们所训练的模型。 本次分享如何快速部署本地训练的 Bert-VITS2 语音模型到 Hugging Face。
如何快速部署本地训练的 Bert-VITS2 语音模型到 Hugging Face
|
3月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图
Bert Pytorch 源码分析:五、模型架构简图
27 0
|
5月前
lda模型和bert模型的文本主题情感分类实战
lda模型和bert模型的文本主题情感分类实战
108 0
|
4月前
|
JavaScript
Bert-vits2-v2.2新版本本地训练推理整合包(原神八重神子英文模型miko)
近日,Bert-vits2-v2.2如约更新,该新版本v2.2主要把Emotion 模型换用CLAP多模态模型,推理支持输入text prompt提示词和audio prompt提示语音来进行引导风格化合成,让推理音色更具情感特色,并且推出了新的预处理webuI,操作上更加亲民和接地气。
Bert-vits2-v2.2新版本本地训练推理整合包(原神八重神子英文模型miko)
|
5月前
|
并行计算 API C++
又欲又撩人,基于新版Bert-vits2V2.0.2音色模型雷电将军八重神子一键推理整合包分享
Bert-vits2项目近期炸裂更新,放出了v2.0.2版本的代码,修正了存在于2.0先前版本的重大bug,并且重炼了底模,本次更新是即1.1.1版本后最重大的更新,支持了三语言训练及混合合成,并且做到向下兼容,可以推理老版本的模型,本次我们基于新版V2.0.2来本地推理原神小姐姐们的音色模型。
又欲又撩人,基于新版Bert-vits2V2.0.2音色模型雷电将军八重神子一键推理整合包分享
|
4月前
|
人工智能 语音技术
Bert-vits2新版本V2.1英文模型本地训练以及中英文混合推理(mix)
中英文混合输出是文本转语音(TTS)项目中很常见的需求场景,尤其在技术文章或者技术视频领域里,其中文文本中一定会夹杂着海量的英文单词,我们当然不希望AI口播只会念中文,Bert-vits2老版本(2.0以下版本)并不支持英文训练和推理,但更新了底模之后,V2.0以上版本支持了中英文混合推理(mix)模式。
Bert-vits2新版本V2.1英文模型本地训练以及中英文混合推理(mix)
|
3月前
|
机器学习/深度学习 数据采集 人工智能
【NLP】Datawhale-AI夏令营Day3打卡:Bert模型
【NLP】Datawhale-AI夏令营Day3打卡:Bert模型
|
3月前
|
机器学习/深度学习 自然语言处理 数据格式
训练你自己的自然语言处理深度学习模型,Bert预训练模型下游任务训练:情感二分类
训练你自己的自然语言处理深度学习模型,Bert预训练模型下游任务训练:情感二分类
55 0

热门文章

最新文章