直接使用
请打开基于EasyNLP的序列标注(命名实体识别),并点击右上角 “ 在DSW中打开” 。
基于RoBERTa-zh的中文序列标注
序列标注(Sequence Labeling)是NLP中较为基础的任务,应用十分广泛,如分词,词性标注(Part-of-speech tagging),命名实体识别(Named Entity Recognition,NER),关键词抽取(Keywords extraction),语义角色标注(Semantic Role Labeling,SRL)等。它的目标是对给定序列中的文本单元进行打标,以获取文本单元的额外信息。RoBERTa(Robustly Optimized BERT pretraining Approach)是由FaceBook(现为META)AI研究院在BERT的基础上增加额外训练数据以及改进预训练策略而获得的更强有力的预训练语言模型,在自然语言理解任务中的性能表现显著强于BERT。
在EasyNLP中,我们提供了中文版RoBERTa,以便用户能够受益于模型强大的建模能力。本文将以命名实体识别任务为例,将RoBERTa-zh作为模型底座构建NER模型,展示如何利用EasyNLP进行模型构建、训练、评估、预测。
运行环境要求
PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存 16/32G
EasyNLP安装
建议从GitHub下载EasyNLP源代码进行安装,命令如下:
! git clone https://github.com/alibaba/EasyNLP.git ! pip install -r EasyNLP/requirements.txt ! cd EasyNLP ! python setup.py install
您可以使用如下命令验证是否安装成功:
! which easynlp
如果您系统内已经安装完easynlp的CLI工具,则说明EasyNLP代码库已经安装。
数据准备
首先,您需要下载用于本示例的训练和测试集,并创建保存模型的文件夹,命令如下:
! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/train.csv ! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/dev.csv
--2022-07-14 15:14:38-- http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/train.csv Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27 Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 2565663 (2.4M) [text/csv] Saving to: ‘train.csv.1’ train.csv.1 100%[===================>] 2.45M 9.30MB/s in 0.3s 2022-07-14 15:14:39 (9.30 MB/s) - ‘train.csv.1’ saved [2565663/2565663] --2022-07-14 15:14:39-- http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/dev.csv Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27 Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 1109771 (1.1M) [text/csv] Saving to: ‘dev.csv.1’ dev.csv.1 100%[===================>] 1.06M 5.96MB/s in 0.2s 2022-07-14 15:14:40 (5.96 MB/s) - ‘dev.csv.1’ saved [1109771/1109771]
数据下载完成后,可以通过以下代码查看前5条数据。其中,每一行为一条数据,包括需要进行命名实体识别的句子以及对应每个字的标签,可以根据字的标签组合成完整的命名实体标签。比如。在验证数据集dev.csv中,'中 共 中 央'对应的标签为B-ORG I-ORG I-ORG I-ORG。B-ORG表示组织名起始位,I-ORG表示组织名中间位或结尾位,组合起来表示‘中共中央’是一个组织名称。
print('Training data sample:') ! head -n 5 train.csv print('Development set data sample:') ! head -n 5 dev.csv
Training data sample: 当 希 望 工 程 救 助 的 百 万 儿 童 成 长 起 来 , 科 教 兴 国 蔚 然 成 风 时 , 今 天 有 收 藏 价 值 的 书 你 没 买 , 明 日 就 叫 你 悔 不 当 初 ! O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O 藏 书 本 来 就 是 所 有 传 统 收 藏 门 类 中 的 第 一 大 户 , 只 是 我 们 结 束 温 饱 的 时 间 太 短 而 已 。 O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O 因 有 关 日 寇 在 京 掠 夺 文 物 详 情 , 藏 界 较 为 重 视 , 也 是 我 们 收 藏 北 京 史 料 中 的 要 件 之 一 。 O O O B-LOC O O B-LOC O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC O O O O O O O O O 我 们 藏 有 一 册 1 9 4 5 年 6 月 油 印 的 《 北 京 文 物 保 存 保 管 状 态 之 调 查 报 告 》 , 调 查 范 围 涉 及 故 宫 、 历 博 、 古 研 所 、 北 大 清 华 图 书 馆 、 北 图 、 日 伪 资 料 库 等 二 十 几 家 , 言 及 文 物 二 十 万 件 以 上 , 洋 洋 三 万 余 言 , 是 珍 贵 的 北 京 史 料 。 O O O O O O O O O O O O O O O O O B-LOC I-LOC O O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC O B-LOC I-LOC O B-ORG I-ORG I-ORG O B-LOC I-LOC I-LOC I-LOC I-LOC I-LOC I-LOC O B-LOC I-LOC O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC O O O 以 家 乡 的 历 史 文 献 、 特 定 历 史 时 期 书 刊 、 某 一 名 家 或 名 著 的 多 种 出 版 物 为 专 题 , 注 意 精 品 、 非 卖 品 、 纪 念 品 , 集 成 系 列 , 那 收 藏 的 过 程 就 已 经 够 您 玩 味 无 穷 了 。 O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O Development set data sample: 中 共 中 央 致 中 国 致 公 党 十 一 大 的 贺 词 B-ORG I-ORG I-ORG I-ORG O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG O O O 各 位 代 表 、 各 位 同 志 : O O O O O O O O O O 在 中 国 致 公 党 第 十 一 次 全 国 代 表 大 会 隆 重 召 开 之 际 , 中 国 共 产 党 中 央 委 员 会 谨 向 大 会 表 示 热 烈 的 祝 贺 , 向 致 公 党 的 同 志 们 O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O O O O O O O B-ORG I-ORG I-ORG O O O O 致 以 亲 切 的 问 候 ! O O O O O O O O 在 过 去 的 五 年 中 , 致 公 党 在 邓 小 平 理 论 指 引 下 , 遵 循 社 会 主 义 初 级 阶 段 的 基 本 路 线 , 努 力 实 践 致 公 党 十 大 提 出 的 发 挥 参 政 党 职 能 、 加 强 自 身 建 设 的 基 本 任 务 。 O O O O O O O O B-ORG I-ORG I-ORG O B-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O O O O O O O O O O O O O O O O O
初始化
在Python 3.6环境下,我们首先从刚刚安装好的EasyNLP中引入模型运行需要的各种库,并做一些初始化。在本教程中,我们使用chinese-roberta-wwm-ext。EasyNLP中集成了丰富的预训练模型库,如果想尝试其他预训练模型,如bert、albert等,可以在user_defined_parameters中进行相应修改,具体的模型名称可见模型列表。
# 为了避免EasyNLP中的args与Jupyter系统的冲突,需要手动设置,否则无法进行初始化。 # 在命令行或py文件中运行文中代码则可忽略下述代码。 import sys sys.argv = ['main.py']
import torch.cuda from easynlp.appzoo import SequenceLabelingDataset from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator from easynlp.appzoo import get_application_model_for_evaluation from easynlp.core import PredictorManager from easynlp.core import Trainer from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path from easynlp.utils.global_vars import parse_user_defined_parameters initialize_easynlp() args = get_args() user_defined_parameters = "pretrain_model_name_or_path=hfl/chinese-roberta-wwm-ext" user_defined_parameters = parse_user_defined_parameters(user_defined_parameters) args.checkpoint_dir = "./seq_labeling/"
[2022-07-14 14:02:21,641.641 dsw34730-66c85d4cdb-qbz2c:30060 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off. Please ignore the following import error if you are using tunnel table io. No module named '_common_io'
No module named 'easy_predict' ------------------------ arguments ------------------------ app_name ........................................ text_classify append_cols ..................................... None buckets ......................................... None checkpoint_dir .................................. None chief_hosts ..................................... data_threads .................................... 10 distributed_backend ............................. nccl do_lower_case ................................... False epoch_num ....................................... 3.0 export_tf_checkpoint_type ....................... easytransfer first_sequence .................................. None gradient_accumulation_steps ..................... 1 input_schema .................................... None is_chief ........................................ is_master_node .................................. True job_name ........................................ None label_enumerate_values .......................... None label_name ...................................... None learning_rate ................................... 5e-05 local_rank ...................................... None logging_steps ................................... 100 master_port ..................................... 23456 max_grad_norm ................................... 1.0 micro_batch_size ................................ 2 mode ............................................ train modelzoo_base_dir ............................... n_cpu ........................................... 1 n_gpu ........................................... 1 odps_config ..................................... None optimizer_type .................................. AdamW output_schema ................................... outputs ......................................... None predict_queue_size .............................. 1024 predict_slice_size .............................. 4096 predict_table_read_thread_num ................... 16 predict_thread_num .............................. 2 ps_hosts ........................................ random_seed ..................................... 1234 rank ............................................ 0 read_odps ....................................... False restore_works_dir ............................... ./.easynlp_predict_restore_works_dir resume_from_checkpoint .......................... None save_all_checkpoints ............................ False save_checkpoint_steps ........................... None second_sequence ................................. None sequence_length ................................. 16 skip_first_line ................................. False tables .......................................... None task_count ...................................... 1 task_index ...................................... 0 use_amp ......................................... False use_torchacc .................................... False user_defined_parameters ......................... None user_entry_file ................................. None user_script ..................................... None warmup_proportion ............................... 0.1 weight_decay .................................... 0.0001 worker_count .................................... 1 worker_cpu ...................................... -1 worker_gpu ...................................... -1 worker_hosts .................................... None world_size ...................................... 1 -------------------- end of arguments --------------------- > initializing torch distributed ...
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release. from cryptography import x509 [2022-07-14 14:02:23,720.720 dsw34730-66c85d4cdb-qbz2c:30060 INFO distributed_c10d.py:195] Added key: store_based_barrier_key:1 to store for rank: 0
Init dist done. World size: 1, rank 0, l_rank 0 > setting random seeds to 1234 ...
注意:上述代码如果出现“Address already in use”错误,则需要运行以下代码清理端口(默认为6000)上正在执行的程序。
netstat -tunlp|grep 6000
kill -9 PID (需要替换成上一行代码执行结果中对应的程序ID)
载入数据
我们使用EasyNLP中自带的ClassificationDataset,对训练和测试数据进行载入。主要参数如下:
- pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”chinese-roberta-wwm-ext”,并自动下载模型
- max_seq_length:文本最大长度,超过将截断,不足将padding
- input_schema:输入数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等
- first_sequence、label_name:用于说明input_schema中哪些字段用于作为输入句子和标签列等
- label_enumerate_values:label类型列举
- is_training:是否为训练过程,train_dataset为True,valid_dataset为False
- app_name:指定当前需要执行的任务,如文本分类、序列标注、文本匹配等
下面我们将手动设置一些参数以便进行实验。
args.tables = "./train.csv,./dev.csv" args.input_schema = "content:str:1,label:str:1" args.first_sequence = "content" args.label_name = "label" args.label_enumerate_values = "B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O" args.learning_rate = 3e-5 args.epoch_num = 1 args.save_checkpoint_steps = 50 args.sequence_length = 128 args.micro_batch_size = 32 args.app_name = "sequence_labeling" args.pretrained_model_name_or_path = user_defined_parameters.get('pretrain_model_name_or_path', None) args.pretrained_model_name_or_path = get_pretrain_model_path(args.pretrained_model_name_or_path) train_dataset = SequenceLabelingDataset( pretrained_model_name_or_path=args.pretrained_model_name_or_path, data_file=args.tables.split(",")[0], max_seq_length=args.sequence_length, input_schema=args.input_schema, first_sequence=args.first_sequence, label_name=args.label_name, label_enumerate_values=args.label_enumerate_values, is_training=True) valid_dataset = SequenceLabelingDataset( pretrained_model_name_or_path=args.pretrained_model_name_or_path, data_file=args.tables.split(",")[-1], max_seq_length=args.sequence_length, input_schema=args.input_schema, first_sequence=args.first_sequence, label_name=args.label_name, label_enumerate_values=args.label_enumerate_values, is_training=False)
Trying downloading name_mapping.json Success Downloading `hfl/chinese-roberta-wwm-ext` to /root/.easynlp/modelzoo/public/hfl/chinese-roberta-wwm-ext.tgz
模型训练
处理好数据与模型载入后,我们开始训练模型。 我们使用EasyNLP中封装好的get_application_model函数进行训练时的模型构建,其参数如下:
- app_name:任务名称,这里选择序列标注”sequence_labeling”
- pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”chinese-roberta-wwm-ext”,并自动下载模型
- user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters
构建模型并读取
model = get_application_model(app_name=args.app_name, pretrained_model_name_or_path=args.pretrained_model_name_or_path, num_labels=len(valid_dataset.label_enumerate_values), user_defined_parameters=args.user_defined_parameters)
Loaded weights of the model: [bert.embeddings.word_embeddings.weight,bert.embeddings.position_embeddings.weight,bert.embeddings.token_type_embeddings.weight,bert.embeddings.LayerNorm.weight,bert.embeddings.LayerNorm.bias,bert.encoder.layer.0.attention.self.query.weight,bert.encoder.layer.0.attention.self.query.bias,bert.encoder.layer.0.attention.self.key.weight,bert.encoder.layer.0.attention.self.key.bias,bert.encoder.layer.0.attention.self.value.weight,bert.encoder.layer.0.attention.self.value.bias,bert.encoder.layer.0.attention.output.dense.weight,bert.encoder.layer.0.attention.output.dense.bias,bert.encoder.layer.0.attention.output.LayerNorm.weight,bert.encoder.layer.0.attention.output.LayerNorm.bias,bert.encoder.layer.0.intermediate.dense.weight,bert.encoder.layer.0.intermediate.dense.bias,bert.encoder.layer.0.output.dense.weight,bert.encoder.layer.0.output.dense.bias,bert.encoder.layer.0.output.LayerNorm.weight,bert.encoder.layer.0.output.LayerNorm.bias,bert.encoder.layer.1.attention.self.query.weight,bert.encoder.layer.1.attention.self.query.bias,bert.encoder.layer.1.attention.self.key.weight,bert.encoder.layer.1.attention.self.key.bias,bert.encoder.layer.1.attention.self.value.weight,bert.encoder.layer.1.attention.self.value.bias,bert.encoder.layer.1.attention.output.dense.weight,bert.encoder.layer.1.attention.output.dense.bias,bert.encoder.layer.1.attention.output.LayerNorm.weight,bert.encoder.layer.1.attention.output.LayerNorm.bias,bert.encoder.layer.1.intermediate.dense.weight,bert.encoder.layer.1.intermediate.dense.bias,bert.encoder.layer.1.output.dense.weight,bert.encoder.layer.1.output.dense.bias,bert.encoder.layer.1.output.LayerNorm.weight,bert.encoder.layer.1.output.LayerNorm.bias,bert.encoder.layer.2.attention.self.query.weight,bert.encoder.layer.2.attention.self.query.bias,bert.encoder.layer.2.attention.self.key.weight,bert.encoder.layer.2.attention.self.key.bias,bert.encoder.layer.2.attention.self.value.weight,bert.encoder.layer.2.attention.self.value.bias,bert.encoder.layer.2.attention.output.dense.weight,bert.encoder.layer.2.attention.output.dense.bias,bert.encoder.layer.2.attention.output.LayerNorm.weight,bert.encoder.layer.2.attention.output.LayerNorm.bias,bert.encoder.layer.2.intermediate.dense.weight,bert.encoder.layer.2.intermediate.dense.bias,bert.encoder.layer.2.output.dense.weight,bert.encoder.layer.2.output.dense.bias,bert.encoder.layer.2.output.LayerNorm.weight,bert.encoder.layer.2.output.LayerNorm.bias,bert.encoder.layer.3.attention.self.query.weight,bert.encoder.layer.3.attention.self.query.bias,bert.encoder.layer.3.attention.self.key.weight,bert.encoder.layer.3.attention.self.key.bias,bert.encoder.layer.3.attention.self.value.weight,bert.encoder.layer.3.attention.self.value.bias,bert.encoder.layer.3.attention.output.dense.weight,bert.encoder.layer.3.attention.output.dense.bias,bert.encoder.layer.3.attention.output.LayerNorm.weight,bert.encoder.layer.3.attention.output.LayerNorm.bias,bert.encoder.layer.3.intermediate.dense.weight,bert.encoder.layer.3.intermediate.dense.bias,bert.encoder.layer.3.output.dense.weight,bert.encoder.layer.3.output.dense.bias,bert.encoder.layer.3.output.LayerNorm.weight,bert.encoder.layer.3.output.LayerNorm.bias,bert.encoder.layer.4.attention.self.query.weight,bert.encoder.layer.4.attention.self.query.bias,bert.encoder.layer.4.attention.self.key.weight,bert.encoder.layer.4.attention.self.key.bias,bert.encoder.layer.4.attention.self.value.weight,bert.encoder.layer.4.attention.self.value.bias,bert.encoder.layer.4.attention.output.dense.weight,bert.encoder.layer.4.attention.output.dense.bias,bert.encoder.layer.4.attention.output.LayerNorm.weight,bert.encoder.layer.4.attention.output.LayerNorm.bias,bert.encoder.layer.4.intermediate.dense.weight,bert.encoder.layer.4.intermediate.dense.bias,bert.encoder.layer.4.output.dense.weight,bert.encoder.layer.4.output.dense.bias,bert.encoder.layer.4.output.LayerNorm.weight,bert.encoder.layer.4.output.LayerNorm.bias,bert.encoder.layer.5.attention.self.query.weight,bert.encoder.layer.5.attention.self.query.bias,bert.encoder.layer.5.attention.self.key.weight,bert.encoder.layer.5.attention.self.key.bias,bert.encoder.layer.5.attention.self.value.weight,bert.encoder.layer.5.attention.self.value.bias,bert.encoder.layer.5.attention.output.dense.weight,bert.encoder.layer.5.attention.output.dense.bias,bert.encoder.layer.5.attention.output.LayerNorm.weight,bert.encoder.layer.5.attention.output.LayerNorm.bias,bert.encoder.layer.5.intermediate.dense.weight,bert.encoder.layer.5.intermediate.dense.bias,bert.encoder.layer.5.output.dense.weight,bert.encoder.layer.5.output.dense.bias,bert.encoder.layer.5.output.LayerNorm.weight,bert.encoder.layer.5.output.LayerNorm.bias,bert.encoder.layer.6.attention.self.query.weight,bert.encoder.layer.6.attention.self.query.bias,bert.encoder.layer.6.attention.self.key.weight,bert.encoder.layer.6.attention.self.key.bias,bert.encoder.layer.6.attention.self.value.weight,bert.encoder.layer.6.attention.self.value.bias,bert.encoder.layer.6.attention.output.dense.weight,bert.encoder.layer.6.attention.output.dense.bias,bert.encoder.layer.6.attention.output.LayerNorm.weight,bert.encoder.layer.6.attention.output.LayerNorm.bias,bert.encoder.layer.6.intermediate.dense.weight,bert.encoder.layer.6.intermediate.dense.bias,bert.encoder.layer.6.output.dense.weight,bert.encoder.layer.6.output.dense.bias,bert.encoder.layer.6.output.LayerNorm.weight,bert.encoder.layer.6.output.LayerNorm.bias,bert.encoder.layer.7.attention.self.query.weight,bert.encoder.layer.7.attention.self.query.bias,bert.encoder.layer.7.attention.self.key.weight,bert.encoder.layer.7.attention.self.key.bias,bert.encoder.layer.7.attention.self.value.weight,bert.encoder.layer.7.attention.self.value.bias,bert.encoder.layer.7.attention.output.dense.weight,bert.encoder.layer.7.attention.output.dense.bias,bert.encoder.layer.7.attention.output.LayerNorm.weight,bert.encoder.layer.7.attention.output.LayerNorm.bias,bert.encoder.layer.7.intermediate.dense.weight,bert.encoder.layer.7.intermediate.dense.bias,bert.encoder.layer.7.output.dense.weight,bert.encoder.layer.7.output.dense.bias,bert.encoder.layer.7.output.LayerNorm.weight,bert.encoder.layer.7.output.LayerNorm.bias,bert.encoder.layer.8.attention.self.query.weight,bert.encoder.layer.8.attention.self.query.bias,bert.encoder.layer.8.attention.self.key.weight,bert.encoder.layer.8.attention.self.key.bias,bert.encoder.layer.8.attention.self.value.weight,bert.encoder.layer.8.attention.self.value.bias,bert.encoder.layer.8.attention.output.dense.weight,bert.encoder.layer.8.attention.output.dense.bias,bert.encoder.layer.8.attention.output.LayerNorm.weight,bert.encoder.layer.8.attention.output.LayerNorm.bias,bert.encoder.layer.8.intermediate.dense.weight,bert.encoder.layer.8.intermediate.dense.bias,bert.encoder.layer.8.output.dense.weight,bert.encoder.layer.8.output.dense.bias,bert.encoder.layer.8.output.LayerNorm.weight,bert.encoder.layer.8.output.LayerNorm.bias,bert.encoder.layer.9.attention.self.query.weight,bert.encoder.layer.9.attention.self.query.bias,bert.encoder.layer.9.attention.self.key.weight,bert.encoder.layer.9.attention.self.key.bias,bert.encoder.layer.9.attention.self.value.weight,bert.encoder.layer.9.attention.self.value.bias,bert.encoder.layer.9.attention.output.dense.weight,bert.encoder.layer.9.attention.output.dense.bias,bert.encoder.layer.9.attention.output.LayerNorm.weight,bert.encoder.layer.9.attention.output.LayerNorm.bias,bert.encoder.layer.9.intermediate.dense.weight,bert.encoder.layer.9.intermediate.dense.bias,bert.encoder.layer.9.output.dense.weight,bert.encoder.layer.9.output.dense.bias,bert.encoder.layer.9.output.LayerNorm.weight,bert.encoder.layer.9.output.LayerNorm.bias,bert.encoder.layer.10.attention.self.query.weight,bert.encoder.layer.10.attention.self.query.bias,bert.encoder.layer.10.attention.self.key.weight,bert.encoder.layer.10.attention.self.key.bias,bert.encoder.layer.10.attention.self.value.weight,bert.encoder.layer.10.attention.self.value.bias,bert.encoder.layer.10.attention.output.dense.weight,bert.encoder.layer.10.attention.output.dense.bias,bert.encoder.layer.10.attention.output.LayerNorm.weight,bert.encoder.layer.10.attention.output.LayerNorm.bias,bert.encoder.layer.10.intermediate.dense.weight,bert.encoder.layer.10.intermediate.dense.bias,bert.encoder.layer.10.output.dense.weight,bert.encoder.layer.10.output.dense.bias,bert.encoder.layer.10.output.LayerNorm.weight,bert.encoder.layer.10.output.LayerNorm.bias,bert.encoder.layer.11.attention.self.query.weight,bert.encoder.layer.11.attention.self.query.bias,bert.encoder.layer.11.attention.self.key.weight,bert.encoder.layer.11.attention.self.key.bias,bert.encoder.layer.11.attention.self.value.weight,bert.encoder.layer.11.attention.self.value.bias,bert.encoder.layer.11.attention.output.dense.weight,bert.encoder.layer.11.attention.output.dense.bias,bert.encoder.layer.11.attention.output.LayerNorm.weight,bert.encoder.layer.11.attention.output.LayerNorm.bias,bert.encoder.layer.11.intermediate.dense.weight,bert.encoder.layer.11.intermediate.dense.bias,bert.encoder.layer.11.output.dense.weight,bert.encoder.layer.11.output.dense.bias,bert.encoder.layer.11.output.LayerNorm.weight,bert.encoder.layer.11.output.LayerNorm.bias,bert.pooler.dense.weight,bert.pooler.dense.bias,cls.predictions.bias,cls.predictions.transform.dense.weight,cls.predictions.transform.dense.bias,cls.predictions.transform.LayerNorm.weight,cls.predictions.transform.LayerNorm.bias,cls.predictions.decoder.weight,cls.seq_relationship.weight,cls.seq_relationship.bias]. Unloaded weights of the model: [cls.predictions.transform.dense.weight,cls.seq_relationship.bias,cls.predictions.transform.LayerNorm.bias,cls.predictions.transform.dense.bias,cls.predictions.bias,cls.seq_relationship.weight,cls.predictions.transform.LayerNorm.weight,cls.predictions.decoder.weight]. This IS expected if you initialize A model from B. This IS NOT expected if you initialize A model from A.
构建训练器并训练
trainer = Trainer(model=model, train_dataset=train_dataset, evaluator=get_application_evaluator(app_name=args.app_name, valid_dataset=valid_dataset, eval_batch_size=args.micro_batch_size, user_defined_parameters=user_defined_parameters)) trainer.train()
[2022-07-14 14:03:32,720 INFO] ========== Initializing Tensorboard ========== [2022-07-14 14:03:32,747 INFO] ========== Training Start ========== [2022-07-14 14:03:32,748 INFO] Num of GPUs (all) = 1 [2022-07-14 14:03:32,751 INFO] Num of CPUs per worker = 1 [2022-07-14 14:03:32,752 INFO] Num dataset examples = 10000 [2022-07-14 14:03:32,752 INFO] Num training examples = 10000 [2022-07-14 14:03:32,752 INFO] Num validation examples = 4631 [2022-07-14 14:03:32,753 INFO] Train. batch size = 32 [2022-07-14 14:03:32,753 INFO] Train. micro batch size = 32 [2022-07-14 14:03:32,754 INFO] Train. batch no. = 312 [2022-07-14 14:03:32,757 INFO] Evaluation batch size = 32 [2022-07-14 14:03:32,758 INFO] Total training steps = 313 [2022-07-14 14:03:32,758 INFO] Sequence length = 128 [2022-07-14 14:03:32,759 INFO] Saving steps = 50 [2022-07-14 14:03:32,760 INFO] Distributed_backend = nccl [2022-07-14 14:03:32,760 INFO] Worker Count = 1 [2022-07-14 14:03:32,761 INFO] Worker CPU = -1 [2022-07-14 14:03:32,761 INFO] Worker data threads = 10 [2022-07-14 14:03:32,763 INFO] num model params = 102,273,031 [2022-07-14 14:03:32,765 INFO] num trainable params = 102,273,031 [2022-07-14 14:03:32,766 INFO] [2022-07-14 14:03:32,766 INFO] ========== Model Config ========== [2022-07-14 14:03:32,767 INFO] { "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "bos_token_id": 0, "directionality": "bidi", "easynlp_version": "0.0.3", "eos_token_id": 2, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "id2label": { "0": "LABEL_0", "1": "LABEL_1", "2": "LABEL_2", "3": "LABEL_3", "4": "LABEL_4", "5": "LABEL_5", "6": "LABEL_6" }, "initializer_range": 0.02, "intermediate_size": 3072, "label2id": { "LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2, "LABEL_3": 3, "LABEL_4": 4, "LABEL_5": 5, "LABEL_6": 6 }, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "output_past": true, "pad_token_id": 1, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128 }
optimizer type: AdamW
/root/.local/lib/python3.6/site-packages/pai_easynlp-0.0.6-py3.6.egg/easynlp/core/optimizers.py:441: UserWarning: This overload of add_ is deprecated: add_(Number alpha, Tensor other) Consider using one of the following signatures instead: add_(Tensor other, *, Number alpha) (Triggered internally at /workspace/artifacts/paipytorch1.8/dist/ubuntu18.04-py3.6-cuda10.1/build/src/torch/csrc/utils/python_arg_parser.cpp:1005.) exp_avg.mul_(beta1).add_(1.0 - beta1, grad) /home/pai/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:247: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`. warnings.warn("To get the last learning rate computed by the scheduler, " [2022-07-14 14:03:46,497 INFO] ========== Evaluation at global step 50 ========== [2022-07-14 14:04:01,182 INFO] Eval: 100/144 steps finished [2022-07-14 14:04:07,556 INFO] Inference time = 2.10s, [0.4528 ms / sample] [2022-07-14 14:04:07,557 INFO] Eval loss: 0.1645107916572356 [2022-07-14 14:04:07,852 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:04:07,853 INFO] found: 6991 phrases; correct: 3378. [2022-07-14 14:04:07,853 INFO] accuracy: 65.86%; (non-O)accuracy: 95.66%; precision: 48.32%; recall: 54.55%; FB1: 51.25 [2022-07-14 14:04:07,854 INFO] LOC: precision: 51.02%; recall: 58.85%; FB1: 54.66 3318 [2022-07-14 14:04:07,854 INFO] ORG: precision: 16.01%; recall: 21.86%; FB1: 18.48 1818 [2022-07-14 14:04:07,855 INFO] PER: precision: 75.15%; recall: 70.26%; FB1: 72.62 1855 [2022-07-14 14:04:07,855 INFO] Labeling F1: 51.24781916104073 [2022-07-14 14:04:07,856 INFO] Labeling Precision: 48.31926762980975 [2022-07-14 14:04:07,856 INFO] Labeling Recall: 54.55426356589147 [2022-07-14 14:04:07,857 INFO] Saving best model to ./seq_labeling/pytorch_model.bin... [2022-07-14 14:04:15,383 INFO] Best score: 51.24781916104073 [2022-07-14 14:04:15,384 INFO] Learning rate: 0.00002809 [2022-07-14 14:04:27,051 INFO] Epoch [ 0/ 1], step [100/313], lr 0.000023, 54.28 s [2022-07-14 14:04:27,053 INFO] loss : 0.4126 [2022-07-14 14:04:27,053 INFO] ========== Evaluation at global step 100 ========== [2022-07-14 14:04:41,792 INFO] Eval: 100/144 steps finished [2022-07-14 14:04:48,179 INFO] Inference time = 2.07s, [0.4458 ms / sample] [2022-07-14 14:04:48,180 INFO] Eval loss: 0.07174456688607561 [2022-07-14 14:04:48,492 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:04:48,493 INFO] found: 6405 phrases; correct: 4872. [2022-07-14 14:04:48,494 INFO] accuracy: 83.33%; (non-O)accuracy: 97.70%; precision: 76.07%; recall: 78.68%; FB1: 77.35 [2022-07-14 14:04:48,494 INFO] LOC: precision: 77.14%; recall: 77.89%; FB1: 77.52 2905 [2022-07-14 14:04:48,495 INFO] ORG: precision: 55.13%; recall: 62.96%; FB1: 58.79 1520 [2022-07-14 14:04:48,495 INFO] PER: precision: 90.56%; recall: 90.37%; FB1: 90.46 1980 [2022-07-14 14:04:48,496 INFO] Labeling F1: 77.3517504167659 [2022-07-14 14:04:48,496 INFO] Labeling Precision: 76.0655737704918 [2022-07-14 14:04:48,498 INFO] Labeling Recall: 78.68217054263566 [2022-07-14 14:04:48,499 INFO] Saving best model to ./seq_labeling/pytorch_model.bin... [2022-07-14 14:04:59,747 INFO] Best score: 77.3517504167659 [2022-07-14 14:04:59,748 INFO] Learning rate: 0.00002277 [2022-07-14 14:05:11,421 INFO] ========== Evaluation at global step 150 ========== [2022-07-14 14:05:26,126 INFO] Eval: 100/144 steps finished [2022-07-14 14:05:32,512 INFO] Inference time = 2.09s, [0.4515 ms / sample] [2022-07-14 14:05:32,513 INFO] Eval loss: 0.06119475209892824 [2022-07-14 14:05:32,829 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:05:32,830 INFO] found: 6424 phrases; correct: 5190. [2022-07-14 14:05:32,830 INFO] accuracy: 87.22%; (non-O)accuracy: 98.19%; precision: 80.79%; recall: 83.82%; FB1: 82.28 [2022-07-14 14:05:32,831 INFO] LOC: precision: 82.48%; recall: 87.56%; FB1: 84.94 3054 [2022-07-14 14:05:32,831 INFO] ORG: precision: 57.24%; recall: 59.43%; FB1: 58.31 1382 [2022-07-14 14:05:32,831 INFO] PER: precision: 94.57%; recall: 94.76%; FB1: 94.66 1988 [2022-07-14 14:05:32,832 INFO] Labeling F1: 82.27647431832595 [2022-07-14 14:05:32,832 INFO] Labeling Precision: 80.79078455790784 [2022-07-14 14:05:32,833 INFO] Labeling Recall: 83.81782945736434 [2022-07-14 14:05:32,834 INFO] Saving best model to ./seq_labeling/pytorch_model.bin... [2022-07-14 14:05:40,305 INFO] Best score: 82.27647431832595 [2022-07-14 14:05:40,306 INFO] Learning rate: 0.00001745 [2022-07-14 14:05:51,957 INFO] Epoch [ 0/ 1], step [200/313], lr 0.000012, 84.90 s [2022-07-14 14:05:51,959 INFO] loss : 0.2268 [2022-07-14 14:05:51,960 INFO] ========== Evaluation at global step 200 ========== [2022-07-14 14:06:06,691 INFO] Eval: 100/144 steps finished [2022-07-14 14:06:13,084 INFO] Inference time = 2.10s, [0.4528 ms / sample] [2022-07-14 14:06:13,085 INFO] Eval loss: 0.052420847586773595 [2022-07-14 14:06:13,402 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:06:13,403 INFO] found: 6516 phrases; correct: 5320. [2022-07-14 14:06:13,403 INFO] accuracy: 89.25%; (non-O)accuracy: 98.37%; precision: 81.65%; recall: 85.92%; FB1: 83.73 [2022-07-14 14:06:13,404 INFO] LOC: precision: 82.97%; recall: 87.87%; FB1: 85.35 3047 [2022-07-14 14:06:13,404 INFO] ORG: precision: 61.63%; recall: 68.90%; FB1: 65.06 1488 [2022-07-14 14:06:13,404 INFO] PER: precision: 94.65%; recall: 94.51%; FB1: 94.58 1981 [2022-07-14 14:06:13,405 INFO] Labeling F1: 83.72678627636135 [2022-07-14 14:06:13,405 INFO] Labeling Precision: 81.64518109269491 [2022-07-14 14:06:13,406 INFO] Labeling Recall: 85.91731266149871 [2022-07-14 14:06:13,407 INFO] Saving best model to ./seq_labeling/pytorch_model.bin... [2022-07-14 14:06:20,924 INFO] Best score: 83.72678627636135 [2022-07-14 14:06:20,925 INFO] Learning rate: 0.00001213 [2022-07-14 14:06:32,569 INFO] ========== Evaluation at global step 250 ========== [2022-07-14 14:06:47,393 INFO] Eval: 100/144 steps finished [2022-07-14 14:06:53,945 INFO] Inference time = 2.13s, [0.4597 ms / sample] [2022-07-14 14:06:53,946 INFO] Eval loss: 0.05124919564896745 [2022-07-14 14:06:54,288 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:06:54,288 INFO] found: 6510 phrases; correct: 5328. [2022-07-14 14:06:54,289 INFO] accuracy: 89.44%; (non-O)accuracy: 98.40%; precision: 81.84%; recall: 86.05%; FB1: 83.89 [2022-07-14 14:06:54,290 INFO] LOC: precision: 82.30%; recall: 87.77%; FB1: 84.95 3068 [2022-07-14 14:06:54,290 INFO] ORG: precision: 63.30%; recall: 69.20%; FB1: 66.12 1455 [2022-07-14 14:06:54,290 INFO] PER: precision: 94.72%; recall: 94.86%; FB1: 94.79 1987 [2022-07-14 14:06:54,291 INFO] Labeling F1: 83.89230042512989 [2022-07-14 14:06:54,291 INFO] Labeling Precision: 81.84331797235022 [2022-07-14 14:06:54,292 INFO] Labeling Recall: 86.04651162790698 [2022-07-14 14:06:54,293 INFO] Saving best model to ./seq_labeling/pytorch_model.bin... [2022-07-14 14:07:01,860 INFO] Best score: 83.89230042512989 [2022-07-14 14:07:01,861 INFO] Learning rate: 0.00000681 [2022-07-14 14:07:13,487 INFO] Epoch [ 0/ 1], step [300/313], lr 0.000001, 81.53 s [2022-07-14 14:07:13,488 INFO] loss : 0.1621 [2022-07-14 14:07:13,489 INFO] ========== Evaluation at global step 300 ========== [2022-07-14 14:07:28,173 INFO] Eval: 100/144 steps finished [2022-07-14 14:07:34,564 INFO] Inference time = 2.09s, [0.4499 ms / sample] [2022-07-14 14:07:34,565 INFO] Eval loss: 0.05110448669834897 [2022-07-14 14:07:34,882 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:07:34,883 INFO] found: 6463 phrases; correct: 5300. [2022-07-14 14:07:34,884 INFO] accuracy: 89.03%; (non-O)accuracy: 98.42%; precision: 82.01%; recall: 85.59%; FB1: 83.76 [2022-07-14 14:07:34,884 INFO] LOC: precision: 83.34%; recall: 86.41%; FB1: 84.85 2983 [2022-07-14 14:07:34,885 INFO] ORG: precision: 62.64%; recall: 69.80%; FB1: 66.03 1483 [2022-07-14 14:07:34,885 INFO] PER: precision: 94.39%; recall: 95.01%; FB1: 94.70 1997 [2022-07-14 14:07:34,885 INFO] Labeling F1: 83.76135914658238 [2022-07-14 14:07:34,886 INFO] Labeling Precision: 82.00526071483831 [2022-07-14 14:07:34,886 INFO] Labeling Recall: 85.59431524547804 [2022-07-14 14:07:34,888 INFO] Best score: 83.89230042512989 [2022-07-14 14:07:34,890 INFO] Learning rate: 0.00000149
Training Time: 245.53464770317078, rank 0, gsteps 313
[2022-07-14 14:07:52,881 INFO] Eval: 100/144 steps finished [2022-07-14 14:07:59,261 INFO] Inference time = 2.09s, [0.4500 ms / sample] [2022-07-14 14:07:59,262 INFO] Eval loss: 0.05115903893421436 [2022-07-14 14:07:59,581 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:07:59,582 INFO] found: 6462 phrases; correct: 5298. [2022-07-14 14:07:59,582 INFO] accuracy: 88.96%; (non-O)accuracy: 98.41%; precision: 81.99%; recall: 85.56%; FB1: 83.74 [2022-07-14 14:07:59,583 INFO] LOC: precision: 83.22%; recall: 86.37%; FB1: 84.77 2986 [2022-07-14 14:07:59,583 INFO] ORG: precision: 62.79%; recall: 69.72%; FB1: 66.07 1478 [2022-07-14 14:07:59,584 INFO] PER: precision: 94.34%; recall: 95.01%; FB1: 94.68 1998 [2022-07-14 14:07:59,584 INFO] Labeling F1: 83.73636794689426 [2022-07-14 14:07:59,585 INFO] Labeling Precision: 81.98700092850511 [2022-07-14 14:07:59,585 INFO] Labeling Recall: 85.56201550387597 [2022-07-14 14:07:59,586 INFO] Best score: 83.89230042512989 [2022-07-14 14:07:59,587 INFO] Training Time: 266.98195219039917
模型评估
训练过程结束后,train好的模型被我们保存在一开始指定好的checkpoint_dir中,本地路径为”./sequence_labeling/”。我们可以对训练好的模型进行效果评估。我们使用EasyNLP中的get_application_evaluator来初始化evaluator,并模型迁移至GPU机器,进行模型评估。
args.tables = "dev.csv" evaluator = get_application_evaluator(app_name=args.app_name, valid_dataset=valid_dataset, eval_batch_size=args.micro_batch_size, user_defined_parameters=user_defined_parameters) model.to(torch.cuda.current_device()) evaluator.evaluate(model=model)
[2022-07-14 14:08:28,310 INFO] Eval: 100/144 steps finished [2022-07-14 14:08:34,709 INFO] Inference time = 2.12s, [0.4567 ms / sample] [2022-07-14 14:08:34,710 INFO] Eval loss: 0.05115903893421436 [2022-07-14 14:08:35,037 INFO] processed 177232 tokens with 6192 phrases; [2022-07-14 14:08:35,038 INFO] found: 6462 phrases; correct: 5298. [2022-07-14 14:08:35,038 INFO] accuracy: 88.96%; (non-O)accuracy: 98.41%; precision: 81.99%; recall: 85.56%; FB1: 83.74 [2022-07-14 14:08:35,039 INFO] LOC: precision: 83.22%; recall: 86.37%; FB1: 84.77 2986 [2022-07-14 14:08:35,040 INFO] ORG: precision: 62.79%; recall: 69.72%; FB1: 66.07 1478 [2022-07-14 14:08:35,040 INFO] PER: precision: 94.34%; recall: 95.01%; FB1: 94.68 1998 [2022-07-14 14:08:35,041 INFO] Labeling F1: 83.73636794689426 [2022-07-14 14:08:35,041 INFO] Labeling Precision: 81.98700092850511 [2022-07-14 14:08:35,041 INFO] Labeling Recall: 85.56201550387597
[('labeling_f1', 83.73636794689426), ('labeling_precision', 81.98700092850511), ('labeling_recall', 85.56201550387597)]
模型预测
我们同样可以使用训练好的模型进行序列标注(此处为命名实体识别)。我们首先创建一个predictor,并据此实例化一个PredictorManager实例。我们指定预测好的结果输出在dev.pred.tsv。
args.tables = "dev.csv" args.outputs = "dev.pred.csv" args.output_schema = "output" args.append_cols="label" predictor = get_application_predictor( app_name=args.app_name, model_dir=args.checkpoint_dir, first_sequence=args.first_sequence, second_sequence=args.second_sequence, sequence_length=args.sequence_length, user_defined_parameters=user_defined_parameters) predictor_manager = PredictorManager( predictor=predictor, input_file=args.tables.split(",")[-1], input_schema=args.input_schema, output_file=args.outputs, output_schema=args.output_schema, append_cols=args.append_cols, batch_size=args.micro_batch_size ) predictor_manager.run()
[2022-07-14 14:08:39,533 INFO] Using PyTorch .bin model to predict... [2022-07-14 14:08:41,742 INFO] Loading model... [2022-07-14 14:08:41,768 INFO] Load finished! Inited keys of the model: [backbone.encoder.layer.7.intermediate.dense.bias,backbone.encoder.layer.0.attention.output.LayerNorm.weight,classifier.bias,backbone.encoder.layer.0.output.LayerNorm.bias,backbone.encoder.layer.6.output.LayerNorm.bias,backbone.encoder.layer.10.attention.output.dense.bias,backbone.encoder.layer.7.attention.self.query.bias,backbone.encoder.layer.11.output.dense.weight,backbone.encoder.layer.3.output.LayerNorm.bias,backbone.encoder.layer.8.output.dense.bias,backbone.encoder.layer.8.attention.self.key.weight,backbone.encoder.layer.9.attention.self.query.weight,backbone.encoder.layer.10.attention.self.query.weight,backbone.encoder.layer.11.attention.output.LayerNorm.weight,backbone.encoder.layer.9.attention.output.dense.bias,backbone.encoder.layer.5.attention.self.query.bias,backbone.encoder.layer.9.output.dense.weight,backbone.encoder.layer.9.attention.self.key.bias,backbone.encoder.layer.2.intermediate.dense.bias,backbone.encoder.layer.6.attention.output.LayerNorm.weight,backbone.encoder.layer.7.attention.self.query.weight,backbone.encoder.layer.0.attention.output.LayerNorm.bias,backbone.embeddings.token_type_embeddings.weight,backbone.encoder.layer.2.attention.self.query.weight,backbone.encoder.layer.0.output.dense.bias,backbone.encoder.layer.5.attention.self.key.bias,backbone.encoder.layer.10.output.LayerNorm.weight,backbone.encoder.layer.10.output.LayerNorm.bias,backbone.encoder.layer.5.intermediate.dense.weight,backbone.encoder.layer.3.attention.self.query.weight,backbone.encoder.layer.5.attention.self.key.weight,backbone.encoder.layer.2.attention.output.dense.weight,backbone.encoder.layer.3.output.dense.bias,backbone.encoder.layer.0.attention.self.query.bias,backbone.embeddings.word_embeddings.weight,backbone.encoder.layer.2.attention.output.LayerNorm.bias,backbone.encoder.layer.6.attention.output.LayerNorm.bias,backbone.encoder.layer.10.intermediate.dense.bias,backbone.embeddings.position_ids,backbone.encoder.layer.7.output.LayerNorm.bias,backbone.encoder.layer.9.intermediate.dense.weight,backbone.encoder.layer.7.output.LayerNorm.weight,backbone.encoder.layer.5.attention.self.value.bias,backbone.encoder.layer.3.attention.output.dense.weight,backbone.encoder.layer.7.attention.self.value.bias,backbone.encoder.layer.11.output.LayerNorm.bias,backbone.encoder.layer.3.output.LayerNorm.weight,backbone.encoder.layer.9.output.dense.bias,backbone.encoder.layer.11.intermediate.dense.weight,backbone.encoder.layer.4.attention.self.value.weight,backbone.encoder.layer.9.attention.output.LayerNorm.weight,backbone.encoder.layer.2.attention.self.value.weight,backbone.encoder.layer.3.attention.self.key.bias,backbone.encoder.layer.6.attention.self.query.weight,backbone.encoder.layer.9.attention.self.value.weight,backbone.encoder.layer.4.output.dense.bias,backbone.encoder.layer.2.output.dense.bias,backbone.encoder.layer.4.attention.output.LayerNorm.bias,backbone.encoder.layer.6.attention.output.dense.weight,backbone.encoder.layer.6.attention.self.query.bias,backbone.embeddings.LayerNorm.bias,backbone.encoder.layer.1.attention.self.query.weight,backbone.encoder.layer.4.attention.self.value.bias,backbone.encoder.layer.10.attention.self.key.weight,backbone.encoder.layer.7.attention.output.LayerNorm.weight,backbone.encoder.layer.5.output.dense.weight,backbone.encoder.layer.1.attention.self.query.bias,backbone.pooler.dense.weight,backbone.encoder.layer.1.output.LayerNorm.weight,backbone.encoder.layer.8.output.LayerNorm.weight,backbone.encoder.layer.10.attention.self.value.weight,backbone.encoder.layer.8.attention.output.dense.bias,backbone.encoder.layer.11.attention.self.key.bias,backbone.encoder.layer.11.attention.self.value.weight,backbone.encoder.layer.6.intermediate.dense.bias,backbone.encoder.layer.2.output.LayerNorm.weight,backbone.encoder.layer.5.attention.output.LayerNorm.bias,backbone.encoder.layer.4.attention.self.key.weight,backbone.encoder.layer.3.attention.output.LayerNorm.bias,backbone.encoder.layer.0.attention.output.dense.weight,backbone.encoder.layer.11.attention.output.LayerNorm.bias,backbone.encoder.layer.6.intermediate.dense.weight,backbone.encoder.layer.8.attention.self.key.bias,backbone.encoder.layer.4.attention.self.query.weight,backbone.encoder.layer.6.attention.self.value.bias,backbone.encoder.layer.9.output.LayerNorm.bias,backbone.encoder.layer.10.intermediate.dense.weight,backbone.encoder.layer.0.intermediate.dense.bias,backbone.encoder.layer.1.attention.self.value.bias,backbone.encoder.layer.9.attention.self.query.bias,backbone.encoder.layer.3.attention.output.dense.bias,backbone.encoder.layer.1.output.LayerNorm.bias,backbone.encoder.layer.11.intermediate.dense.bias,backbone.encoder.layer.4.attention.self.query.bias,backbone.encoder.layer.8.attention.self.value.bias,backbone.encoder.layer.1.attention.output.LayerNorm.bias,backbone.encoder.layer.11.attention.self.query.weight,backbone.encoder.layer.11.output.LayerNorm.weight,backbone.encoder.layer.10.output.dense.bias,backbone.encoder.layer.3.attention.self.value.weight,backbone.encoder.layer.6.output.LayerNorm.weight,backbone.encoder.layer.4.intermediate.dense.bias,backbone.encoder.layer.0.attention.self.key.bias,backbone.encoder.layer.1.attention.self.key.weight,backbone.encoder.layer.6.attention.output.dense.bias,backbone.encoder.layer.11.attention.output.dense.bias,backbone.encoder.layer.3.attention.self.value.bias,backbone.encoder.layer.1.attention.self.key.bias,backbone.encoder.layer.7.attention.self.value.weight,backbone.encoder.layer.9.output.LayerNorm.weight,backbone.encoder.layer.3.intermediate.dense.weight,backbone.encoder.layer.0.attention.self.value.weight,backbone.encoder.layer.0.attention.self.query.weight,backbone.encoder.layer.8.intermediate.dense.weight,backbone.encoder.layer.1.output.dense.bias,backbone.encoder.layer.8.attention.output.LayerNorm.bias,backbone.encoder.layer.1.attention.output.LayerNorm.weight,backbone.encoder.layer.10.attention.self.query.bias,backbone.encoder.layer.5.intermediate.dense.bias,backbone.encoder.layer.10.attention.output.dense.weight,backbone.embeddings.position_embeddings.weight,backbone.encoder.layer.2.attention.self.key.weight,backbone.encoder.layer.5.output.dense.bias,backbone.encoder.layer.1.attention.self.value.weight,backbone.encoder.layer.5.attention.output.dense.weight,backbone.encoder.layer.4.output.dense.weight,backbone.encoder.layer.7.attention.output.dense.bias,backbone.encoder.layer.7.attention.self.key.weight,backbone.encoder.layer.5.attention.output.dense.bias,backbone.encoder.layer.7.intermediate.dense.weight,backbone.encoder.layer.11.output.dense.bias,backbone.embeddings.LayerNorm.weight,backbone.encoder.layer.11.attention.self.key.weight,backbone.encoder.layer.4.attention.output.LayerNorm.weight,backbone.encoder.layer.3.attention.output.LayerNorm.weight,backbone.encoder.layer.8.output.dense.weight,backbone.encoder.layer.10.attention.output.LayerNorm.bias,backbone.encoder.layer.4.intermediate.dense.weight,backbone.encoder.layer.6.attention.self.key.weight,backbone.encoder.layer.11.attention.output.dense.weight,backbone.encoder.layer.2.attention.self.query.bias,backbone.encoder.layer.3.attention.self.query.bias,backbone.encoder.layer.4.output.LayerNorm.bias,backbone.encoder.layer.2.attention.output.dense.bias,backbone.encoder.layer.7.output.dense.bias,backbone.encoder.layer.5.attention.self.value.weight,backbone.encoder.layer.7.attention.output.LayerNorm.bias,backbone.pooler.dense.bias,backbone.encoder.layer.1.intermediate.dense.weight,backbone.encoder.layer.11.attention.self.query.bias,backbone.encoder.layer.5.output.LayerNorm.bias,backbone.encoder.layer.7.attention.output.dense.weight,backbone.encoder.layer.8.attention.self.query.weight,classifier.weight,backbone.encoder.layer.0.output.LayerNorm.weight,backbone.encoder.layer.4.attention.output.dense.weight,backbone.encoder.layer.2.output.LayerNorm.bias,backbone.encoder.layer.0.attention.output.dense.bias,backbone.encoder.layer.10.attention.self.value.bias,backbone.encoder.layer.2.attention.output.LayerNorm.weight,backbone.encoder.layer.2.attention.self.key.bias,backbone.encoder.layer.10.output.dense.weight,backbone.encoder.layer.4.output.LayerNorm.weight,backbone.encoder.layer.8.output.LayerNorm.bias,backbone.encoder.layer.8.attention.self.value.weight,backbone.encoder.layer.9.attention.output.dense.weight,backbone.encoder.layer.3.attention.self.key.weight,backbone.encoder.layer.2.intermediate.dense.weight,backbone.encoder.layer.4.attention.self.key.bias,backbone.encoder.layer.0.intermediate.dense.weight,backbone.encoder.layer.9.attention.self.value.bias,backbone.encoder.layer.11.attention.self.value.bias,backbone.encoder.layer.1.output.dense.weight,backbone.encoder.layer.3.output.dense.weight,backbone.encoder.layer.5.attention.output.LayerNorm.weight,backbone.encoder.layer.5.attention.self.query.weight,backbone.encoder.layer.8.intermediate.dense.bias,backbone.encoder.layer.1.attention.output.dense.bias,backbone.encoder.layer.0.output.dense.weight,backbone.encoder.layer.6.attention.self.value.weight,backbone.encoder.layer.1.attention.output.dense.weight,backbone.encoder.layer.6.attention.self.key.bias,backbone.encoder.layer.6.output.dense.weight,backbone.encoder.layer.4.attention.output.dense.bias,backbone.encoder.layer.7.output.dense.weight,backbone.encoder.layer.5.output.LayerNorm.weight,backbone.encoder.layer.6.output.dense.bias,backbone.encoder.layer.9.intermediate.dense.bias,backbone.encoder.layer.0.attention.self.value.bias,backbone.encoder.layer.10.attention.output.LayerNorm.weight,backbone.encoder.layer.9.attention.output.LayerNorm.bias,backbone.encoder.layer.2.attention.self.value.bias,backbone.encoder.layer.3.intermediate.dense.bias,backbone.encoder.layer.10.attention.self.key.bias,backbone.encoder.layer.9.attention.self.key.weight,backbone.encoder.layer.1.intermediate.dense.bias,backbone.encoder.layer.8.attention.output.LayerNorm.weight,backbone.encoder.layer.2.output.dense.weight,backbone.encoder.layer.8.attention.output.dense.weight,backbone.encoder.layer.7.attention.self.key.bias,backbone.encoder.layer.0.attention.self.key.weight,backbone.encoder.layer.8.attention.self.query.bias]. All keys are initialized. [2022-07-14 14:08:41,866 INFO] Using SimplePredict to predict... 145it [00:30, 4.68it/s]
print('Labeled samples:') ! tail -n 5 dev.csv print('Predicted results:') ! tail -n 5 dev.pred.csv
Labeled samples: 人 们 愿 意 与 他 做 生 意 , 有 时 商 业 事 务 通 过 电 话 即 可 办 理 。 O O O O O O O O O O O O O O O O O O O O O O O O O 经 过 十 几 年 的 努 力 , 他 已 成 为 世 界 最 大 的 私 人 集 装 箱 船 船 主 。 O O O O O O O O O O O O O O O O O O O O O O O O O O O 妻 贤 子 孝 家 庭 幸 福 O O O O O O O O 希 腊 人 将 瓦 西 里 斯 与 奥 纳 西 斯 比 较 时 总 不 忘 补 充 一 句 : 他 和 奥 纳 西 斯 不 同 , 他 没 有 改 组 家 庭 。 B-LOC I-LOC O O B-PER I-PER I-PER I-PER O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O 重 视 传 统 家 庭 观 念 的 希 腊 人 , 对 瓦 西 里 斯 幸 福 的 家 庭 充 满 赞 誉 。 O O O O O O O O O B-LOC I-LOC O O O B-PER I-PER I-PER I-PER O O O O O O O O O O Predicted results: [] O O O O O O O O O O O O O O O O O O O O O O O O O [] O O O O O O O O O O O O O O O O O O O O O O O O O O O [] O O O O O O O O [{'word': '希', 'tag': 'LOC', 'start': 0, 'end': 1}, {'word': '腊', 'tag': 'LOC', 'start': 3, 'end': 4}, {'word': '瓦', 'tag': 'PER', 'start': 12, 'end': 13}, {'word': '奥', 'tag': 'PER', 'start': 27, 'end': 28}, {'word': '奥', 'tag': 'PER', 'start': 78, 'end': 79}] B-LOC I-LOC O O B-PER I-PER I-PER I-PER O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O [{'word': '希', 'tag': 'LOC', 'start': 27, 'end': 28}] O O O O O O O O O B-LOC I-LOC O O O B-PER I-PER I-PER I-PER O O O O O O O O O O
上面展示了数据集中的五条数据以及经过训练以后模型的预测结果。可以看出,模型对于上述样本的拟合结果较为准确。
一步执行
值得一提的是,上述所有训练/评估/预测代码,都已经被集成在EasyNLP/examples/appzoo_tutorials/sequence_labeling/main.py中,此外,我们也预先编写好了多种可供直接执行的脚本。用户可以通过带参数运行上述main.py文件,或者直接执行脚本文件run_user_defined_local.sh的方式,一步执行上述所有训练/评估/预测操作。
mian.py文件一步执行
用户通过以下代码带参数执行main.py中的指令,可直接对模型进行训练/评估/预测操作。 训练代码指令如下。具体的参数解释可见上文,此处不再赘述。
模型训练代码如下:
! python main.py \ --mode train \ --tables=./train.csv,./dev.csv \ --input_schema=content:str:1,label:str:1 \ --first_sequence=content \ --label_name=label \ --label_enumerate_values=B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O \ --checkpoint_dir=./seq_labeling/ \ --learning_rate=3e-5 \ --epoch_num=1 \ --save_checkpoint_steps=50 \ --sequence_length=128 \ --micro_batch_size=32 \ --app_name=sequence_labeling \ --user_defined_parameters='pretrain_model_name_or_path=hfl/chinese-roberta-wwm-ext'
模型评估代码如下:
! python main.py \ --mode=evaluate \ --tables=dev.csv \ --input_schema=content:str:1,label:str:1 \ --first_sequence=content \ --label_name=label \ --label_enumerate_values=B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O \ --checkpoint_dir=./seq_labeling/ \ --sequence_length=128 \ --micro_batch_size=32 \ --app_name=sequence_labeling
模型预测代码如下:
! python main.py \ --mode=predict \ --tables=dev.csv \ --outputs=dev.pred.csv \ --input_schema=content:str:1,label:str:1 \ --output_schema=output \ --append_cols=label \ --first_sequence=content \ --checkpoint_path=./seq_labeling/ \ --micro_batch_size 32 \ --sequence_length=128 \ --app_name=sequence_labeling
利用bash文件命令行执行
我们在EasyNLP/examples/appzoo_tutorials/sequence_labeling文件夹下封装好了多种可直接执行的脚本,用户同样可以通过带参数执行脚本文件的方式来一步完成模型的训练/评估/预测。以下以run_user_defined_local.sh脚本为例。该脚本文件需要传入两个参数,第一个参数为运行程序的GPU编号,一般为0;第二个参数代表模型的训练/评估/预测。
模型训练:
! cd examples/appzoo_tutorials/sequence_labeling ! bash run_user_defined_local.sh 0 train
模型评估:
! bash run_user_defined_local.sh 0 evaluate
模型预测:
! bash run_user_defined_local.sh 0 predict