【DSW Gallery】基于EasyNLP的序列标注(命名实体识别)

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文以序列标注(命名实体识别)为例,为您介绍如何在PAI-DSW中使用EasyNLP。

直接使用

请打开基于EasyNLP的序列标注(命名实体识别),并点击右上角 “ 在DSW中打开” 。

image.png


基于RoBERTa-zh的中文序列标注

序列标注(Sequence Labeling)是NLP中较为基础的任务,应用十分广泛,如分词,词性标注(Part-of-speech tagging),命名实体识别(Named Entity Recognition,NER),关键词抽取(Keywords extraction),语义角色标注(Semantic Role Labeling,SRL)等。它的目标是对给定序列中的文本单元进行打标,以获取文本单元的额外信息。RoBERTa(Robustly Optimized BERT pretraining Approach)是由FaceBook(现为META)AI研究院在BERT的基础上增加额外训练数据以及改进预训练策略而获得的更强有力的预训练语言模型,在自然语言理解任务中的性能表现显著强于BERT。

EasyNLP中,我们提供了中文版RoBERTa,以便用户能够受益于模型强大的建模能力。本文将以命名实体识别任务为例,将RoBERTa-zh作为模型底座构建NER模型,展示如何利用EasyNLP进行模型构建、训练、评估、预测。

运行环境要求

PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 16/32G

EasyNLP安装

建议从GitHub下载EasyNLP源代码进行安装,命令如下:

! git clone https://github.com/alibaba/EasyNLP.git
! pip install -r EasyNLP/requirements.txt
! cd EasyNLP 
! python setup.py install

您可以使用如下命令验证是否安装成功:

! which easynlp

如果您系统内已经安装完easynlp的CLI工具,则说明EasyNLP代码库已经安装。

数据准备

首先,您需要下载用于本示例的训练和测试集,并创建保存模型的文件夹,命令如下:

! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/train.csv
! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/dev.csv
--2022-07-14 15:14:38--  http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/train.csv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2565663 (2.4M) [text/csv]
Saving to: ‘train.csv.1’
train.csv.1         100%[===================>]   2.45M  9.30MB/s    in 0.3s    
2022-07-14 15:14:39 (9.30 MB/s) - ‘train.csv.1’ saved [2565663/2565663]
--2022-07-14 15:14:39--  http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/dev.csv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1109771 (1.1M) [text/csv]
Saving to: ‘dev.csv.1’
dev.csv.1           100%[===================>]   1.06M  5.96MB/s    in 0.2s    
2022-07-14 15:14:40 (5.96 MB/s) - ‘dev.csv.1’ saved [1109771/1109771]

数据下载完成后,可以通过以下代码查看前5条数据。其中,每一行为一条数据,包括需要进行命名实体识别的句子以及对应每个字的标签,可以根据字的标签组合成完整的命名实体标签。比如。在验证数据集dev.csv中,'中 共 中 央'对应的标签为B-ORG I-ORG I-ORG I-ORGB-ORG表示组织名起始位,I-ORG表示组织名中间位或结尾位,组合起来表示‘中共中央’是一个组织名称。

print('Training data sample:')
! head -n 5 train.csv
print('Development set data sample:')
! head -n 5 dev.csv
Training data sample:
当 希 望 工 程 救 助 的 百 万 儿 童 成 长 起 来 , 科 教 兴 国 蔚 然 成 风 时 , 今 天 有 收 藏 价 值 的 书 你 没 买 , 明 日 就 叫 你 悔 不 当 初 ! O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O
藏 书 本 来 就 是 所 有 传 统 收 藏 门 类 中 的 第 一 大 户 , 只 是 我 们 结 束 温 饱 的 时 间 太 短 而 已 。 O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O
因 有 关 日 寇 在 京 掠 夺 文 物 详 情 , 藏 界 较 为 重 视 , 也 是 我 们 收 藏 北 京 史 料 中 的 要 件 之 一 。 O O O B-LOC O O B-LOC O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC O O O O O O O O O
我 们 藏 有 一 册 1 9 4 5 年 6 月 油 印 的 《 北 京 文 物 保 存 保 管 状 态 之 调 查 报 告 》 , 调 查 范 围 涉 及 故 宫 、 历 博 、 古 研 所 、 北 大 清 华 图 书 馆 、 北 图 、 日 伪 资 料 库 等 二 十 几 家 , 言 及 文 物 二 十 万 件 以 上 , 洋 洋 三 万 余 言 , 是 珍 贵 的 北 京 史 料 。 O O O O O O O O O O O O O O O O O B-LOC I-LOC O O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC O B-LOC I-LOC O B-ORG I-ORG I-ORG O B-LOC I-LOC I-LOC I-LOC I-LOC I-LOC I-LOC O B-LOC I-LOC O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC O O O
以 家 乡 的 历 史 文 献 、 特 定 历 史 时 期 书 刊 、 某 一 名 家 或 名 著 的 多 种 出 版 物 为 专 题 , 注 意 精 品 、 非 卖 品 、 纪 念 品 , 集 成 系 列 , 那 收 藏 的 过 程 就 已 经 够 您 玩 味 无 穷 了 。 O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O
Development set data sample:
中 共 中 央 致 中 国 致 公 党 十 一 大 的 贺 词 B-ORG I-ORG I-ORG I-ORG O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG O O O
各 位 代 表 、 各 位 同 志 : O O O O O O O O O O
在 中 国 致 公 党 第 十 一 次 全 国 代 表 大 会 隆 重 召 开 之 际 , 中 国 共 产 党 中 央 委 员 会 谨 向 大 会 表 示 热 烈 的 祝 贺 , 向 致 公 党 的 同 志 们 O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O O O O O O O B-ORG I-ORG I-ORG O O O O
致 以 亲 切 的 问 候 ! O O O O O O O O
在 过 去 的 五 年 中 , 致 公 党 在 邓 小 平 理 论 指 引 下 , 遵 循 社 会 主 义 初 级 阶 段 的 基 本 路 线 , 努 力 实 践 致 公 党 十 大 提 出 的 发 挥 参 政 党 职 能 、 加 强 自 身 建 设 的 基 本 任 务 。 O O O O O O O O B-ORG I-ORG I-ORG O B-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O O O O O O O O O O O O O O O O O

初始化

在Python 3.6环境下,我们首先从刚刚安装好的EasyNLP中引入模型运行需要的各种库,并做一些初始化。在本教程中,我们使用chinese-roberta-wwm-extEasyNLP中集成了丰富的预训练模型库,如果想尝试其他预训练模型,如bertalbert等,可以在user_defined_parameters中进行相应修改,具体的模型名称可见模型列表

# 为了避免EasyNLP中的args与Jupyter系统的冲突,需要手动设置,否则无法进行初始化。
# 在命令行或py文件中运行文中代码则可忽略下述代码。
import sys
sys.argv = ['main.py']
import torch.cuda
from easynlp.appzoo import SequenceLabelingDataset
from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator
from easynlp.appzoo import get_application_model_for_evaluation
from easynlp.core import PredictorManager
from easynlp.core import Trainer
from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path
from easynlp.utils.global_vars import parse_user_defined_parameters
initialize_easynlp()
args = get_args()
user_defined_parameters = "pretrain_model_name_or_path=hfl/chinese-roberta-wwm-ext"
user_defined_parameters = parse_user_defined_parameters(user_defined_parameters)
args.checkpoint_dir = "./seq_labeling/"
[2022-07-14 14:02:21,641.641 dsw34730-66c85d4cdb-qbz2c:30060 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
Please ignore the following import error if you are using tunnel table io.
No module named '_common_io'
No module named 'easy_predict'
------------------------ arguments ------------------------
  app_name ........................................ text_classify
  append_cols ..................................... None
  buckets ......................................... None
  checkpoint_dir .................................. None
  chief_hosts ..................................... 
  data_threads .................................... 10
  distributed_backend ............................. nccl
  do_lower_case ................................... False
  epoch_num ....................................... 3.0
  export_tf_checkpoint_type ....................... easytransfer
  first_sequence .................................. None
  gradient_accumulation_steps ..................... 1
  input_schema .................................... None
  is_chief ........................................ 
  is_master_node .................................. True
  job_name ........................................ None
  label_enumerate_values .......................... None
  label_name ...................................... None
  learning_rate ................................... 5e-05
  local_rank ...................................... None
  logging_steps ................................... 100
  master_port ..................................... 23456
  max_grad_norm ................................... 1.0
  micro_batch_size ................................ 2
  mode ............................................ train
  modelzoo_base_dir ............................... 
  n_cpu ........................................... 1
  n_gpu ........................................... 1
  odps_config ..................................... None
  optimizer_type .................................. AdamW
  output_schema ................................... 
  outputs ......................................... None
  predict_queue_size .............................. 1024
  predict_slice_size .............................. 4096
  predict_table_read_thread_num ................... 16
  predict_thread_num .............................. 2
  ps_hosts ........................................ 
  random_seed ..................................... 1234
  rank ............................................ 0
  read_odps ....................................... False
  restore_works_dir ............................... ./.easynlp_predict_restore_works_dir
  resume_from_checkpoint .......................... None
  save_all_checkpoints ............................ False
  save_checkpoint_steps ........................... None
  second_sequence ................................. None
  sequence_length ................................. 16
  skip_first_line ................................. False
  tables .......................................... None
  task_count ...................................... 1
  task_index ...................................... 0
  use_amp ......................................... False
  use_torchacc .................................... False
  user_defined_parameters ......................... None
  user_entry_file ................................. None
  user_script ..................................... None
  warmup_proportion ............................... 0.1
  weight_decay .................................... 0.0001
  worker_count .................................... 1
  worker_cpu ...................................... -1
  worker_gpu ...................................... -1
  worker_hosts .................................... None
  world_size ...................................... 1
-------------------- end of arguments ---------------------
> initializing torch distributed ...
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
  from cryptography import x509
[2022-07-14 14:02:23,720.720 dsw34730-66c85d4cdb-qbz2c:30060 INFO distributed_c10d.py:195] Added key: store_based_barrier_key:1 to store for rank: 0
Init dist done. World size: 1, rank 0, l_rank 0
> setting random seeds to 1234 ...

注意:上述代码如果出现“Address already in use”错误,则需要运行以下代码清理端口(默认为6000)上正在执行的程序。

netstat -tunlp|grep 6000

kill -9 PID (需要替换成上一行代码执行结果中对应的程序ID)

载入数据

我们使用EasyNLP中自带的ClassificationDataset,对训练和测试数据进行载入。主要参数如下:

  • pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”chinese-roberta-wwm-ext”,并自动下载模型
  • max_seq_length:文本最大长度,超过将截断,不足将padding
  • input_schema:输入数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等
  • first_sequence、label_name:用于说明input_schema中哪些字段用于作为输入句子和标签列等
  • label_enumerate_values:label类型列举
  • is_training:是否为训练过程,train_dataset为True,valid_dataset为False
  • app_name:指定当前需要执行的任务,如文本分类、序列标注、文本匹配等

下面我们将手动设置一些参数以便进行实验。

args.tables = "./train.csv,./dev.csv"
args.input_schema = "content:str:1,label:str:1"
args.first_sequence = "content"
args.label_name = "label" 
args.label_enumerate_values = "B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O" 
args.learning_rate = 3e-5
args.epoch_num = 1
args.save_checkpoint_steps = 50
args.sequence_length = 128
args.micro_batch_size = 32
args.app_name = "sequence_labeling"
args.pretrained_model_name_or_path = user_defined_parameters.get('pretrain_model_name_or_path', None)
args.pretrained_model_name_or_path = get_pretrain_model_path(args.pretrained_model_name_or_path)
train_dataset = SequenceLabelingDataset(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        data_file=args.tables.split(",")[0],
        max_seq_length=args.sequence_length,
        input_schema=args.input_schema,
        first_sequence=args.first_sequence,
        label_name=args.label_name,
        label_enumerate_values=args.label_enumerate_values,
        is_training=True)
valid_dataset = SequenceLabelingDataset(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        data_file=args.tables.split(",")[-1],
        max_seq_length=args.sequence_length,
        input_schema=args.input_schema,
        first_sequence=args.first_sequence,
        label_name=args.label_name,
        label_enumerate_values=args.label_enumerate_values,
        is_training=False)
Trying downloading name_mapping.json
Success
Downloading `hfl/chinese-roberta-wwm-ext` to /root/.easynlp/modelzoo/public/hfl/chinese-roberta-wwm-ext.tgz

模型训练

处理好数据与模型载入后,我们开始训练模型。 我们使用EasyNLP中封装好的get_application_model函数进行训练时的模型构建,其参数如下:

  • app_name:任务名称,这里选择序列标注”sequence_labeling”
  • pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”chinese-roberta-wwm-ext”,并自动下载模型
  • user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters

构建模型并读取

model = get_application_model(app_name=args.app_name,
                                  pretrained_model_name_or_path=args.pretrained_model_name_or_path,
                                  num_labels=len(valid_dataset.label_enumerate_values),
                                  user_defined_parameters=args.user_defined_parameters)
 Loaded weights of the model:
 [bert.embeddings.word_embeddings.weight,bert.embeddings.position_embeddings.weight,bert.embeddings.token_type_embeddings.weight,bert.embeddings.LayerNorm.weight,bert.embeddings.LayerNorm.bias,bert.encoder.layer.0.attention.self.query.weight,bert.encoder.layer.0.attention.self.query.bias,bert.encoder.layer.0.attention.self.key.weight,bert.encoder.layer.0.attention.self.key.bias,bert.encoder.layer.0.attention.self.value.weight,bert.encoder.layer.0.attention.self.value.bias,bert.encoder.layer.0.attention.output.dense.weight,bert.encoder.layer.0.attention.output.dense.bias,bert.encoder.layer.0.attention.output.LayerNorm.weight,bert.encoder.layer.0.attention.output.LayerNorm.bias,bert.encoder.layer.0.intermediate.dense.weight,bert.encoder.layer.0.intermediate.dense.bias,bert.encoder.layer.0.output.dense.weight,bert.encoder.layer.0.output.dense.bias,bert.encoder.layer.0.output.LayerNorm.weight,bert.encoder.layer.0.output.LayerNorm.bias,bert.encoder.layer.1.attention.self.query.weight,bert.encoder.layer.1.attention.self.query.bias,bert.encoder.layer.1.attention.self.key.weight,bert.encoder.layer.1.attention.self.key.bias,bert.encoder.layer.1.attention.self.value.weight,bert.encoder.layer.1.attention.self.value.bias,bert.encoder.layer.1.attention.output.dense.weight,bert.encoder.layer.1.attention.output.dense.bias,bert.encoder.layer.1.attention.output.LayerNorm.weight,bert.encoder.layer.1.attention.output.LayerNorm.bias,bert.encoder.layer.1.intermediate.dense.weight,bert.encoder.layer.1.intermediate.dense.bias,bert.encoder.layer.1.output.dense.weight,bert.encoder.layer.1.output.dense.bias,bert.encoder.layer.1.output.LayerNorm.weight,bert.encoder.layer.1.output.LayerNorm.bias,bert.encoder.layer.2.attention.self.query.weight,bert.encoder.layer.2.attention.self.query.bias,bert.encoder.layer.2.attention.self.key.weight,bert.encoder.layer.2.attention.self.key.bias,bert.encoder.layer.2.attention.self.value.weight,bert.encoder.layer.2.attention.self.value.bias,bert.encoder.layer.2.attention.output.dense.weight,bert.encoder.layer.2.attention.output.dense.bias,bert.encoder.layer.2.attention.output.LayerNorm.weight,bert.encoder.layer.2.attention.output.LayerNorm.bias,bert.encoder.layer.2.intermediate.dense.weight,bert.encoder.layer.2.intermediate.dense.bias,bert.encoder.layer.2.output.dense.weight,bert.encoder.layer.2.output.dense.bias,bert.encoder.layer.2.output.LayerNorm.weight,bert.encoder.layer.2.output.LayerNorm.bias,bert.encoder.layer.3.attention.self.query.weight,bert.encoder.layer.3.attention.self.query.bias,bert.encoder.layer.3.attention.self.key.weight,bert.encoder.layer.3.attention.self.key.bias,bert.encoder.layer.3.attention.self.value.weight,bert.encoder.layer.3.attention.self.value.bias,bert.encoder.layer.3.attention.output.dense.weight,bert.encoder.layer.3.attention.output.dense.bias,bert.encoder.layer.3.attention.output.LayerNorm.weight,bert.encoder.layer.3.attention.output.LayerNorm.bias,bert.encoder.layer.3.intermediate.dense.weight,bert.encoder.layer.3.intermediate.dense.bias,bert.encoder.layer.3.output.dense.weight,bert.encoder.layer.3.output.dense.bias,bert.encoder.layer.3.output.LayerNorm.weight,bert.encoder.layer.3.output.LayerNorm.bias,bert.encoder.layer.4.attention.self.query.weight,bert.encoder.layer.4.attention.self.query.bias,bert.encoder.layer.4.attention.self.key.weight,bert.encoder.layer.4.attention.self.key.bias,bert.encoder.layer.4.attention.self.value.weight,bert.encoder.layer.4.attention.self.value.bias,bert.encoder.layer.4.attention.output.dense.weight,bert.encoder.layer.4.attention.output.dense.bias,bert.encoder.layer.4.attention.output.LayerNorm.weight,bert.encoder.layer.4.attention.output.LayerNorm.bias,bert.encoder.layer.4.intermediate.dense.weight,bert.encoder.layer.4.intermediate.dense.bias,bert.encoder.layer.4.output.dense.weight,bert.encoder.layer.4.output.dense.bias,bert.encoder.layer.4.output.LayerNorm.weight,bert.encoder.layer.4.output.LayerNorm.bias,bert.encoder.layer.5.attention.self.query.weight,bert.encoder.layer.5.attention.self.query.bias,bert.encoder.layer.5.attention.self.key.weight,bert.encoder.layer.5.attention.self.key.bias,bert.encoder.layer.5.attention.self.value.weight,bert.encoder.layer.5.attention.self.value.bias,bert.encoder.layer.5.attention.output.dense.weight,bert.encoder.layer.5.attention.output.dense.bias,bert.encoder.layer.5.attention.output.LayerNorm.weight,bert.encoder.layer.5.attention.output.LayerNorm.bias,bert.encoder.layer.5.intermediate.dense.weight,bert.encoder.layer.5.intermediate.dense.bias,bert.encoder.layer.5.output.dense.weight,bert.encoder.layer.5.output.dense.bias,bert.encoder.layer.5.output.LayerNorm.weight,bert.encoder.layer.5.output.LayerNorm.bias,bert.encoder.layer.6.attention.self.query.weight,bert.encoder.layer.6.attention.self.query.bias,bert.encoder.layer.6.attention.self.key.weight,bert.encoder.layer.6.attention.self.key.bias,bert.encoder.layer.6.attention.self.value.weight,bert.encoder.layer.6.attention.self.value.bias,bert.encoder.layer.6.attention.output.dense.weight,bert.encoder.layer.6.attention.output.dense.bias,bert.encoder.layer.6.attention.output.LayerNorm.weight,bert.encoder.layer.6.attention.output.LayerNorm.bias,bert.encoder.layer.6.intermediate.dense.weight,bert.encoder.layer.6.intermediate.dense.bias,bert.encoder.layer.6.output.dense.weight,bert.encoder.layer.6.output.dense.bias,bert.encoder.layer.6.output.LayerNorm.weight,bert.encoder.layer.6.output.LayerNorm.bias,bert.encoder.layer.7.attention.self.query.weight,bert.encoder.layer.7.attention.self.query.bias,bert.encoder.layer.7.attention.self.key.weight,bert.encoder.layer.7.attention.self.key.bias,bert.encoder.layer.7.attention.self.value.weight,bert.encoder.layer.7.attention.self.value.bias,bert.encoder.layer.7.attention.output.dense.weight,bert.encoder.layer.7.attention.output.dense.bias,bert.encoder.layer.7.attention.output.LayerNorm.weight,bert.encoder.layer.7.attention.output.LayerNorm.bias,bert.encoder.layer.7.intermediate.dense.weight,bert.encoder.layer.7.intermediate.dense.bias,bert.encoder.layer.7.output.dense.weight,bert.encoder.layer.7.output.dense.bias,bert.encoder.layer.7.output.LayerNorm.weight,bert.encoder.layer.7.output.LayerNorm.bias,bert.encoder.layer.8.attention.self.query.weight,bert.encoder.layer.8.attention.self.query.bias,bert.encoder.layer.8.attention.self.key.weight,bert.encoder.layer.8.attention.self.key.bias,bert.encoder.layer.8.attention.self.value.weight,bert.encoder.layer.8.attention.self.value.bias,bert.encoder.layer.8.attention.output.dense.weight,bert.encoder.layer.8.attention.output.dense.bias,bert.encoder.layer.8.attention.output.LayerNorm.weight,bert.encoder.layer.8.attention.output.LayerNorm.bias,bert.encoder.layer.8.intermediate.dense.weight,bert.encoder.layer.8.intermediate.dense.bias,bert.encoder.layer.8.output.dense.weight,bert.encoder.layer.8.output.dense.bias,bert.encoder.layer.8.output.LayerNorm.weight,bert.encoder.layer.8.output.LayerNorm.bias,bert.encoder.layer.9.attention.self.query.weight,bert.encoder.layer.9.attention.self.query.bias,bert.encoder.layer.9.attention.self.key.weight,bert.encoder.layer.9.attention.self.key.bias,bert.encoder.layer.9.attention.self.value.weight,bert.encoder.layer.9.attention.self.value.bias,bert.encoder.layer.9.attention.output.dense.weight,bert.encoder.layer.9.attention.output.dense.bias,bert.encoder.layer.9.attention.output.LayerNorm.weight,bert.encoder.layer.9.attention.output.LayerNorm.bias,bert.encoder.layer.9.intermediate.dense.weight,bert.encoder.layer.9.intermediate.dense.bias,bert.encoder.layer.9.output.dense.weight,bert.encoder.layer.9.output.dense.bias,bert.encoder.layer.9.output.LayerNorm.weight,bert.encoder.layer.9.output.LayerNorm.bias,bert.encoder.layer.10.attention.self.query.weight,bert.encoder.layer.10.attention.self.query.bias,bert.encoder.layer.10.attention.self.key.weight,bert.encoder.layer.10.attention.self.key.bias,bert.encoder.layer.10.attention.self.value.weight,bert.encoder.layer.10.attention.self.value.bias,bert.encoder.layer.10.attention.output.dense.weight,bert.encoder.layer.10.attention.output.dense.bias,bert.encoder.layer.10.attention.output.LayerNorm.weight,bert.encoder.layer.10.attention.output.LayerNorm.bias,bert.encoder.layer.10.intermediate.dense.weight,bert.encoder.layer.10.intermediate.dense.bias,bert.encoder.layer.10.output.dense.weight,bert.encoder.layer.10.output.dense.bias,bert.encoder.layer.10.output.LayerNorm.weight,bert.encoder.layer.10.output.LayerNorm.bias,bert.encoder.layer.11.attention.self.query.weight,bert.encoder.layer.11.attention.self.query.bias,bert.encoder.layer.11.attention.self.key.weight,bert.encoder.layer.11.attention.self.key.bias,bert.encoder.layer.11.attention.self.value.weight,bert.encoder.layer.11.attention.self.value.bias,bert.encoder.layer.11.attention.output.dense.weight,bert.encoder.layer.11.attention.output.dense.bias,bert.encoder.layer.11.attention.output.LayerNorm.weight,bert.encoder.layer.11.attention.output.LayerNorm.bias,bert.encoder.layer.11.intermediate.dense.weight,bert.encoder.layer.11.intermediate.dense.bias,bert.encoder.layer.11.output.dense.weight,bert.encoder.layer.11.output.dense.bias,bert.encoder.layer.11.output.LayerNorm.weight,bert.encoder.layer.11.output.LayerNorm.bias,bert.pooler.dense.weight,bert.pooler.dense.bias,cls.predictions.bias,cls.predictions.transform.dense.weight,cls.predictions.transform.dense.bias,cls.predictions.transform.LayerNorm.weight,cls.predictions.transform.LayerNorm.bias,cls.predictions.decoder.weight,cls.seq_relationship.weight,cls.seq_relationship.bias].
 Unloaded weights of the model:
 [cls.predictions.transform.dense.weight,cls.seq_relationship.bias,cls.predictions.transform.LayerNorm.bias,cls.predictions.transform.dense.bias,cls.predictions.bias,cls.seq_relationship.weight,cls.predictions.transform.LayerNorm.weight,cls.predictions.decoder.weight]. 
 This IS expected if you initialize A model from B.
 This IS NOT expected if you initialize A model from A.

构建训练器并训练

trainer = Trainer(model=model, train_dataset=train_dataset,
                          evaluator=get_application_evaluator(app_name=args.app_name, valid_dataset=valid_dataset,
                                                              eval_batch_size=args.micro_batch_size,
                                                              user_defined_parameters=user_defined_parameters))
trainer.train()
[2022-07-14 14:03:32,720 INFO] ========== Initializing Tensorboard ==========
[2022-07-14 14:03:32,747 INFO] ========== Training Start ==========
[2022-07-14 14:03:32,748 INFO]   Num of GPUs (all)       = 1
[2022-07-14 14:03:32,751 INFO]   Num of CPUs per worker  = 1
[2022-07-14 14:03:32,752 INFO]   Num dataset examples    = 10000
[2022-07-14 14:03:32,752 INFO]   Num training examples   = 10000
[2022-07-14 14:03:32,752 INFO]   Num validation examples = 4631
[2022-07-14 14:03:32,753 INFO]   Train. batch size       = 32
[2022-07-14 14:03:32,753 INFO]   Train. micro batch size = 32
[2022-07-14 14:03:32,754 INFO]   Train. batch no.        = 312
[2022-07-14 14:03:32,757 INFO]   Evaluation batch size   = 32
[2022-07-14 14:03:32,758 INFO]   Total training steps    = 313
[2022-07-14 14:03:32,758 INFO]   Sequence length         = 128
[2022-07-14 14:03:32,759 INFO]   Saving steps            = 50
[2022-07-14 14:03:32,760 INFO]   Distributed_backend     = nccl
[2022-07-14 14:03:32,760 INFO]   Worker Count            = 1
[2022-07-14 14:03:32,761 INFO]   Worker CPU              = -1
[2022-07-14 14:03:32,761 INFO]   Worker data threads     = 10
[2022-07-14 14:03:32,763 INFO]   num model params        = 102,273,031
[2022-07-14 14:03:32,765 INFO]   num trainable params    = 102,273,031
[2022-07-14 14:03:32,766 INFO] 
[2022-07-14 14:03:32,766 INFO] ========== Model Config ==========
[2022-07-14 14:03:32,767 INFO] {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "directionality": "bidi",
  "easynlp_version": "0.0.3",
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3,
    "LABEL_4": 4,
    "LABEL_5": 5,
    "LABEL_6": 6
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_past": true,
  "pad_token_id": 1,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}
optimizer type: AdamW
/root/.local/lib/python3.6/site-packages/pai_easynlp-0.0.6-py3.6.egg/easynlp/core/optimizers.py:441: UserWarning: This overload of add_ is deprecated:
  add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
  add_(Tensor other, *, Number alpha) (Triggered internally at  /workspace/artifacts/paipytorch1.8/dist/ubuntu18.04-py3.6-cuda10.1/build/src/torch/csrc/utils/python_arg_parser.cpp:1005.)
  exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
/home/pai/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:247: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  warnings.warn("To get the last learning rate computed by the scheduler, "
[2022-07-14 14:03:46,497 INFO] ========== Evaluation at global step 50 ==========
[2022-07-14 14:04:01,182 INFO] Eval: 100/144 steps finished
[2022-07-14 14:04:07,556 INFO] Inference time = 2.10s, [0.4528 ms / sample] 
[2022-07-14 14:04:07,557 INFO] Eval loss: 0.1645107916572356
[2022-07-14 14:04:07,852 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:04:07,853 INFO] found: 6991 phrases; correct: 3378.
[2022-07-14 14:04:07,853 INFO] 
accuracy:  65.86%; (non-O)accuracy:  95.66%; precision:  48.32%; recall:  54.55%; FB1:  51.25
[2022-07-14 14:04:07,854 INFO] 
              LOC: precision:  51.02%; recall:  58.85%; FB1:  54.66  3318
[2022-07-14 14:04:07,854 INFO] 
              ORG: precision:  16.01%; recall:  21.86%; FB1:  18.48  1818
[2022-07-14 14:04:07,855 INFO] 
              PER: precision:  75.15%; recall:  70.26%; FB1:  72.62  1855
[2022-07-14 14:04:07,855 INFO] Labeling F1:        51.24781916104073
[2022-07-14 14:04:07,856 INFO] Labeling Precision: 48.31926762980975
[2022-07-14 14:04:07,856 INFO] Labeling Recall:     54.55426356589147
[2022-07-14 14:04:07,857 INFO] Saving best model to ./seq_labeling/pytorch_model.bin...
[2022-07-14 14:04:15,383 INFO] Best score: 51.24781916104073
[2022-07-14 14:04:15,384 INFO] Learning rate: 0.00002809
[2022-07-14 14:04:27,051 INFO] Epoch [ 0/ 1], step [100/313], lr 0.000023, 54.28 s
[2022-07-14 14:04:27,053 INFO]   loss      : 0.4126 
[2022-07-14 14:04:27,053 INFO] ========== Evaluation at global step 100 ==========
[2022-07-14 14:04:41,792 INFO] Eval: 100/144 steps finished
[2022-07-14 14:04:48,179 INFO] Inference time = 2.07s, [0.4458 ms / sample] 
[2022-07-14 14:04:48,180 INFO] Eval loss: 0.07174456688607561
[2022-07-14 14:04:48,492 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:04:48,493 INFO] found: 6405 phrases; correct: 4872.
[2022-07-14 14:04:48,494 INFO] 
accuracy:  83.33%; (non-O)accuracy:  97.70%; precision:  76.07%; recall:  78.68%; FB1:  77.35
[2022-07-14 14:04:48,494 INFO] 
              LOC: precision:  77.14%; recall:  77.89%; FB1:  77.52  2905
[2022-07-14 14:04:48,495 INFO] 
              ORG: precision:  55.13%; recall:  62.96%; FB1:  58.79  1520
[2022-07-14 14:04:48,495 INFO] 
              PER: precision:  90.56%; recall:  90.37%; FB1:  90.46  1980
[2022-07-14 14:04:48,496 INFO] Labeling F1:        77.3517504167659
[2022-07-14 14:04:48,496 INFO] Labeling Precision: 76.0655737704918
[2022-07-14 14:04:48,498 INFO] Labeling Recall:     78.68217054263566
[2022-07-14 14:04:48,499 INFO] Saving best model to ./seq_labeling/pytorch_model.bin...
[2022-07-14 14:04:59,747 INFO] Best score: 77.3517504167659
[2022-07-14 14:04:59,748 INFO] Learning rate: 0.00002277
[2022-07-14 14:05:11,421 INFO] ========== Evaluation at global step 150 ==========
[2022-07-14 14:05:26,126 INFO] Eval: 100/144 steps finished
[2022-07-14 14:05:32,512 INFO] Inference time = 2.09s, [0.4515 ms / sample] 
[2022-07-14 14:05:32,513 INFO] Eval loss: 0.06119475209892824
[2022-07-14 14:05:32,829 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:05:32,830 INFO] found: 6424 phrases; correct: 5190.
[2022-07-14 14:05:32,830 INFO] 
accuracy:  87.22%; (non-O)accuracy:  98.19%; precision:  80.79%; recall:  83.82%; FB1:  82.28
[2022-07-14 14:05:32,831 INFO] 
              LOC: precision:  82.48%; recall:  87.56%; FB1:  84.94  3054
[2022-07-14 14:05:32,831 INFO] 
              ORG: precision:  57.24%; recall:  59.43%; FB1:  58.31  1382
[2022-07-14 14:05:32,831 INFO] 
              PER: precision:  94.57%; recall:  94.76%; FB1:  94.66  1988
[2022-07-14 14:05:32,832 INFO] Labeling F1:        82.27647431832595
[2022-07-14 14:05:32,832 INFO] Labeling Precision: 80.79078455790784
[2022-07-14 14:05:32,833 INFO] Labeling Recall:     83.81782945736434
[2022-07-14 14:05:32,834 INFO] Saving best model to ./seq_labeling/pytorch_model.bin...
[2022-07-14 14:05:40,305 INFO] Best score: 82.27647431832595
[2022-07-14 14:05:40,306 INFO] Learning rate: 0.00001745
[2022-07-14 14:05:51,957 INFO] Epoch [ 0/ 1], step [200/313], lr 0.000012, 84.90 s
[2022-07-14 14:05:51,959 INFO]   loss      : 0.2268 
[2022-07-14 14:05:51,960 INFO] ========== Evaluation at global step 200 ==========
[2022-07-14 14:06:06,691 INFO] Eval: 100/144 steps finished
[2022-07-14 14:06:13,084 INFO] Inference time = 2.10s, [0.4528 ms / sample] 
[2022-07-14 14:06:13,085 INFO] Eval loss: 0.052420847586773595
[2022-07-14 14:06:13,402 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:06:13,403 INFO] found: 6516 phrases; correct: 5320.
[2022-07-14 14:06:13,403 INFO] 
accuracy:  89.25%; (non-O)accuracy:  98.37%; precision:  81.65%; recall:  85.92%; FB1:  83.73
[2022-07-14 14:06:13,404 INFO] 
              LOC: precision:  82.97%; recall:  87.87%; FB1:  85.35  3047
[2022-07-14 14:06:13,404 INFO] 
              ORG: precision:  61.63%; recall:  68.90%; FB1:  65.06  1488
[2022-07-14 14:06:13,404 INFO] 
              PER: precision:  94.65%; recall:  94.51%; FB1:  94.58  1981
[2022-07-14 14:06:13,405 INFO] Labeling F1:        83.72678627636135
[2022-07-14 14:06:13,405 INFO] Labeling Precision: 81.64518109269491
[2022-07-14 14:06:13,406 INFO] Labeling Recall:     85.91731266149871
[2022-07-14 14:06:13,407 INFO] Saving best model to ./seq_labeling/pytorch_model.bin...
[2022-07-14 14:06:20,924 INFO] Best score: 83.72678627636135
[2022-07-14 14:06:20,925 INFO] Learning rate: 0.00001213
[2022-07-14 14:06:32,569 INFO] ========== Evaluation at global step 250 ==========
[2022-07-14 14:06:47,393 INFO] Eval: 100/144 steps finished
[2022-07-14 14:06:53,945 INFO] Inference time = 2.13s, [0.4597 ms / sample] 
[2022-07-14 14:06:53,946 INFO] Eval loss: 0.05124919564896745
[2022-07-14 14:06:54,288 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:06:54,288 INFO] found: 6510 phrases; correct: 5328.
[2022-07-14 14:06:54,289 INFO] 
accuracy:  89.44%; (non-O)accuracy:  98.40%; precision:  81.84%; recall:  86.05%; FB1:  83.89
[2022-07-14 14:06:54,290 INFO] 
              LOC: precision:  82.30%; recall:  87.77%; FB1:  84.95  3068
[2022-07-14 14:06:54,290 INFO] 
              ORG: precision:  63.30%; recall:  69.20%; FB1:  66.12  1455
[2022-07-14 14:06:54,290 INFO] 
              PER: precision:  94.72%; recall:  94.86%; FB1:  94.79  1987
[2022-07-14 14:06:54,291 INFO] Labeling F1:        83.89230042512989
[2022-07-14 14:06:54,291 INFO] Labeling Precision: 81.84331797235022
[2022-07-14 14:06:54,292 INFO] Labeling Recall:     86.04651162790698
[2022-07-14 14:06:54,293 INFO] Saving best model to ./seq_labeling/pytorch_model.bin...
[2022-07-14 14:07:01,860 INFO] Best score: 83.89230042512989
[2022-07-14 14:07:01,861 INFO] Learning rate: 0.00000681
[2022-07-14 14:07:13,487 INFO] Epoch [ 0/ 1], step [300/313], lr 0.000001, 81.53 s
[2022-07-14 14:07:13,488 INFO]   loss      : 0.1621 
[2022-07-14 14:07:13,489 INFO] ========== Evaluation at global step 300 ==========
[2022-07-14 14:07:28,173 INFO] Eval: 100/144 steps finished
[2022-07-14 14:07:34,564 INFO] Inference time = 2.09s, [0.4499 ms / sample] 
[2022-07-14 14:07:34,565 INFO] Eval loss: 0.05110448669834897
[2022-07-14 14:07:34,882 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:07:34,883 INFO] found: 6463 phrases; correct: 5300.
[2022-07-14 14:07:34,884 INFO] 
accuracy:  89.03%; (non-O)accuracy:  98.42%; precision:  82.01%; recall:  85.59%; FB1:  83.76
[2022-07-14 14:07:34,884 INFO] 
              LOC: precision:  83.34%; recall:  86.41%; FB1:  84.85  2983
[2022-07-14 14:07:34,885 INFO] 
              ORG: precision:  62.64%; recall:  69.80%; FB1:  66.03  1483
[2022-07-14 14:07:34,885 INFO] 
              PER: precision:  94.39%; recall:  95.01%; FB1:  94.70  1997
[2022-07-14 14:07:34,885 INFO] Labeling F1:        83.76135914658238
[2022-07-14 14:07:34,886 INFO] Labeling Precision: 82.00526071483831
[2022-07-14 14:07:34,886 INFO] Labeling Recall:     85.59431524547804
[2022-07-14 14:07:34,888 INFO] Best score: 83.89230042512989
[2022-07-14 14:07:34,890 INFO] Learning rate: 0.00000149
Training Time: 245.53464770317078, rank 0, gsteps 313
[2022-07-14 14:07:52,881 INFO] Eval: 100/144 steps finished
[2022-07-14 14:07:59,261 INFO] Inference time = 2.09s, [0.4500 ms / sample] 
[2022-07-14 14:07:59,262 INFO] Eval loss: 0.05115903893421436
[2022-07-14 14:07:59,581 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:07:59,582 INFO] found: 6462 phrases; correct: 5298.
[2022-07-14 14:07:59,582 INFO] 
accuracy:  88.96%; (non-O)accuracy:  98.41%; precision:  81.99%; recall:  85.56%; FB1:  83.74
[2022-07-14 14:07:59,583 INFO] 
              LOC: precision:  83.22%; recall:  86.37%; FB1:  84.77  2986
[2022-07-14 14:07:59,583 INFO] 
              ORG: precision:  62.79%; recall:  69.72%; FB1:  66.07  1478
[2022-07-14 14:07:59,584 INFO] 
              PER: precision:  94.34%; recall:  95.01%; FB1:  94.68  1998
[2022-07-14 14:07:59,584 INFO] Labeling F1:        83.73636794689426
[2022-07-14 14:07:59,585 INFO] Labeling Precision: 81.98700092850511
[2022-07-14 14:07:59,585 INFO] Labeling Recall:     85.56201550387597
[2022-07-14 14:07:59,586 INFO] Best score: 83.89230042512989
[2022-07-14 14:07:59,587 INFO] Training Time: 266.98195219039917

模型评估

训练过程结束后,train好的模型被我们保存在一开始指定好的checkpoint_dir中,本地路径为”./sequence_labeling/”。我们可以对训练好的模型进行效果评估。我们使用EasyNLP中的get_application_evaluator来初始化evaluator,并模型迁移至GPU机器,进行模型评估。

args.tables = "dev.csv"
evaluator = get_application_evaluator(app_name=args.app_name, valid_dataset=valid_dataset,
                                              eval_batch_size=args.micro_batch_size,
                                              user_defined_parameters=user_defined_parameters)
model.to(torch.cuda.current_device())
evaluator.evaluate(model=model)
[2022-07-14 14:08:28,310 INFO] Eval: 100/144 steps finished
[2022-07-14 14:08:34,709 INFO] Inference time = 2.12s, [0.4567 ms / sample] 
[2022-07-14 14:08:34,710 INFO] Eval loss: 0.05115903893421436
[2022-07-14 14:08:35,037 INFO] processed 177232 tokens with 6192 phrases; 
[2022-07-14 14:08:35,038 INFO] found: 6462 phrases; correct: 5298.
[2022-07-14 14:08:35,038 INFO] 
accuracy:  88.96%; (non-O)accuracy:  98.41%; precision:  81.99%; recall:  85.56%; FB1:  83.74
[2022-07-14 14:08:35,039 INFO] 
              LOC: precision:  83.22%; recall:  86.37%; FB1:  84.77  2986
[2022-07-14 14:08:35,040 INFO] 
              ORG: precision:  62.79%; recall:  69.72%; FB1:  66.07  1478
[2022-07-14 14:08:35,040 INFO] 
              PER: precision:  94.34%; recall:  95.01%; FB1:  94.68  1998
[2022-07-14 14:08:35,041 INFO] Labeling F1:        83.73636794689426
[2022-07-14 14:08:35,041 INFO] Labeling Precision: 81.98700092850511
[2022-07-14 14:08:35,041 INFO] Labeling Recall:     85.56201550387597
[('labeling_f1', 83.73636794689426),
 ('labeling_precision', 81.98700092850511),
 ('labeling_recall', 85.56201550387597)]

模型预测

我们同样可以使用训练好的模型进行序列标注(此处为命名实体识别)。我们首先创建一个predictor,并据此实例化一个PredictorManager实例。我们指定预测好的结果输出在dev.pred.tsv

args.tables = "dev.csv"
args.outputs = "dev.pred.csv"
args.output_schema = "output"
args.append_cols="label"
predictor = get_application_predictor(
            app_name=args.app_name, model_dir=args.checkpoint_dir,
            first_sequence=args.first_sequence,
            second_sequence=args.second_sequence,
            sequence_length=args.sequence_length,
            user_defined_parameters=user_defined_parameters)
predictor_manager = PredictorManager(
    predictor=predictor,
    input_file=args.tables.split(",")[-1],
    input_schema=args.input_schema,
    output_file=args.outputs,
    output_schema=args.output_schema,
    append_cols=args.append_cols,
    batch_size=args.micro_batch_size
)
predictor_manager.run()
[2022-07-14 14:08:39,533 INFO] Using PyTorch .bin model to predict...
[2022-07-14 14:08:41,742 INFO] Loading model...
[2022-07-14 14:08:41,768 INFO] Load finished!
 Inited keys of the model:
 [backbone.encoder.layer.7.intermediate.dense.bias,backbone.encoder.layer.0.attention.output.LayerNorm.weight,classifier.bias,backbone.encoder.layer.0.output.LayerNorm.bias,backbone.encoder.layer.6.output.LayerNorm.bias,backbone.encoder.layer.10.attention.output.dense.bias,backbone.encoder.layer.7.attention.self.query.bias,backbone.encoder.layer.11.output.dense.weight,backbone.encoder.layer.3.output.LayerNorm.bias,backbone.encoder.layer.8.output.dense.bias,backbone.encoder.layer.8.attention.self.key.weight,backbone.encoder.layer.9.attention.self.query.weight,backbone.encoder.layer.10.attention.self.query.weight,backbone.encoder.layer.11.attention.output.LayerNorm.weight,backbone.encoder.layer.9.attention.output.dense.bias,backbone.encoder.layer.5.attention.self.query.bias,backbone.encoder.layer.9.output.dense.weight,backbone.encoder.layer.9.attention.self.key.bias,backbone.encoder.layer.2.intermediate.dense.bias,backbone.encoder.layer.6.attention.output.LayerNorm.weight,backbone.encoder.layer.7.attention.self.query.weight,backbone.encoder.layer.0.attention.output.LayerNorm.bias,backbone.embeddings.token_type_embeddings.weight,backbone.encoder.layer.2.attention.self.query.weight,backbone.encoder.layer.0.output.dense.bias,backbone.encoder.layer.5.attention.self.key.bias,backbone.encoder.layer.10.output.LayerNorm.weight,backbone.encoder.layer.10.output.LayerNorm.bias,backbone.encoder.layer.5.intermediate.dense.weight,backbone.encoder.layer.3.attention.self.query.weight,backbone.encoder.layer.5.attention.self.key.weight,backbone.encoder.layer.2.attention.output.dense.weight,backbone.encoder.layer.3.output.dense.bias,backbone.encoder.layer.0.attention.self.query.bias,backbone.embeddings.word_embeddings.weight,backbone.encoder.layer.2.attention.output.LayerNorm.bias,backbone.encoder.layer.6.attention.output.LayerNorm.bias,backbone.encoder.layer.10.intermediate.dense.bias,backbone.embeddings.position_ids,backbone.encoder.layer.7.output.LayerNorm.bias,backbone.encoder.layer.9.intermediate.dense.weight,backbone.encoder.layer.7.output.LayerNorm.weight,backbone.encoder.layer.5.attention.self.value.bias,backbone.encoder.layer.3.attention.output.dense.weight,backbone.encoder.layer.7.attention.self.value.bias,backbone.encoder.layer.11.output.LayerNorm.bias,backbone.encoder.layer.3.output.LayerNorm.weight,backbone.encoder.layer.9.output.dense.bias,backbone.encoder.layer.11.intermediate.dense.weight,backbone.encoder.layer.4.attention.self.value.weight,backbone.encoder.layer.9.attention.output.LayerNorm.weight,backbone.encoder.layer.2.attention.self.value.weight,backbone.encoder.layer.3.attention.self.key.bias,backbone.encoder.layer.6.attention.self.query.weight,backbone.encoder.layer.9.attention.self.value.weight,backbone.encoder.layer.4.output.dense.bias,backbone.encoder.layer.2.output.dense.bias,backbone.encoder.layer.4.attention.output.LayerNorm.bias,backbone.encoder.layer.6.attention.output.dense.weight,backbone.encoder.layer.6.attention.self.query.bias,backbone.embeddings.LayerNorm.bias,backbone.encoder.layer.1.attention.self.query.weight,backbone.encoder.layer.4.attention.self.value.bias,backbone.encoder.layer.10.attention.self.key.weight,backbone.encoder.layer.7.attention.output.LayerNorm.weight,backbone.encoder.layer.5.output.dense.weight,backbone.encoder.layer.1.attention.self.query.bias,backbone.pooler.dense.weight,backbone.encoder.layer.1.output.LayerNorm.weight,backbone.encoder.layer.8.output.LayerNorm.weight,backbone.encoder.layer.10.attention.self.value.weight,backbone.encoder.layer.8.attention.output.dense.bias,backbone.encoder.layer.11.attention.self.key.bias,backbone.encoder.layer.11.attention.self.value.weight,backbone.encoder.layer.6.intermediate.dense.bias,backbone.encoder.layer.2.output.LayerNorm.weight,backbone.encoder.layer.5.attention.output.LayerNorm.bias,backbone.encoder.layer.4.attention.self.key.weight,backbone.encoder.layer.3.attention.output.LayerNorm.bias,backbone.encoder.layer.0.attention.output.dense.weight,backbone.encoder.layer.11.attention.output.LayerNorm.bias,backbone.encoder.layer.6.intermediate.dense.weight,backbone.encoder.layer.8.attention.self.key.bias,backbone.encoder.layer.4.attention.self.query.weight,backbone.encoder.layer.6.attention.self.value.bias,backbone.encoder.layer.9.output.LayerNorm.bias,backbone.encoder.layer.10.intermediate.dense.weight,backbone.encoder.layer.0.intermediate.dense.bias,backbone.encoder.layer.1.attention.self.value.bias,backbone.encoder.layer.9.attention.self.query.bias,backbone.encoder.layer.3.attention.output.dense.bias,backbone.encoder.layer.1.output.LayerNorm.bias,backbone.encoder.layer.11.intermediate.dense.bias,backbone.encoder.layer.4.attention.self.query.bias,backbone.encoder.layer.8.attention.self.value.bias,backbone.encoder.layer.1.attention.output.LayerNorm.bias,backbone.encoder.layer.11.attention.self.query.weight,backbone.encoder.layer.11.output.LayerNorm.weight,backbone.encoder.layer.10.output.dense.bias,backbone.encoder.layer.3.attention.self.value.weight,backbone.encoder.layer.6.output.LayerNorm.weight,backbone.encoder.layer.4.intermediate.dense.bias,backbone.encoder.layer.0.attention.self.key.bias,backbone.encoder.layer.1.attention.self.key.weight,backbone.encoder.layer.6.attention.output.dense.bias,backbone.encoder.layer.11.attention.output.dense.bias,backbone.encoder.layer.3.attention.self.value.bias,backbone.encoder.layer.1.attention.self.key.bias,backbone.encoder.layer.7.attention.self.value.weight,backbone.encoder.layer.9.output.LayerNorm.weight,backbone.encoder.layer.3.intermediate.dense.weight,backbone.encoder.layer.0.attention.self.value.weight,backbone.encoder.layer.0.attention.self.query.weight,backbone.encoder.layer.8.intermediate.dense.weight,backbone.encoder.layer.1.output.dense.bias,backbone.encoder.layer.8.attention.output.LayerNorm.bias,backbone.encoder.layer.1.attention.output.LayerNorm.weight,backbone.encoder.layer.10.attention.self.query.bias,backbone.encoder.layer.5.intermediate.dense.bias,backbone.encoder.layer.10.attention.output.dense.weight,backbone.embeddings.position_embeddings.weight,backbone.encoder.layer.2.attention.self.key.weight,backbone.encoder.layer.5.output.dense.bias,backbone.encoder.layer.1.attention.self.value.weight,backbone.encoder.layer.5.attention.output.dense.weight,backbone.encoder.layer.4.output.dense.weight,backbone.encoder.layer.7.attention.output.dense.bias,backbone.encoder.layer.7.attention.self.key.weight,backbone.encoder.layer.5.attention.output.dense.bias,backbone.encoder.layer.7.intermediate.dense.weight,backbone.encoder.layer.11.output.dense.bias,backbone.embeddings.LayerNorm.weight,backbone.encoder.layer.11.attention.self.key.weight,backbone.encoder.layer.4.attention.output.LayerNorm.weight,backbone.encoder.layer.3.attention.output.LayerNorm.weight,backbone.encoder.layer.8.output.dense.weight,backbone.encoder.layer.10.attention.output.LayerNorm.bias,backbone.encoder.layer.4.intermediate.dense.weight,backbone.encoder.layer.6.attention.self.key.weight,backbone.encoder.layer.11.attention.output.dense.weight,backbone.encoder.layer.2.attention.self.query.bias,backbone.encoder.layer.3.attention.self.query.bias,backbone.encoder.layer.4.output.LayerNorm.bias,backbone.encoder.layer.2.attention.output.dense.bias,backbone.encoder.layer.7.output.dense.bias,backbone.encoder.layer.5.attention.self.value.weight,backbone.encoder.layer.7.attention.output.LayerNorm.bias,backbone.pooler.dense.bias,backbone.encoder.layer.1.intermediate.dense.weight,backbone.encoder.layer.11.attention.self.query.bias,backbone.encoder.layer.5.output.LayerNorm.bias,backbone.encoder.layer.7.attention.output.dense.weight,backbone.encoder.layer.8.attention.self.query.weight,classifier.weight,backbone.encoder.layer.0.output.LayerNorm.weight,backbone.encoder.layer.4.attention.output.dense.weight,backbone.encoder.layer.2.output.LayerNorm.bias,backbone.encoder.layer.0.attention.output.dense.bias,backbone.encoder.layer.10.attention.self.value.bias,backbone.encoder.layer.2.attention.output.LayerNorm.weight,backbone.encoder.layer.2.attention.self.key.bias,backbone.encoder.layer.10.output.dense.weight,backbone.encoder.layer.4.output.LayerNorm.weight,backbone.encoder.layer.8.output.LayerNorm.bias,backbone.encoder.layer.8.attention.self.value.weight,backbone.encoder.layer.9.attention.output.dense.weight,backbone.encoder.layer.3.attention.self.key.weight,backbone.encoder.layer.2.intermediate.dense.weight,backbone.encoder.layer.4.attention.self.key.bias,backbone.encoder.layer.0.intermediate.dense.weight,backbone.encoder.layer.9.attention.self.value.bias,backbone.encoder.layer.11.attention.self.value.bias,backbone.encoder.layer.1.output.dense.weight,backbone.encoder.layer.3.output.dense.weight,backbone.encoder.layer.5.attention.output.LayerNorm.weight,backbone.encoder.layer.5.attention.self.query.weight,backbone.encoder.layer.8.intermediate.dense.bias,backbone.encoder.layer.1.attention.output.dense.bias,backbone.encoder.layer.0.output.dense.weight,backbone.encoder.layer.6.attention.self.value.weight,backbone.encoder.layer.1.attention.output.dense.weight,backbone.encoder.layer.6.attention.self.key.bias,backbone.encoder.layer.6.output.dense.weight,backbone.encoder.layer.4.attention.output.dense.bias,backbone.encoder.layer.7.output.dense.weight,backbone.encoder.layer.5.output.LayerNorm.weight,backbone.encoder.layer.6.output.dense.bias,backbone.encoder.layer.9.intermediate.dense.bias,backbone.encoder.layer.0.attention.self.value.bias,backbone.encoder.layer.10.attention.output.LayerNorm.weight,backbone.encoder.layer.9.attention.output.LayerNorm.bias,backbone.encoder.layer.2.attention.self.value.bias,backbone.encoder.layer.3.intermediate.dense.bias,backbone.encoder.layer.10.attention.self.key.bias,backbone.encoder.layer.9.attention.self.key.weight,backbone.encoder.layer.1.intermediate.dense.bias,backbone.encoder.layer.8.attention.output.LayerNorm.weight,backbone.encoder.layer.2.output.dense.weight,backbone.encoder.layer.8.attention.output.dense.weight,backbone.encoder.layer.7.attention.self.key.bias,backbone.encoder.layer.0.attention.self.key.weight,backbone.encoder.layer.8.attention.self.query.bias].
All keys are initialized.
[2022-07-14 14:08:41,866 INFO] Using SimplePredict to predict...
145it [00:30,  4.68it/s]
print('Labeled samples:')
! tail -n 5 dev.csv
print('Predicted results:')
! tail -n 5 dev.pred.csv
Labeled samples:
人 们 愿 意 与 他 做 生 意 , 有 时 商 业 事 务 通 过 电 话 即 可 办 理 。 O O O O O O O O O O O O O O O O O O O O O O O O O
经 过 十 几 年 的 努 力 , 他 已 成 为 世 界 最 大 的 私 人 集 装 箱 船 船 主 。 O O O O O O O O O O O O O O O O O O O O O O O O O O O
妻 贤 子 孝 家 庭 幸 福 O O O O O O O O
希 腊 人 将 瓦 西 里 斯 与 奥 纳 西 斯 比 较 时 总 不 忘 补 充 一 句 : 他 和 奥 纳 西 斯 不 同 , 他 没 有 改 组 家 庭 。 B-LOC I-LOC O O B-PER I-PER I-PER I-PER O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O
重 视 传 统 家 庭 观 念 的 希 腊 人 , 对 瓦 西 里 斯 幸 福 的 家 庭 充 满 赞 誉 。 O O O O O O O O O B-LOC I-LOC O O O B-PER I-PER I-PER I-PER O O O O O O O O O O
Predicted results:
[]  O O O O O O O O O O O O O O O O O O O O O O O O O
[]  O O O O O O O O O O O O O O O O O O O O O O O O O O O
[]  O O O O O O O O
[{'word': '希', 'tag': 'LOC', 'start': 0, 'end': 1}, {'word': '腊', 'tag': 'LOC', 'start': 3, 'end': 4}, {'word': '瓦', 'tag': 'PER', 'start': 12, 'end': 13}, {'word': '奥', 'tag': 'PER', 'start': 27, 'end': 28}, {'word': '奥', 'tag': 'PER', 'start': 78, 'end': 79}] B-LOC I-LOC O O B-PER I-PER I-PER I-PER O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O
[{'word': '希', 'tag': 'LOC', 'start': 27, 'end': 28}] O O O O O O O O O B-LOC I-LOC O O O B-PER I-PER I-PER I-PER O O O O O O O O O O

上面展示了数据集中的五条数据以及经过训练以后模型的预测结果。可以看出,模型对于上述样本的拟合结果较为准确。

一步执行

值得一提的是,上述所有训练/评估/预测代码,都已经被集成在EasyNLP/examples/appzoo_tutorials/sequence_labeling/main.py中,此外,我们也预先编写好了多种可供直接执行的脚本。用户可以通过带参数运行上述main.py文件,或者直接执行脚本文件run_user_defined_local.sh的方式,一步执行上述所有训练/评估/预测操作。

mian.py文件一步执行

用户通过以下代码带参数执行main.py中的指令,可直接对模型进行训练/评估/预测操作。 训练代码指令如下。具体的参数解释可见上文,此处不再赘述。

模型训练代码如下:

! python main.py \
    --mode train \
    --tables=./train.csv,./dev.csv \
    --input_schema=content:str:1,label:str:1 \
    --first_sequence=content \
    --label_name=label \
    --label_enumerate_values=B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O \
    --checkpoint_dir=./seq_labeling/ \
    --learning_rate=3e-5  \
    --epoch_num=1  \
    --save_checkpoint_steps=50 \
    --sequence_length=128 \
    --micro_batch_size=32 \
    --app_name=sequence_labeling \
    --user_defined_parameters='pretrain_model_name_or_path=hfl/chinese-roberta-wwm-ext'

模型评估代码如下:

! python main.py \
    --mode=evaluate \
    --tables=dev.csv \
    --input_schema=content:str:1,label:str:1 \
    --first_sequence=content \
    --label_name=label \
    --label_enumerate_values=B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O \
    --checkpoint_dir=./seq_labeling/ \
    --sequence_length=128 \
    --micro_batch_size=32 \
    --app_name=sequence_labeling

模型预测代码如下:

! python main.py \
    --mode=predict \
    --tables=dev.csv \
    --outputs=dev.pred.csv \
    --input_schema=content:str:1,label:str:1 \
    --output_schema=output \
    --append_cols=label \
    --first_sequence=content \
    --checkpoint_path=./seq_labeling/ \
    --micro_batch_size 32 \
    --sequence_length=128 \
    --app_name=sequence_labeling

利用bash文件命令行执行

我们在EasyNLP/examples/appzoo_tutorials/sequence_labeling文件夹下封装好了多种可直接执行的脚本,用户同样可以通过带参数执行脚本文件的方式来一步完成模型的训练/评估/预测。以下以run_user_defined_local.sh脚本为例。该脚本文件需要传入两个参数,第一个参数为运行程序的GPU编号,一般为0;第二个参数代表模型的训练/评估/预测。

模型训练:

! cd examples/appzoo_tutorials/sequence_labeling
! bash run_user_defined_local.sh 0 train

模型评估:

! bash run_user_defined_local.sh 0 evaluate

模型预测:

! bash run_user_defined_local.sh 0 predict


相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
自然语言处理 数据处理
浅析命名实体识别(NER)的三种序列标注方法
简述序列标注 序列标注(Sequence Tagging)是NLP中最基础的任务,应用十分广泛,如分词、词性标注(POS tagging)、命名实体识别(Named Entity Recognition,NER)、关键词抽取、语义角色标注(Semantic Role Labeling)、槽位抽取(Slot Filling)等实质上都属于序列标注的范畴。
|
自然语言处理 安全 数据挖掘
PaddleNLP基于ERNIR3.0文本分类以CAIL2018-SMALL数据集罪名预测任务为例【多标签】
文本分类任务是自然语言处理中最常见的任务,文本分类任务简单来说就是对给定的一个句子或一段文本使用文本分类器进行分类。文本分类任务广泛应用于长短文本分类、情感分析、新闻分类、事件类别分类、政务数据分类、商品信息分类、商品类目预测、文章分类、论文类别分类、专利分类、案件描述分类、罪名分类、意图分类、论文专利分类、邮件自动标签、评论正负识别、药物反应分类、对话分类、税种识别、来电信息自动分类、投诉分类、广告检测、敏感违法内容检测、内容安全检测、舆情分析、话题标记等各类日常或专业领域中。 文本分类任务可以根据标签类型分为**多分类(multi class)、多标签(multi label)、层次分类
PaddleNLP基于ERNIR3.0文本分类以CAIL2018-SMALL数据集罪名预测任务为例【多标签】
|
10月前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习应用篇-自然语言处理-命名实体识别[9]:BiLSTM+CRF实现命名实体识别、实体、关系、属性抽取实战项目合集(含智能标注)
深度学习应用篇-自然语言处理-命名实体识别[9]:BiLSTM+CRF实现命名实体识别、实体、关系、属性抽取实战项目合集(含智能标注)
深度学习应用篇-自然语言处理-命名实体识别[9]:BiLSTM+CRF实现命名实体识别、实体、关系、属性抽取实战项目合集(含智能标注)
|
缓存 自然语言处理 Shell
【DSW Gallery】基于CK-BERT的中文序列标注
EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文以序列标注(命名实体识别)为例,为您介绍如何在PAI-DSW中使用EasyNLP。
【DSW Gallery】基于CK-BERT的中文序列标注
|
算法 PyTorch 算法框架/工具
【DSW Gallery】基于EasyCV的STDC图像语义分割示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文将为您介绍如何在PAI-DSW中使用EasyCV训练轻量化语义分割模型STDC
【DSW Gallery】基于EasyCV的STDC图像语义分割示例
|
机器学习/深度学习 并行计算 数据可视化
【DSW Gallery】EasyCV-基于关键点的视频分类示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以基于关键点的视频分类为例,为您介绍如何在PAI-DSW中使用EasyCV。
【DSW Gallery】EasyCV-基于关键点的视频分类示例
|
算法 PyTorch 算法框架/工具
【DSW Gallery】基于YOLOX模型和iTAG标注数据的图像检测示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文将为您介绍如何在PAI-DSW中使用EasyCV和PAI-iTAG标注的检测数据训练YOLOX模型。
【DSW Gallery】基于YOLOX模型和iTAG标注数据的图像检测示例
|
XML 存储 JSON
2.基于Label studio的训练数据标注指南:(智能文档)文档抽取任务、PDF、表格、图片抽取标注等
2.基于Label studio的训练数据标注指南:(智能文档)文档抽取任务、PDF、表格、图片抽取标注等
|
JSON 自然语言处理 数据挖掘
4.基于Label studio的训练数据标注指南:情感分析任务观点词抽取、属性抽取
4.基于Label studio的训练数据标注指南:情感分析任务观点词抽取、属性抽取
|
XML 存储 JSON
1.基于Label studio的训练数据标注指南:信息抽取(实体关系抽取)、文本分类等
1.基于Label studio的训练数据标注指南:信息抽取(实体关系抽取)、文本分类等
1.基于Label studio的训练数据标注指南:信息抽取(实体关系抽取)、文本分类等

热门文章

最新文章