直接使用
请打开基于预训练模型的多场景文本生成(以新闻标题生成为例),并点击右上角 “ 在DSW中打开” 。
基于预训练模型的多场景文本生成(以新闻标题生成为例)
文本生成的目标是基于给定文本指引,由模型生成对应的文本段,具有丰富的应用场景,包括文本摘要、新闻标题生成、文案生成、问题生成、作文生成、古诗生成、文本纠错、写对联等。在开源代码库EasyNLP中,我们集成了目前先进模型在上述应用场景中微调过的模型,方便用户进行具体业务场景的进一步训练和预测。以下将以新闻标题生成(生成短摘要)为例,展示如何利用EasyNLP进行文本生成相关任务的全链路过程。同时,下表展示了目前EasyNLP提供的各场景模型以及对应场景中可以使用的demo数据集,欢迎进行尝试。注:最新的模型信息可见列表。
数据地址格式为:http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/替换为数据名称.tsv
任务 |
可选模型 |
demo数据集 |
文本生成 |
hfl/bart-generation-base-zh,hfl/bart-generation-large-zh |
/ |
新闻标题生成 |
alibaba-pai/mt5-title-generation-zh,alibaba-pai/randeng-essay-generation-base-zh,alibaba-pai/randeng-question-generation-base-zh |
|
文案生成 |
alibaba-pai/randeng-advertise-generation-base-zh |
|
问题生成 |
alibaba-pai/randeng-question-generation-base-zh,alibaba-pai/bart-question-generation-large-zh |
|
作文生成 |
alibaba-pai/randeng-essay-generation-base-zh,alibaba-pai/glm-essay-generation-large-zh |
|
古诗生成 |
alibaba-pai/bart-poem-generation-large-zh,alibaba-pai/randeng-poem-generation-base-zh |
文本摘要(Text Summarization)旨在从冗长、重复的文本序列中抽取、精炼或总结出其中的要点信息。T5是由谷歌提出的一个序列到序列预训练模型,它将不同的生成任务进行统一,在兼顾迁移性的前提下取得了文本生成领域的最佳性能。mT5是T5的多语言版本,该模型利用包含101种语言的语料训练得到多语言预训练模型。
在EasyNLP中,我们提供了经过训练的mT5,以便用户能够受益于模型强大的建模能力。该模型是在mT5的基础上利用新闻数据进行微调得到。本文将以新闻标题生成任务为例,将mT5作为模型底座构建标题生成模型,展示如何利用EasyNLP进行模型构建、训练、评估、预测。
运行环境要求
PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存32G
EasyNLP安装
建议从GitHub下载EasyNLP源代码进行安装,命令如下:
! echo y | pip uninstall pai-easynlp easynlp ! git clone https://github.com/alibaba/EasyNLP.git ! pip install -r EasyNLP/requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ ! cd EasyNLP && python setup.py install
安装完成easynlp之后,建议重启notebook,防止环境存在缓存,未更新
您可以使用如下命令验证是否安装成功:
import easynlp easynlp.__file__
如果您系统内已经安装完easynlp的CLI工具,则说明EasyNLP代码库已经安装。
数据准备
首先,您需要下载用于本示例的训练和测试集,并创建保存模型的文件夹,命令如下:
! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_train.tsv ! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_dev.tsv
--2022-11-01 10:26:29-- http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_train.tsv Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27 Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 351006 (343K) [text/tab-separated-values] Saving to: ‘cn_train.tsv’ cn_train.tsv 100%[===================>] 342.78K 2.02MB/s in 0.2s 2022-11-01 10:26:29 (2.02 MB/s) - ‘cn_train.tsv’ saved [351006/351006] --2022-11-01 10:26:29-- http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_dev.tsv Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27 Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 175713 (172K) [text/tab-separated-values] Saving to: ‘cn_dev.tsv’ cn_dev.tsv 100%[===================>] 171.59K --.-KB/s in 0.07s 2022-11-01 10:26:30 (2.30 MB/s) - ‘cn_dev.tsv’ saved [175713/175713]
数据下载完成后,可以通过以下代码查看第一条数据。在训练集验证集中,每一行为一条新闻数据,包括新闻标题和新闻内容,两者通过制表符(\t)隔开。
print('Training data sample:') ! head -n 1 cn_train.tsv print('Development set data sample:') ! head -n 1 cn_dev.tsv
Training data sample: 远离烟草!不止二手烟,还有三手烟! 在广州第一人民医院,一个上午6名患者做支气管镜检查,5人查出肺癌,且4人是老烟民!专家称,吸烟和被动吸烟是肺癌的主要元凶,而二手烟、三手烟(即吸烟后滞留在室内或衣服、头发等的微粒和气体)与吸烟危害性一样大! Development set data sample: #全国低碳城市十强#排行榜合肥居榜首 中科院上海高等研究院日前发布《中国低碳城市建设报告》,合肥、广州、南京排全国低碳城市十强榜前三。城市评价的五个领域:经济社会特征、基础设施建设特征、城市能源消耗特征、城市交通运输和城市环境影响特征。同意的点赞,质疑的转发
初始化
在Python 3.6环境下,我们首先从刚刚安装好的EasyNLP中引入模型运行需要的各种库,并做一些初始化。在本教程中,我们使用mt5-title-generation-zh作为预训练模型底座。
# 为了避免EasyNLP中的args与Jupyter系统的冲突,需要手动设置,否则无法进行初始化。 # 在命令行或py文件中运行文中代码则可忽略下述代码。 import sys sys.argv = ['main.py']
import imp import sys import os import torch.cuda sys.path.append('./') from easynlp.core import Trainer from easynlp.appzoo.sequence_generation.data import SequenceGenerationDataset from easynlp.appzoo.sequence_generation.model import SequenceGeneration from easynlp.appzoo.sequence_generation.evaluator import SequenceGenerationEvaluator from easynlp.appzoo.sequence_generation.predictor import SequenceGenerationPredictor from easynlp.appzoo import get_application_model_for_evaluation from easynlp.utils import initialize_easynlp, get_args from easynlp.utils.global_vars import parse_user_defined_parameters from easynlp.core import PredictorManager from easynlp.utils import get_pretrain_model_path initialize_easynlp() args = get_args() user_defined_parameters='pretrain_model_name_or_path=hfl/randeng-summary-generation-base-zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5' user_defined_parameters = parse_user_defined_parameters(user_defined_parameters) args.checkpoint_dir = "./finetuned_zh_model"
[2022-12-08 17:49:44,415.415 dsw34730-bbd74fb77-5264h:173618 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
[2022-12-08 17:49:44,781] [WARNING] [partition_parameters.py:61:<module>] unable to find torch.distributed._all_gather_base. will fall back to torch.distributed.all_gather which will result in suboptimal performance. please consider upgrading your pytorch installation.
/home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'. _np_qint8 = np.dtype([("qint8", np.int8, 1)]) /home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:524: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'. _np_quint8 = np.dtype([("quint8", np.uint8, 1)]) /home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'. _np_qint16 = np.dtype([("qint16", np.int16, 1)]) /home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'. _np_quint16 = np.dtype([("quint16", np.uint16, 1)]) /home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'. _np_qint32 = np.dtype([("qint32", np.int32, 1)]) /home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:532: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'. np_resource = np.dtype([("resource", np.ubyte, 1)]) /home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release. from cryptography import x509 Please ignore the following import error if you are using tunnel table io. No module named '_common_io'
No module named 'easy_predict' The following parameters are not recognized: [] ------------------------ arguments ------------------------ adam_beta1 ...................................... 0.9 adam_beta2 ...................................... 0.999 adam_eps ........................................ 1e-08 adapet .......................................... False app_name ........................................ text_classify append_cols ..................................... None attention_dropout ............................... 0.1 attention_scale ................................. 1.0 avg_block_length ................................ 3 batch_size ...................................... 4 bert_prob ....................................... 0.5 blank_maskratio ................................. 0.1 block_lm ........................................ False block_lm_ratio .................................. 0.0 block_mask_prob ................................. 0.0 buckets ......................................... None cache_dir ....................................... None checkpoint_activations .......................... False checkpoint_dir .................................. None checkpoint_num_layers ........................... 1 chief_hosts ..................................... clip_grad ....................................... 1.0 cloze_eval ...................................... False context_mask_ratio .............................. 0.0 continuous_prompt ............................... False cpu_optimizer ................................... False cpu_torch_adam .................................. False data_dir ........................................ ./ data_threads .................................... 10 DDP_impl ........................................ torch deep_init ....................................... False deepspeed_activation_checkpointing .............. False delim ........................................... , distributed_backend ............................. nccl do_lower_case ................................... False encoder_decoder ................................. False epoch_num ....................................... 3 epochs .......................................... None eval_batch_size ................................. None eval_epoch ...................................... 1 eval_interval ................................... 1000 eval_iters ...................................... 100 eval_max_preds_per_seq .......................... None eval_seq_length ................................. None eval_text_key ................................... None eval_valid ...................................... False experiment_name ................................. glm-356M export_tf_checkpoint_type ....................... easytransfer fast_decode ..................................... False few_superglue ................................... False filter_english .................................. False finetune ........................................ False first_sequence .................................. None fix_command_token ............................... False fp16 ............................................ False fp32_allreduce .................................. False fp32_embedding .................................. False fp32_layernorm .................................. False fp32_tokentypes ................................. False freeze_transformer .............................. False gap_sentence_prob ............................... 0.0 gap_sentence_ratio .............................. 0.15 gpt_infill_prob ................................. 0.5 gpt_min_ratio ................................... 0.5 gradient_accumulation_steps ..................... 1 half_lazy_loader ................................ False hidden_dropout .................................. 0.1 hidden_size ..................................... 1024 hysteresis ...................................... 2 input_data_sizes_file ........................... sizes.txt input_schema .................................... None intermediate_size ............................... None is_chief ........................................ is_master_node .................................. True job_name ........................................ None label_enumerate_values .......................... None label_name ...................................... None label_smoothing ................................. 0.0 layernorm_epsilon ............................... 1e-05 learning_rate ................................... 5e-05 length_penalty .................................. 0.0 load ............................................ None load_pretrained ................................. None load_splits ..................................... None loader_scatter .................................. None local_rank ...................................... None log_interval .................................... 100 logging_steps ................................... 100 loose_json ...................................... False loss_func ....................................... cross_entropy loss_scale ...................................... None loss_scale_window ............................... 1000 lr .............................................. 0.0001 lr_decay_iters .................................. None lr_decay_ratio .................................. 0.1 lr_decay_style .................................. linear make_vocab_size_divisible_by .................... 128 masked_lm ....................................... False master_port ..................................... 23456 max_grad_norm ................................... 1.0 max_position_embeddings ......................... 512 max_preds_per_seq ............................... None mem_length ...................................... 0 mg_model ........................................ False micro_batch_size ................................ 2 min_scale ....................................... 1 min_tgt_length .................................. 0 mode ............................................ train model_parallel_size ............................. 1 modelzoo_base_dir ............................... multi_batch_size ................................ None multi_seq_length ................................ None multi_task_data ................................. None multi_task_ratio ................................ 0.0 multi_token ..................................... False n_cpu ........................................... 1 n_gpu ........................................... 1 new_save_directory .............................. False no_block_position ............................... False no_deepspeed_load ............................... False no_lazy_loader .................................. False no_load_lr_scheduler ............................ False no_load_optim ................................... False no_load_rng ..................................... False no_pre_tokenize ................................. False no_repeat_ngram_size ............................ 0 no_save_optim ................................... False no_save_rng ..................................... False no_shuffle_block ................................ False no_validation ................................... False non_sentence_start .............................. 0.0 num_attention_heads ............................. 16 num_beams ....................................... 1 num_layers ...................................... 24 num_prompt_tokens ............................... 0 num_workers ..................................... 2 odps_config ..................................... None optimizer_type .................................. AdamW out_seq_length .................................. 256 output_dropout .................................. 0.1 output_schema ................................... outputs ......................................... None overlapping_eval ................................ 32 overwrite ....................................... False pattern_id ...................................... 0 pool_token ...................................... cls predict_queue_size .............................. 1024 predict_slice_size .............................. 4096 predict_table_read_thread_num ................... 16 predict_thread_num .............................. 2 prefix_prompt ................................... 0 presplit_sentences .............................. False pretrained_bert ................................. False prompt_func ..................................... lstm prompt_init ..................................... False ps_hosts ........................................ random_position ................................. False random_seed ..................................... 1234 rank ............................................ 0 read_odps ....................................... False reset_attention_mask ............................ False reset_position_ids .............................. False restore_works_dir ............................... ./.easynlp_predict_restore_works_dir resume_dataloader ............................... False resume_from_checkpoint .......................... None sample_one_document ............................. False save ............................................ None save_all_checkpoints ............................ False save_checkpoint_steps ........................... None save_epoch ...................................... 1 save_interval ................................... 5000 save_splits ..................................... None save_test_data .................................. None second_sequence ................................. None seed ............................................ 1234 segment_length .................................. 0 select_topk ..................................... False sentinel_token .................................. False seq_length ...................................... 512 sequence_length ................................. 16 short_seq_prob .................................. 0.0 shuffle ......................................... False single_span_prob ................................ 0.0 skip_first_line ................................. False split ........................................... 1000,1,1 src_seq_length .................................. None summary_dir ..................................... switch_linear ................................... False tables .......................................... None task ............................................ chinesegen task_count ...................................... 1 task_index ...................................... 0 task_mask ....................................... False temperature ..................................... 1.0 test_data ....................................... None text_key ........................................ sentence tgt_seq_length .................................. None tokenizer_model_type ............................ None tokenizer_path .................................. tokenizer.model tokenizer_type .................................. ChineseSPTokenizer top_k ........................................... 0 top_p ........................................... 0.0 train_data ...................................... None train_iters ..................................... 0 transformer_xl .................................. False tune_prefix_layers .............................. None unidirectional .................................. False use_amp ......................................... False use_tfrecords ................................... False use_torchacc .................................... False user_defined_parameters ......................... None user_entry_file ................................. None user_script ..................................... None valid_data ...................................... None validation_metric ............................... None vocab_size ...................................... 30522 warmup .......................................... 0.01 warmup_proportion ............................... 0.1 weight_decay .................................... 0.0001 worker_count .................................... 1 worker_cpu ...................................... -1 worker_gpu ...................................... -1 worker_hosts .................................... None world_size ...................................... 1 wsc_negative .................................... False -------------------- end of arguments --------------------- > initializing torch distributed ...
[2022-12-08 17:49:47,752.752 dsw34730-bbd74fb77-5264h:173618 INFO distributed_c10d.py:195] Added key: store_based_barrier_key:1 to store for rank: 0
Init dist done. World size: 1, rank 0, l_rank 0 > setting random seeds to 1234 ...
注意:上述代码如果出现“Address already in use”错误,则需要运行以下代码清理端口(默认为6000)上正在执行的程序。
netstat -tunlp|grep 6000
kill -9 PID (需要替换成上一行代码执行结果中对应的程序ID)
载入数据
我们使用EasyNLP中自带的SequenceGenerationDataset,对训练和测试数据进行载入。主要参数如下:
- pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”mt5-title-generation-zh”,并自动下载模型
- max_seq_length:文本最大长度,超过将截断,不足将padding
- input_schema:输入数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等
- first_sequence、label_name:用于说明input_schema中哪些字段用于作为输入句子和标签列等
- label_enumerate_values:label类型列举
- is_training:是否为训练过程,train_dataset为True,valid_dataset为False
- app_name:指定当前需要执行的任务,如文本分类、序列标注、文本匹配、文本生成等
下面我们将手动设置一些参数以便进行实验。
args.tables = "./cn_train.tsv,./cn_dev.tsv" args.input_schema = "title_tokens:str:1,content_tokens:str:1" args.first_sequence = "content_tokens" args.second_sequence = "title_tokens" args.label_name = "title_tokens" args.learning_rate = 3e-5 args.epoch_num = 1 args.save_checkpoint_steps = 150 args.sequence_length = 512 args.micro_batch_size = 8 args.export_tf_checkpoint_type = "none" args.app_name = "sequence_generation" args.pretrained_model_name_or_path = user_defined_parameters.get('pretrain_model_name_or_path', None) args.pretrained_model_name_or_path = get_pretrain_model_path(args.pretrained_model_name_or_path) train_dataset = SequenceGenerationDataset( pretrained_model_name_or_path=args.pretrained_model_name_or_path, data_file=args.tables.split(",")[0], max_seq_length=args.sequence_length, input_schema=args.input_schema, first_sequence=args.first_sequence, second_sequence=args.second_sequence, user_defined_parameters=user_defined_parameters, is_training=True) valid_dataset = SequenceGenerationDataset( pretrained_model_name_or_path=args.pretrained_model_name_or_path, data_file=args.tables.split(",")[-1], max_seq_length=args.sequence_length, input_schema=args.input_schema, first_sequence=args.first_sequence, second_sequence=args.second_sequence, user_defined_parameters=user_defined_parameters, is_training=False)
`/root/.easynlp/modelzoo/alibaba-pai/mt5-title-generation-zh.tgz` already exists ****./cn_train.tsv ****./cn_dev.tsv
模型训练
处理好数据与模型载入后,我们开始训练模型。 我们使用EasyNLP中封装好的SequenceGeneration函数进行训练时的模型构建,其参数如下:
- pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”mt5-title-generation-zh”,并自动下载模型
- user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters
构建模型并读取
model = SequenceGeneration(pretrained_model_name_or_path=args.pretrained_model_name_or_path, user_defined_parameters=user_defined_parameters)
**language** parameter is not provided in user defined parameters, using zh as default.
Loaded weights of the model: [shared.weight,encoder.embed_tokens.weight,encoder.block.0.layer.0.SelfAttention.q.weight,encoder.block.0.layer.0.SelfAttention.k.weight,encoder.block.0.layer.0.SelfAttention.v.weight,encoder.block.0.layer.0.SelfAttention.o.weight,encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,encoder.block.0.layer.0.layer_norm.weight,encoder.block.0.layer.1.DenseReluDense.wi_0.weight,encoder.block.0.layer.1.DenseReluDense.wi_1.weight,encoder.block.0.layer.1.DenseReluDense.wo.weight,encoder.block.0.layer.1.layer_norm.weight,encoder.block.1.layer.0.SelfAttention.q.weight,encoder.block.1.layer.0.SelfAttention.k.weight,encoder.block.1.layer.0.SelfAttention.v.weight,encoder.block.1.layer.0.SelfAttention.o.weight,encoder.block.1.layer.0.layer_norm.weight,encoder.block.1.layer.1.DenseReluDense.wi_0.weight,encoder.block.1.layer.1.DenseReluDense.wi_1.weight,encoder.block.1.layer.1.DenseReluDense.wo.weight,encoder.block.1.layer.1.layer_norm.weight,encoder.block.2.layer.0.SelfAttention.q.weight,encoder.block.2.layer.0.SelfAttention.k.weight,encoder.block.2.layer.0.SelfAttention.v.weight,encoder.block.2.layer.0.SelfAttention.o.weight,encoder.block.2.layer.0.layer_norm.weight,encoder.block.2.layer.1.DenseReluDense.wi_0.weight,encoder.block.2.layer.1.DenseReluDense.wi_1.weight,encoder.block.2.layer.1.DenseReluDense.wo.weight,encoder.block.2.layer.1.layer_norm.weight,encoder.block.3.layer.0.SelfAttention.q.weight,encoder.block.3.layer.0.SelfAttention.k.weight,encoder.block.3.layer.0.SelfAttention.v.weight,encoder.block.3.layer.0.SelfAttention.o.weight,encoder.block.3.layer.0.layer_norm.weight,encoder.block.3.layer.1.DenseReluDense.wi_0.weight,encoder.block.3.layer.1.DenseReluDense.wi_1.weight,encoder.block.3.layer.1.DenseReluDense.wo.weight,encoder.block.3.layer.1.layer_norm.weight,encoder.block.4.layer.0.SelfAttention.q.weight,encoder.block.4.layer.0.SelfAttention.k.weight,encoder.block.4.layer.0.SelfAttention.v.weight,encoder.block.4.layer.0.SelfAttention.o.weight,encoder.block.4.layer.0.layer_norm.weight,encoder.block.4.layer.1.DenseReluDense.wi_0.weight,encoder.block.4.layer.1.DenseReluDense.wi_1.weight,encoder.block.4.layer.1.DenseReluDense.wo.weight,encoder.block.4.layer.1.layer_norm.weight,encoder.block.5.layer.0.SelfAttention.q.weight,encoder.block.5.layer.0.SelfAttention.k.weight,encoder.block.5.layer.0.SelfAttention.v.weight,encoder.block.5.layer.0.SelfAttention.o.weight,encoder.block.5.layer.0.layer_norm.weight,encoder.block.5.layer.1.DenseReluDense.wi_0.weight,encoder.block.5.layer.1.DenseReluDense.wi_1.weight,encoder.block.5.layer.1.DenseReluDense.wo.weight,encoder.block.5.layer.1.layer_norm.weight,encoder.block.6.layer.0.SelfAttention.q.weight,encoder.block.6.layer.0.SelfAttention.k.weight,encoder.block.6.layer.0.SelfAttention.v.weight,encoder.block.6.layer.0.SelfAttention.o.weight,encoder.block.6.layer.0.layer_norm.weight,encoder.block.6.layer.1.DenseReluDense.wi_0.weight,encoder.block.6.layer.1.DenseReluDense.wi_1.weight,encoder.block.6.layer.1.DenseReluDense.wo.weight,encoder.block.6.layer.1.layer_norm.weight,encoder.block.7.layer.0.SelfAttention.q.weight,encoder.block.7.layer.0.SelfAttention.k.weight,encoder.block.7.layer.0.SelfAttention.v.weight,encoder.block.7.layer.0.SelfAttention.o.weight,encoder.block.7.layer.0.layer_norm.weight,encoder.block.7.layer.1.DenseReluDense.wi_0.weight,encoder.block.7.layer.1.DenseReluDense.wi_1.weight,encoder.block.7.layer.1.DenseReluDense.wo.weight,encoder.block.7.layer.1.layer_norm.weight,encoder.block.8.layer.0.SelfAttention.q.weight,encoder.block.8.layer.0.SelfAttention.k.weight,encoder.block.8.layer.0.SelfAttention.v.weight,encoder.block.8.layer.0.SelfAttention.o.weight,encoder.block.8.layer.0.layer_norm.weight,encoder.block.8.layer.1.DenseReluDense.wi_0.weight,encoder.block.8.layer.1.DenseReluDense.wi_1.weight,encoder.block.8.layer.1.DenseReluDense.wo.weight,encoder.block.8.layer.1.layer_norm.weight,encoder.block.9.layer.0.SelfAttention.q.weight,encoder.block.9.layer.0.SelfAttention.k.weight,encoder.block.9.layer.0.SelfAttention.v.weight,encoder.block.9.layer.0.SelfAttention.o.weight,encoder.block.9.layer.0.layer_norm.weight,encoder.block.9.layer.1.DenseReluDense.wi_0.weight,encoder.block.9.layer.1.DenseReluDense.wi_1.weight,encoder.block.9.layer.1.DenseReluDense.wo.weight,encoder.block.9.layer.1.layer_norm.weight,encoder.block.10.layer.0.SelfAttention.q.weight,encoder.block.10.layer.0.SelfAttention.k.weight,encoder.block.10.layer.0.SelfAttention.v.weight,encoder.block.10.layer.0.SelfAttention.o.weight,encoder.block.10.layer.0.layer_norm.weight,encoder.block.10.layer.1.DenseReluDense.wi_0.weight,encoder.block.10.layer.1.DenseReluDense.wi_1.weight,encoder.block.10.layer.1.DenseReluDense.wo.weight,encoder.block.10.layer.1.layer_norm.weight,encoder.block.11.layer.0.SelfAttention.q.weight,encoder.block.11.layer.0.SelfAttention.k.weight,encoder.block.11.layer.0.SelfAttention.v.weight,encoder.block.11.layer.0.SelfAttention.o.weight,encoder.block.11.layer.0.layer_norm.weight,encoder.block.11.layer.1.DenseReluDense.wi_0.weight,encoder.block.11.layer.1.DenseReluDense.wi_1.weight,encoder.block.11.layer.1.DenseReluDense.wo.weight,encoder.block.11.layer.1.layer_norm.weight,encoder.final_layer_norm.weight,decoder.embed_tokens.weight,decoder.block.0.layer.0.SelfAttention.q.weight,decoder.block.0.layer.0.SelfAttention.k.weight,decoder.block.0.layer.0.SelfAttention.v.weight,decoder.block.0.layer.0.SelfAttention.o.weight,decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,decoder.block.0.layer.0.layer_norm.weight,decoder.block.0.layer.1.EncDecAttention.q.weight,decoder.block.0.layer.1.EncDecAttention.k.weight,decoder.block.0.layer.1.EncDecAttention.v.weight,decoder.block.0.layer.1.EncDecAttention.o.weight,decoder.block.0.layer.1.layer_norm.weight,decoder.block.0.layer.2.DenseReluDense.wi_0.weight,decoder.block.0.layer.2.DenseReluDense.wi_1.weight,decoder.block.0.layer.2.DenseReluDense.wo.weight,decoder.block.0.layer.2.layer_norm.weight,decoder.block.1.layer.0.SelfAttention.q.weight,decoder.block.1.layer.0.SelfAttention.k.weight,decoder.block.1.layer.0.SelfAttention.v.weight,decoder.block.1.layer.0.SelfAttention.o.weight,decoder.block.1.layer.0.layer_norm.weight,decoder.block.1.layer.1.EncDecAttention.q.weight,decoder.block.1.layer.1.EncDecAttention.k.weight,decoder.block.1.layer.1.EncDecAttention.v.weight,decoder.block.1.layer.1.EncDecAttention.o.weight,decoder.block.1.layer.1.layer_norm.weight,decoder.block.1.layer.2.DenseReluDense.wi_0.weight,decoder.block.1.layer.2.DenseReluDense.wi_1.weight,decoder.block.1.layer.2.DenseReluDense.wo.weight,decoder.block.1.layer.2.layer_norm.weight,decoder.block.2.layer.0.SelfAttention.q.weight,decoder.block.2.layer.0.SelfAttention.k.weight,decoder.block.2.layer.0.SelfAttention.v.weight,decoder.block.2.layer.0.SelfAttention.o.weight,decoder.block.2.layer.0.layer_norm.weight,decoder.block.2.layer.1.EncDecAttention.q.weight,decoder.block.2.layer.1.EncDecAttention.k.weight,decoder.block.2.layer.1.EncDecAttention.v.weight,decoder.block.2.layer.1.EncDecAttention.o.weight,decoder.block.2.layer.1.layer_norm.weight,decoder.block.2.layer.2.DenseReluDense.wi_0.weight,decoder.block.2.layer.2.DenseReluDense.wi_1.weight,decoder.block.2.layer.2.DenseReluDense.wo.weight,decoder.block.2.layer.2.layer_norm.weight,decoder.block.3.layer.0.SelfAttention.q.weight,decoder.block.3.layer.0.SelfAttention.k.weight,decoder.block.3.layer.0.SelfAttention.v.weight,decoder.block.3.layer.0.SelfAttention.o.weight,decoder.block.3.layer.0.layer_norm.weight,decoder.block.3.layer.1.EncDecAttention.q.weight,decoder.block.3.layer.1.EncDecAttention.k.weight,decoder.block.3.layer.1.EncDecAttention.v.weight,decoder.block.3.layer.1.EncDecAttention.o.weight,decoder.block.3.layer.1.layer_norm.weight,decoder.block.3.layer.2.DenseReluDense.wi_0.weight,decoder.block.3.layer.2.DenseReluDense.wi_1.weight,decoder.block.3.layer.2.DenseReluDense.wo.weight,decoder.block.3.layer.2.layer_norm.weight,decoder.block.4.layer.0.SelfAttention.q.weight,decoder.block.4.layer.0.SelfAttention.k.weight,decoder.block.4.layer.0.SelfAttention.v.weight,decoder.block.4.layer.0.SelfAttention.o.weight,decoder.block.4.layer.0.layer_norm.weight,decoder.block.4.layer.1.EncDecAttention.q.weight,decoder.block.4.layer.1.EncDecAttention.k.weight,decoder.block.4.layer.1.EncDecAttention.v.weight,decoder.block.4.layer.1.EncDecAttention.o.weight,decoder.block.4.layer.1.layer_norm.weight,decoder.block.4.layer.2.DenseReluDense.wi_0.weight,decoder.block.4.layer.2.DenseReluDense.wi_1.weight,decoder.block.4.layer.2.DenseReluDense.wo.weight,decoder.block.4.layer.2.layer_norm.weight,decoder.block.5.layer.0.SelfAttention.q.weight,decoder.block.5.layer.0.SelfAttention.k.weight,decoder.block.5.layer.0.SelfAttention.v.weight,decoder.block.5.layer.0.SelfAttention.o.weight,decoder.block.5.layer.0.layer_norm.weight,decoder.block.5.layer.1.EncDecAttention.q.weight,decoder.block.5.layer.1.EncDecAttention.k.weight,decoder.block.5.layer.1.EncDecAttention.v.weight,decoder.block.5.layer.1.EncDecAttention.o.weight,decoder.block.5.layer.1.layer_norm.weight,decoder.block.5.layer.2.DenseReluDense.wi_0.weight,decoder.block.5.layer.2.DenseReluDense.wi_1.weight,decoder.block.5.layer.2.DenseReluDense.wo.weight,decoder.block.5.layer.2.layer_norm.weight,decoder.block.6.layer.0.SelfAttention.q.weight,decoder.block.6.layer.0.SelfAttention.k.weight,decoder.block.6.layer.0.SelfAttention.v.weight,decoder.block.6.layer.0.SelfAttention.o.weight,decoder.block.6.layer.0.layer_norm.weight,decoder.block.6.layer.1.EncDecAttention.q.weight,decoder.block.6.layer.1.EncDecAttention.k.weight,decoder.block.6.layer.1.EncDecAttention.v.weight,decoder.block.6.layer.1.EncDecAttention.o.weight,decoder.block.6.layer.1.layer_norm.weight,decoder.block.6.layer.2.DenseReluDense.wi_0.weight,decoder.block.6.layer.2.DenseReluDense.wi_1.weight,decoder.block.6.layer.2.DenseReluDense.wo.weight,decoder.block.6.layer.2.layer_norm.weight,decoder.block.7.layer.0.SelfAttention.q.weight,decoder.block.7.layer.0.SelfAttention.k.weight,decoder.block.7.layer.0.SelfAttention.v.weight,decoder.block.7.layer.0.SelfAttention.o.weight,decoder.block.7.layer.0.layer_norm.weight,decoder.block.7.layer.1.EncDecAttention.q.weight,decoder.block.7.layer.1.EncDecAttention.k.weight,decoder.block.7.layer.1.EncDecAttention.v.weight,decoder.block.7.layer.1.EncDecAttention.o.weight,decoder.block.7.layer.1.layer_norm.weight,decoder.block.7.layer.2.DenseReluDense.wi_0.weight,decoder.block.7.layer.2.DenseReluDense.wi_1.weight,decoder.block.7.layer.2.DenseReluDense.wo.weight,decoder.block.7.layer.2.layer_norm.weight,decoder.block.8.layer.0.SelfAttention.q.weight,decoder.block.8.layer.0.SelfAttention.k.weight,decoder.block.8.layer.0.SelfAttention.v.weight,decoder.block.8.layer.0.SelfAttention.o.weight,decoder.block.8.layer.0.layer_norm.weight,decoder.block.8.layer.1.EncDecAttention.q.weight,decoder.block.8.layer.1.EncDecAttention.k.weight,decoder.block.8.layer.1.EncDecAttention.v.weight,decoder.block.8.layer.1.EncDecAttention.o.weight,decoder.block.8.layer.1.layer_norm.weight,decoder.block.8.layer.2.DenseReluDense.wi_0.weight,decoder.block.8.layer.2.DenseReluDense.wi_1.weight,decoder.block.8.layer.2.DenseReluDense.wo.weight,decoder.block.8.layer.2.layer_norm.weight,decoder.block.9.layer.0.SelfAttention.q.weight,decoder.block.9.layer.0.SelfAttention.k.weight,decoder.block.9.layer.0.SelfAttention.v.weight,decoder.block.9.layer.0.SelfAttention.o.weight,decoder.block.9.layer.0.layer_norm.weight,decoder.block.9.layer.1.EncDecAttention.q.weight,decoder.block.9.layer.1.EncDecAttention.k.weight,decoder.block.9.layer.1.EncDecAttention.v.weight,decoder.block.9.layer.1.EncDecAttention.o.weight,decoder.block.9.layer.1.layer_norm.weight,decoder.block.9.layer.2.DenseReluDense.wi_0.weight,decoder.block.9.layer.2.DenseReluDense.wi_1.weight,decoder.block.9.layer.2.DenseReluDense.wo.weight,decoder.block.9.layer.2.layer_norm.weight,decoder.block.10.layer.0.SelfAttention.q.weight,decoder.block.10.layer.0.SelfAttention.k.weight,decoder.block.10.layer.0.SelfAttention.v.weight,decoder.block.10.layer.0.SelfAttention.o.weight,decoder.block.10.layer.0.layer_norm.weight,decoder.block.10.layer.1.EncDecAttention.q.weight,decoder.block.10.layer.1.EncDecAttention.k.weight,decoder.block.10.layer.1.EncDecAttention.v.weight,decoder.block.10.layer.1.EncDecAttention.o.weight,decoder.block.10.layer.1.layer_norm.weight,decoder.block.10.layer.2.DenseReluDense.wi_0.weight,decoder.block.10.layer.2.DenseReluDense.wi_1.weight,decoder.block.10.layer.2.DenseReluDense.wo.weight,decoder.block.10.layer.2.layer_norm.weight,decoder.block.11.layer.0.SelfAttention.q.weight,decoder.block.11.layer.0.SelfAttention.k.weight,decoder.block.11.layer.0.SelfAttention.v.weight,decoder.block.11.layer.0.SelfAttention.o.weight,decoder.block.11.layer.0.layer_norm.weight,decoder.block.11.layer.1.EncDecAttention.q.weight,decoder.block.11.layer.1.EncDecAttention.k.weight,decoder.block.11.layer.1.EncDecAttention.v.weight,decoder.block.11.layer.1.EncDecAttention.o.weight,decoder.block.11.layer.1.layer_norm.weight,decoder.block.11.layer.2.DenseReluDense.wi_0.weight,decoder.block.11.layer.2.DenseReluDense.wi_1.weight,decoder.block.11.layer.2.DenseReluDense.wo.weight,decoder.block.11.layer.2.layer_norm.weight,decoder.final_layer_norm.weight,lm_head.weight]. All weights are initialized.
构建训练器并训练
extra_para = {'pretrained_model_name_or_path':args.pretrained_model_name_or_path} evaluator = SequenceGenerationEvaluator(valid_dataset=valid_dataset, user_defined_parameters=user_defined_parameters, **extra_para) trainer = Trainer(model=model, train_dataset=train_dataset, user_defined_parameters=user_defined_parameters, evaluator=evaluator) trainer.train()
[2022-12-08 17:49:54,705 INFO] ========== Initializing Tensorboard ========== [2022-12-08 17:49:54,729 INFO] ========== Training Start ========== [2022-12-08 17:49:54,731 INFO] Num of GPUs (all) = 1 [2022-12-08 17:49:54,732 INFO] Num of CPUs per worker = 1 [2022-12-08 17:49:54,732 INFO] Num dataset examples = 1000 [2022-12-08 17:49:54,733 INFO] Num training examples = 1000 [2022-12-08 17:49:54,733 INFO] Num validation examples = 500 [2022-12-08 17:49:54,734 INFO] Train. batch size = 8 [2022-12-08 17:49:54,734 INFO] Train. micro batch size = 8 [2022-12-08 17:49:54,734 INFO] Train. batch no. = 125 [2022-12-08 17:49:54,736 INFO] Evaluation batch size = 8 [2022-12-08 17:49:54,737 INFO] Total training steps = 125 [2022-12-08 17:49:54,737 INFO] Sequence length = 512 [2022-12-08 17:49:54,738 INFO] Saving steps = 150 [2022-12-08 17:49:54,739 INFO] Distributed_backend = nccl [2022-12-08 17:49:54,739 INFO] Worker Count = 1 [2022-12-08 17:49:54,739 INFO] Worker CPU = -1 [2022-12-08 17:49:54,740 INFO] Worker data threads = 10 [2022-12-08 17:49:54,743 INFO] num model params = 275,029,248 [2022-12-08 17:49:54,744 INFO] num trainable params = 275,029,248 [2022-12-08 17:49:54,744 INFO] [2022-12-08 17:49:54,744 INFO] ========== Model Config ========== [2022-12-08 17:49:54,745 INFO] { "architectures": [ "T5ForConditionalGeneration" ], "d_ff": 2048, "d_kv": 64, "d_model": 768, "decoder_start_token_id": 0, "dropout_rate": 0.1, "easynlp_version": "0.0.3", "eos_token_id": 1, "feed_forward_proj": "gated-gelu", "initializer_factor": 1.0, "is_encoder_decoder": true, "layer_norm_epsilon": 1e-06, "model_type": "mt5", "num_decoder_layers": 12, "num_heads": 12, "num_layers": 12, "output_past": true, "pad_token_id": 0, "relative_attention_num_buckets": 32, "tie_word_embeddings": false, "tokenizer_class": "T5Tokenizer", "use_cache": true, "vocab_size": 50000 }
optimizer type: AdamW
/home/pai/lib/python3.6/site-packages/pai_easynlp-0.0.9-py3.6.egg/easynlp/core/optimizers.py:441: UserWarning: This overload of add_ is deprecated: add_(Number alpha, Tensor other) Consider using one of the following signatures instead: add_(Tensor other, *, Number alpha) (Triggered internally at /workspace/artifacts/paipytorch1.8/dist/ubuntu18.04-py3.6-cuda10.1/build/src/torch/csrc/utils/python_arg_parser.cpp:1005.) exp_avg.mul_(beta1).add_(1.0 - beta1, grad) /home/pai/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:247: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`. warnings.warn("To get the last learning rate computed by the scheduler, " [2022-12-08 17:50:14,705 INFO] Epoch [ 1/ 1], step [100/125], lr 0.000007, 19.96 s [2022-12-08 17:50:14,707 INFO] loss : 3.6508
Training Time: 25.2710120677948, rank 0, gsteps 125
100%|██████████| 500/500 [03:19<00:00, 2.51it/s] [2022-12-08 17:53:39,253 INFO] Saving best model to ./finetuned_zh_model/pytorch_model.bin...
Rouge 1/2/L: 35.91/21.85/32.54
[2022-12-08 17:53:57,904 INFO] Best score: 32.535939404637126 [2022-12-08 17:53:57,907 INFO] Training Time: 243.49750781059265
模型评估
训练过程结束后,模型被我们保存在一开始指定好的checkpoint_dir中,本地路径为”./finetuned_zh_model/”。我们可以对训练好的模型进行效果评估。我们使用EasyNLP中的SequenceGenerationEvaluator来初始化evaluator,并将模型迁移至GPU机器,进行模型评估。
args.tables = "cn_dev.tsv" extra_para = {'pretrained_model_name_or_path':args.pretrained_model_name_or_path} evaluator = SequenceGenerationEvaluator(valid_dataset=valid_dataset, user_defined_parameters=user_defined_parameters, **extra_para) if args.n_gpu > 0: model.to(torch.cuda.current_device()) else: model.to("cpu") evaluator.evaluate(model=model)
/root/.easynlp/modelzoo/alibaba-pai/mt5-title-generation-zh
100%|██████████| 500/500 [02:50<00:00, 2.93it/s]
Rouge 1/2/L: 33.51/18.77/30.09
[('rouge-l', 30.08704702646614), ('rouge-1', 33.50990449925101), ('rouge-2', 18.76897464541499)]
模型预测
我们同样可以使用训练好的模型进行新闻标题生成。我们首先创建一个predictor,并据此实例化一个PredictorManager实例。我们指定预测好的结果输出在cn.preds.txt。
args.tables = "cn_dev.tsv" args.outputs = "cn.preds.txt" args.input_schema = "title_tokens:str:1,content_tokens:str:1" args.output_schema = "predictions,beams" args.append_cols="title_tokens,content_tokens" args.micro_batch_size = 32 extra_para = {'pretrained_model_name_or_path':args.pretrained_model_name_or_path, 'max_encoder_length':args.sequence_length} predictor = SequenceGenerationPredictor(model_dir=args.checkpoint_dir, model_cls=SequenceGeneration, first_sequence=args.first_sequence, user_defined_parameters=user_defined_parameters, **extra_para) predictor_manager = PredictorManager( predictor=predictor, input_file=args.tables.split(",")[0], input_schema=args.input_schema, output_file=args.outputs, output_schema=args.output_schema, append_cols=args.append_cols, batch_size=args.micro_batch_size ) predictor_manager.run()
**language** parameter is not provided in user defined parameters, using zh as default.
Loaded weights of the model: [shared.weight,encoder.embed_tokens.weight,encoder.block.0.layer.0.SelfAttention.q.weight,encoder.block.0.layer.0.SelfAttention.k.weight,encoder.block.0.layer.0.SelfAttention.v.weight,encoder.block.0.layer.0.SelfAttention.o.weight,encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,encoder.block.0.layer.0.layer_norm.weight,encoder.block.0.layer.1.DenseReluDense.wi_0.weight,encoder.block.0.layer.1.DenseReluDense.wi_1.weight,encoder.block.0.layer.1.DenseReluDense.wo.weight,encoder.block.0.layer.1.layer_norm.weight,encoder.block.1.layer.0.SelfAttention.q.weight,encoder.block.1.layer.0.SelfAttention.k.weight,encoder.block.1.layer.0.SelfAttention.v.weight,encoder.block.1.layer.0.SelfAttention.o.weight,encoder.block.1.layer.0.layer_norm.weight,encoder.block.1.layer.1.DenseReluDense.wi_0.weight,encoder.block.1.layer.1.DenseReluDense.wi_1.weight,encoder.block.1.layer.1.DenseReluDense.wo.weight,encoder.block.1.layer.1.layer_norm.weight,encoder.block.2.layer.0.SelfAttention.q.weight,encoder.block.2.layer.0.SelfAttention.k.weight,encoder.block.2.layer.0.SelfAttention.v.weight,encoder.block.2.layer.0.SelfAttention.o.weight,encoder.block.2.layer.0.layer_norm.weight,encoder.block.2.layer.1.DenseReluDense.wi_0.weight,encoder.block.2.layer.1.DenseReluDense.wi_1.weight,encoder.block.2.layer.1.DenseReluDense.wo.weight,encoder.block.2.layer.1.layer_norm.weight,encoder.block.3.layer.0.SelfAttention.q.weight,encoder.block.3.layer.0.SelfAttention.k.weight,encoder.block.3.layer.0.SelfAttention.v.weight,encoder.block.3.layer.0.SelfAttention.o.weight,encoder.block.3.layer.0.layer_norm.weight,encoder.block.3.layer.1.DenseReluDense.wi_0.weight,encoder.block.3.layer.1.DenseReluDense.wi_1.weight,encoder.block.3.layer.1.DenseReluDense.wo.weight,encoder.block.3.layer.1.layer_norm.weight,encoder.block.4.layer.0.SelfAttention.q.weight,encoder.block.4.layer.0.SelfAttention.k.weight,encoder.block.4.layer.0.SelfAttention.v.weight,encoder.block.4.layer.0.SelfAttention.o.weight,encoder.block.4.layer.0.layer_norm.weight,encoder.block.4.layer.1.DenseReluDense.wi_0.weight,encoder.block.4.layer.1.DenseReluDense.wi_1.weight,encoder.block.4.layer.1.DenseReluDense.wo.weight,encoder.block.4.layer.1.layer_norm.weight,encoder.block.5.layer.0.SelfAttention.q.weight,encoder.block.5.layer.0.SelfAttention.k.weight,encoder.block.5.layer.0.SelfAttention.v.weight,encoder.block.5.layer.0.SelfAttention.o.weight,encoder.block.5.layer.0.layer_norm.weight,encoder.block.5.layer.1.DenseReluDense.wi_0.weight,encoder.block.5.layer.1.DenseReluDense.wi_1.weight,encoder.block.5.layer.1.DenseReluDense.wo.weight,encoder.block.5.layer.1.layer_norm.weight,encoder.block.6.layer.0.SelfAttention.q.weight,encoder.block.6.layer.0.SelfAttention.k.weight,encoder.block.6.layer.0.SelfAttention.v.weight,encoder.block.6.layer.0.SelfAttention.o.weight,encoder.block.6.layer.0.layer_norm.weight,encoder.block.6.layer.1.DenseReluDense.wi_0.weight,encoder.block.6.layer.1.DenseReluDense.wi_1.weight,encoder.block.6.layer.1.DenseReluDense.wo.weight,encoder.block.6.layer.1.layer_norm.weight,encoder.block.7.layer.0.SelfAttention.q.weight,encoder.block.7.layer.0.SelfAttention.k.weight,encoder.block.7.layer.0.SelfAttention.v.weight,encoder.block.7.layer.0.SelfAttention.o.weight,encoder.block.7.layer.0.layer_norm.weight,encoder.block.7.layer.1.DenseReluDense.wi_0.weight,encoder.block.7.layer.1.DenseReluDense.wi_1.weight,encoder.block.7.layer.1.DenseReluDense.wo.weight,encoder.block.7.layer.1.layer_norm.weight,encoder.block.8.layer.0.SelfAttention.q.weight,encoder.block.8.layer.0.SelfAttention.k.weight,encoder.block.8.layer.0.SelfAttention.v.weight,encoder.block.8.layer.0.SelfAttention.o.weight,encoder.block.8.layer.0.layer_norm.weight,encoder.block.8.layer.1.DenseReluDense.wi_0.weight,encoder.block.8.layer.1.DenseReluDense.wi_1.weight,encoder.block.8.layer.1.DenseReluDense.wo.weight,encoder.block.8.layer.1.layer_norm.weight,encoder.block.9.layer.0.SelfAttention.q.weight,encoder.block.9.layer.0.SelfAttention.k.weight,encoder.block.9.layer.0.SelfAttention.v.weight,encoder.block.9.layer.0.SelfAttention.o.weight,encoder.block.9.layer.0.layer_norm.weight,encoder.block.9.layer.1.DenseReluDense.wi_0.weight,encoder.block.9.layer.1.DenseReluDense.wi_1.weight,encoder.block.9.layer.1.DenseReluDense.wo.weight,encoder.block.9.layer.1.layer_norm.weight,encoder.block.10.layer.0.SelfAttention.q.weight,encoder.block.10.layer.0.SelfAttention.k.weight,encoder.block.10.layer.0.SelfAttention.v.weight,encoder.block.10.layer.0.SelfAttention.o.weight,encoder.block.10.layer.0.layer_norm.weight,encoder.block.10.layer.1.DenseReluDense.wi_0.weight,encoder.block.10.layer.1.DenseReluDense.wi_1.weight,encoder.block.10.layer.1.DenseReluDense.wo.weight,encoder.block.10.layer.1.layer_norm.weight,encoder.block.11.layer.0.SelfAttention.q.weight,encoder.block.11.layer.0.SelfAttention.k.weight,encoder.block.11.layer.0.SelfAttention.v.weight,encoder.block.11.layer.0.SelfAttention.o.weight,encoder.block.11.layer.0.layer_norm.weight,encoder.block.11.layer.1.DenseReluDense.wi_0.weight,encoder.block.11.layer.1.DenseReluDense.wi_1.weight,encoder.block.11.layer.1.DenseReluDense.wo.weight,encoder.block.11.layer.1.layer_norm.weight,encoder.final_layer_norm.weight,decoder.embed_tokens.weight,decoder.block.0.layer.0.SelfAttention.q.weight,decoder.block.0.layer.0.SelfAttention.k.weight,decoder.block.0.layer.0.SelfAttention.v.weight,decoder.block.0.layer.0.SelfAttention.o.weight,decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,decoder.block.0.layer.0.layer_norm.weight,decoder.block.0.layer.1.EncDecAttention.q.weight,decoder.block.0.layer.1.EncDecAttention.k.weight,decoder.block.0.layer.1.EncDecAttention.v.weight,decoder.block.0.layer.1.EncDecAttention.o.weight,decoder.block.0.layer.1.layer_norm.weight,decoder.block.0.layer.2.DenseReluDense.wi_0.weight,decoder.block.0.layer.2.DenseReluDense.wi_1.weight,decoder.block.0.layer.2.DenseReluDense.wo.weight,decoder.block.0.layer.2.layer_norm.weight,decoder.block.1.layer.0.SelfAttention.q.weight,decoder.block.1.layer.0.SelfAttention.k.weight,decoder.block.1.layer.0.SelfAttention.v.weight,decoder.block.1.layer.0.SelfAttention.o.weight,decoder.block.1.layer.0.layer_norm.weight,decoder.block.1.layer.1.EncDecAttention.q.weight,decoder.block.1.layer.1.EncDecAttention.k.weight,decoder.block.1.layer.1.EncDecAttention.v.weight,decoder.block.1.layer.1.EncDecAttention.o.weight,decoder.block.1.layer.1.layer_norm.weight,decoder.block.1.layer.2.DenseReluDense.wi_0.weight,decoder.block.1.layer.2.DenseReluDense.wi_1.weight,decoder.block.1.layer.2.DenseReluDense.wo.weight,decoder.block.1.layer.2.layer_norm.weight,decoder.block.2.layer.0.SelfAttention.q.weight,decoder.block.2.layer.0.SelfAttention.k.weight,decoder.block.2.layer.0.SelfAttention.v.weight,decoder.block.2.layer.0.SelfAttention.o.weight,decoder.block.2.layer.0.layer_norm.weight,decoder.block.2.layer.1.EncDecAttention.q.weight,decoder.block.2.layer.1.EncDecAttention.k.weight,decoder.block.2.layer.1.EncDecAttention.v.weight,decoder.block.2.layer.1.EncDecAttention.o.weight,decoder.block.2.layer.1.layer_norm.weight,decoder.block.2.layer.2.DenseReluDense.wi_0.weight,decoder.block.2.layer.2.DenseReluDense.wi_1.weight,decoder.block.2.layer.2.DenseReluDense.wo.weight,decoder.block.2.layer.2.layer_norm.weight,decoder.block.3.layer.0.SelfAttention.q.weight,decoder.block.3.layer.0.SelfAttention.k.weight,decoder.block.3.layer.0.SelfAttention.v.weight,decoder.block.3.layer.0.SelfAttention.o.weight,decoder.block.3.layer.0.layer_norm.weight,decoder.block.3.layer.1.EncDecAttention.q.weight,decoder.block.3.layer.1.EncDecAttention.k.weight,decoder.block.3.layer.1.EncDecAttention.v.weight,decoder.block.3.layer.1.EncDecAttention.o.weight,decoder.block.3.layer.1.layer_norm.weight,decoder.block.3.layer.2.DenseReluDense.wi_0.weight,decoder.block.3.layer.2.DenseReluDense.wi_1.weight,decoder.block.3.layer.2.DenseReluDense.wo.weight,decoder.block.3.layer.2.layer_norm.weight,decoder.block.4.layer.0.SelfAttention.q.weight,decoder.block.4.layer.0.SelfAttention.k.weight,decoder.block.4.layer.0.SelfAttention.v.weight,decoder.block.4.layer.0.SelfAttention.o.weight,decoder.block.4.layer.0.layer_norm.weight,decoder.block.4.layer.1.EncDecAttention.q.weight,decoder.block.4.layer.1.EncDecAttention.k.weight,decoder.block.4.layer.1.EncDecAttention.v.weight,decoder.block.4.layer.1.EncDecAttention.o.weight,decoder.block.4.layer.1.layer_norm.weight,decoder.block.4.layer.2.DenseReluDense.wi_0.weight,decoder.block.4.layer.2.DenseReluDense.wi_1.weight,decoder.block.4.layer.2.DenseReluDense.wo.weight,decoder.block.4.layer.2.layer_norm.weight,decoder.block.5.layer.0.SelfAttention.q.weight,decoder.block.5.layer.0.SelfAttention.k.weight,decoder.block.5.layer.0.SelfAttention.v.weight,decoder.block.5.layer.0.SelfAttention.o.weight,decoder.block.5.layer.0.layer_norm.weight,decoder.block.5.layer.1.EncDecAttention.q.weight,decoder.block.5.layer.1.EncDecAttention.k.weight,decoder.block.5.layer.1.EncDecAttention.v.weight,decoder.block.5.layer.1.EncDecAttention.o.weight,decoder.block.5.layer.1.layer_norm.weight,decoder.block.5.layer.2.DenseReluDense.wi_0.weight,decoder.block.5.layer.2.DenseReluDense.wi_1.weight,decoder.block.5.layer.2.DenseReluDense.wo.weight,decoder.block.5.layer.2.layer_norm.weight,decoder.block.6.layer.0.SelfAttention.q.weight,decoder.block.6.layer.0.SelfAttention.k.weight,decoder.block.6.layer.0.SelfAttention.v.weight,decoder.block.6.layer.0.SelfAttention.o.weight,decoder.block.6.layer.0.layer_norm.weight,decoder.block.6.layer.1.EncDecAttention.q.weight,decoder.block.6.layer.1.EncDecAttention.k.weight,decoder.block.6.layer.1.EncDecAttention.v.weight,decoder.block.6.layer.1.EncDecAttention.o.weight,decoder.block.6.layer.1.layer_norm.weight,decoder.block.6.layer.2.DenseReluDense.wi_0.weight,decoder.block.6.layer.2.DenseReluDense.wi_1.weight,decoder.block.6.layer.2.DenseReluDense.wo.weight,decoder.block.6.layer.2.layer_norm.weight,decoder.block.7.layer.0.SelfAttention.q.weight,decoder.block.7.layer.0.SelfAttention.k.weight,decoder.block.7.layer.0.SelfAttention.v.weight,decoder.block.7.layer.0.SelfAttention.o.weight,decoder.block.7.layer.0.layer_norm.weight,decoder.block.7.layer.1.EncDecAttention.q.weight,decoder.block.7.layer.1.EncDecAttention.k.weight,decoder.block.7.layer.1.EncDecAttention.v.weight,decoder.block.7.layer.1.EncDecAttention.o.weight,decoder.block.7.layer.1.layer_norm.weight,decoder.block.7.layer.2.DenseReluDense.wi_0.weight,decoder.block.7.layer.2.DenseReluDense.wi_1.weight,decoder.block.7.layer.2.DenseReluDense.wo.weight,decoder.block.7.layer.2.layer_norm.weight,decoder.block.8.layer.0.SelfAttention.q.weight,decoder.block.8.layer.0.SelfAttention.k.weight,decoder.block.8.layer.0.SelfAttention.v.weight,decoder.block.8.layer.0.SelfAttention.o.weight,decoder.block.8.layer.0.layer_norm.weight,decoder.block.8.layer.1.EncDecAttention.q.weight,decoder.block.8.layer.1.EncDecAttention.k.weight,decoder.block.8.layer.1.EncDecAttention.v.weight,decoder.block.8.layer.1.EncDecAttention.o.weight,decoder.block.8.layer.1.layer_norm.weight,decoder.block.8.layer.2.DenseReluDense.wi_0.weight,decoder.block.8.layer.2.DenseReluDense.wi_1.weight,decoder.block.8.layer.2.DenseReluDense.wo.weight,decoder.block.8.layer.2.layer_norm.weight,decoder.block.9.layer.0.SelfAttention.q.weight,decoder.block.9.layer.0.SelfAttention.k.weight,decoder.block.9.layer.0.SelfAttention.v.weight,decoder.block.9.layer.0.SelfAttention.o.weight,decoder.block.9.layer.0.layer_norm.weight,decoder.block.9.layer.1.EncDecAttention.q.weight,decoder.block.9.layer.1.EncDecAttention.k.weight,decoder.block.9.layer.1.EncDecAttention.v.weight,decoder.block.9.layer.1.EncDecAttention.o.weight,decoder.block.9.layer.1.layer_norm.weight,decoder.block.9.layer.2.DenseReluDense.wi_0.weight,decoder.block.9.layer.2.DenseReluDense.wi_1.weight,decoder.block.9.layer.2.DenseReluDense.wo.weight,decoder.block.9.layer.2.layer_norm.weight,decoder.block.10.layer.0.SelfAttention.q.weight,decoder.block.10.layer.0.SelfAttention.k.weight,decoder.block.10.layer.0.SelfAttention.v.weight,decoder.block.10.layer.0.SelfAttention.o.weight,decoder.block.10.layer.0.layer_norm.weight,decoder.block.10.layer.1.EncDecAttention.q.weight,decoder.block.10.layer.1.EncDecAttention.k.weight,decoder.block.10.layer.1.EncDecAttention.v.weight,decoder.block.10.layer.1.EncDecAttention.o.weight,decoder.block.10.layer.1.layer_norm.weight,decoder.block.10.layer.2.DenseReluDense.wi_0.weight,decoder.block.10.layer.2.DenseReluDense.wi_1.weight,decoder.block.10.layer.2.DenseReluDense.wo.weight,decoder.block.10.layer.2.layer_norm.weight,decoder.block.11.layer.0.SelfAttention.q.weight,decoder.block.11.layer.0.SelfAttention.k.weight,decoder.block.11.layer.0.SelfAttention.v.weight,decoder.block.11.layer.0.SelfAttention.o.weight,decoder.block.11.layer.0.layer_norm.weight,decoder.block.11.layer.1.EncDecAttention.q.weight,decoder.block.11.layer.1.EncDecAttention.k.weight,decoder.block.11.layer.1.EncDecAttention.v.weight,decoder.block.11.layer.1.EncDecAttention.o.weight,decoder.block.11.layer.1.layer_norm.weight,decoder.block.11.layer.2.DenseReluDense.wi_0.weight,decoder.block.11.layer.2.DenseReluDense.wi_1.weight,decoder.block.11.layer.2.DenseReluDense.wo.weight,decoder.block.11.layer.2.layer_norm.weight,decoder.final_layer_norm.weight,lm_head.weight]. All weights are initialized. [2022-12-08 17:54:07,569 INFO] Using SimplePredict to predict... 16it [04:16, 16.06s/it]
print('Labeled samples:') ! tail -n 3 cn_dev.tsv print('Predicted results:') ! tail -n 3 cn.preds.txt
Labeled samples: 上海嘉定实验小学门口今晨7车连撞1人受伤 今晨7时45分左右,位于平城路上的嘉定区实验小学门口发生7车连撞事故。事发时正值学生上学高峰,路上车流量大,车速慢。据悉,先是两车相撞,随后引发7车连环撞。事故造成1人受伤。警方正对涉事驾驶员进行验血调查。 王中军:投资就像追女孩并购银汉幕后军师是腾讯 “从投资到现在为止没有像我想的那么顺利在接受专访时,华谊兄弟董事长王中军还打趣说,投资谈判遇到阻碍很正常,“就像我要追一个女孩,不是说我今年把她追到手就真能追到手,人家可能三年后才同意,完全有这个可能。 营销不宜过度“快消化” 对于快消品来说,其在选择数字媒体平台时需要考虑:1、确保能够接触到主流受众的规模;2、有效地提高受众接触度并运用整合性的媒体投放;3、迎合目标受众的需求。如果营销过度“快消化”,则不利于长品牌的长期价值构建和与消费者之间的沟通。 Predicted results: 嘉定一小学门口发生7车连撞事故1人受伤 嘉定一小学门口发生7车连撞事故1人受伤||嘉定一小学门口发生7车连撞事故1人受伤||嘉定实验小学门口发生7车连撞事故1人受伤||嘉定实验小学7车连撞事故1人受伤||嘉定实验小学门口发生7车连撞事故 上海嘉定实验小学门口今晨7车连撞1人受伤 今晨7时45分左右,位于平城路上的嘉定区实验小学门口发生7车连撞事故。事发时正值学生上学高峰,路上车流量大,车速慢。据悉,先是两车相撞,随后引发7车连环撞。事故造成1人受伤。警方正对涉事驾驶员进行验血调查。 王中军:从投资到现在为止没有像我想的那么顺利 王中军:从投资到现在为止没有像我想的那么顺利||华谊董事长:从投资到现在为止没有像我想的那么顺利||华谊董事长王中军:从投资到现在为止没有像想的那么顺利||“从投资到现在为止没有像我想的那么顺利”||“从投资到现在为止没有像我想的那么顺利 王中军:投资就像追女孩并购银汉幕后军师是腾讯 “从投资到现在为止没有像我想的那么顺利在接受专访时,华谊兄弟董事长王中军还打趣说,投资谈判遇到阻碍很正常,“就像我要追一个女孩,不是说我今年把她追到手就真能追到手,人家可能三年后才同意,完全有这个可能。 选择数字媒体平台需考虑 选择数字媒体平台需考虑||选择数字媒体平台需谨慎||快消品选择数字媒体平台||快消品选择数字媒体需考虑||快消品选择数字媒体平台需考虑 营销不宜过度“快消化” 对于快消品来说,其在选择数字媒体平台时需要考虑:1、确保能够接触到主流受众的规模;2、有效地提高受众接触度并运用整合性的媒体投放;3、迎合目标受众的需求。如果营销过度“快消化”,则不利于长品牌的长期价值构建和与消费者之间的沟通。
上面展示了数据集中的1条数据以及经过训练以后模型的预测结果。预测的标题为第一列,第二列为集束搜索(beam search)的5条结果,以||隔开。预测结果同时还包含原始新闻标题、新闻原文,相互之间以\t隔开。
一步执行
值得一提的是,上述所有训练/评估/预测代码,都已经被集成在EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py中,此外,我们也预先编写好了多种可供直接执行的脚本。用户可以通过带参数运行上述main.py文件,或者直接执行脚本文件run_user_defined_local_zh.sh的方式,一步执行上述所有训练/评估/预测操作。
mian.py文件一步执行
用户通过以下代码带参数执行main.py中的指令,可直接对模型进行训练/评估/预测操作。 训练代码指令如下。具体的参数解释可见上文,此处不再赘述。
模型训练代码如下:
! python EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py \ --mode train \ --app_name=sequence_generation \ --worker_gpu=1 \ --tables=./cn_train.tsv,./cn_dev.tsv \ --input_schema=target:str:1,source:str:1 \ --first_sequence=source \ --second_sequence=target \ --label_name=target \ --checkpoint_dir=./finetuned_zh_model \ --micro_batch_size=8 \ --sequence_length=512 \ --epoch_num=1 \ --save_checkpoint_steps=150 \ --export_tf_checkpoint_type none \ --user_defined_parameters 'pretrain_model_name_or_path=hfl/randeng-summary-generation-base-zh language=zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5'
模型评估代码如下:
! python EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py \ --mode=evaluate \ --app_name=sequence_generation \ --worker_gpu=1 \ --tables=./cn_dev.tsv \ --input_schema=target:str:1,source:str:1 \ --first_sequence=source \ --second_sequence=target \ --label_name=target \ --checkpoint_dir=./finetuned_zh_model \ --micro_batch_size=8 \ --sequence_length=512 \ --epoch_num=1 \ --save_checkpoint_steps=150 \ --export_tf_checkpoint_type none \ --user_defined_parameters 'language=zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5'
模型预测代码如下:
! python EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py \ --mode=predict \ --app_name=sequence_generation \ --worker_gpu=1 \ --tables=./cn_dev.tsv \ --outputs=./cn.preds.txt \ --input_schema=target:str:1,source:str:1 \ --output_schema=predictions,beams \ --append_cols=target,source \ --first_sequence=source \ --checkpoint_dir=./finetuned_zh_model \ --micro_batch_size=32 \ --sequence_length=512 \ --user_defined_parameters 'language=zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5'
利用bash文件命令行执行
我们在EasyNLP/examples/appzoo_tutorials/sequence_generation文件夹下封装好了多种可直接执行的脚本,用户同样可以通过带参数执行脚本文件的方式来一步完成模型的训练/评估/预测。以下以run_user_defined_local_zh.sh脚本为例。该脚本文件需要传入两个参数,第一个参数为运行程序的GPU编号,一般为0;第二个参数代表模型的训练/评估/预测。
注意:训练、评估、预测的相关参数在文件run_user_defined_local_zh中,如需使用自有数据和参数则应修改其中的参数。
模型训练:
! cd EasyNLP/examples/appzoo_tutorials/sequence_generation && bash run_user_defined_local_zh.sh 0 train
模型评估:
! bash EasyNLP/examples/appzoo_tutorials/sequence_generation/run_user_defined_local_zh.sh 0 evaluate
模型预测:
! cd EasyNLP/examples/appzoo_tutorials/sequence_generation && bash run_user_defined_local_zh.sh 0 predict