【DSW Gallery】基于预训练模型的多场景文本生成(以新闻标题生成为例)

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文以中文文本生成为例,为您介绍如何在PAI-DSW中使用EasyNLP。

直接使用

请打开基于预训练模型的多场景文本生成(以新闻标题生成为例),并点击右上角 “ 在DSW中打开” 。

image.png

基于预训练模型的多场景文本生成(以新闻标题生成为例)

文本生成的目标是基于给定文本指引,由模型生成对应的文本段,具有丰富的应用场景,包括文本摘要、新闻标题生成、文案生成、问题生成、作文生成、古诗生成、文本纠错、写对联等。在开源代码库EasyNLP中,我们集成了目前先进模型在上述应用场景中微调过的模型,方便用户进行具体业务场景的进一步训练和预测。以下将以新闻标题生成(生成短摘要)为例,展示如何利用EasyNLP进行文本生成相关任务的全链路过程。同时,下表展示了目前EasyNLP提供的各场景模型以及对应场景中可以使用的demo数据集,欢迎进行尝试。注:最新的模型信息可见列表

数据地址格式为:http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/替换为数据名称.tsv

任务

可选模型

demo数据集

文本生成

hfl/bart-generation-base-zh,hfl/bart-generation-large-zh

/

新闻标题生成

alibaba-pai/mt5-title-generation-zh,alibaba-pai/randeng-essay-generation-base-zh,alibaba-pai/randeng-question-generation-base-zh

cn_train/cn_dev

文案生成

alibaba-pai/randeng-advertise-generation-base-zh

advertise_train/advertise_dev

问题生成

alibaba-pai/randeng-question-generation-base-zh,alibaba-pai/bart-question-generation-large-zh

question_train/question_dev

作文生成

alibaba-pai/randeng-essay-generation-base-zh,alibaba-pai/glm-essay-generation-large-zh

essay_train/essay_dev

古诗生成

alibaba-pai/bart-poem-generation-large-zh,alibaba-pai/randeng-poem-generation-base-zh

poem_train/poem_dev

文本摘要(Text Summarization)旨在从冗长、重复的文本序列中抽取、精炼或总结出其中的要点信息。T5是由谷歌提出的一个序列到序列预训练模型,它将不同的生成任务进行统一,在兼顾迁移性的前提下取得了文本生成领域的最佳性能。mT5是T5的多语言版本,该模型利用包含101种语言的语料训练得到多语言预训练模型。

EasyNLP中,我们提供了经过训练的mT5,以便用户能够受益于模型强大的建模能力。该模型是在mT5的基础上利用新闻数据进行微调得到。本文将以新闻标题生成任务为例,将mT5作为模型底座构建标题生成模型,展示如何利用EasyNLP进行模型构建、训练、评估、预测。

运行环境要求

PAI-Pytorch 1.7/1.8镜像, GPU机型 P100 or V100, 内存32G

EasyNLP安装

建议从GitHub下载EasyNLP源代码进行安装,命令如下:

! echo y | pip uninstall pai-easynlp easynlp
! git clone https://github.com/alibaba/EasyNLP.git
! pip install -r EasyNLP/requirements.txt -i http://mirrors.aliyun.com/pypi/simple/
! cd EasyNLP && python setup.py install

安装完成easynlp之后,建议重启notebook,防止环境存在缓存,未更新

您可以使用如下命令验证是否安装成功:

import easynlp
easynlp.__file__

如果您系统内已经安装完easynlp的CLI工具,则说明EasyNLP代码库已经安装。

数据准备

首先,您需要下载用于本示例的训练和测试集,并创建保存模型的文件夹,命令如下:

! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_train.tsv
! wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_dev.tsv
--2022-11-01 10:26:29--  http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_train.tsv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 351006 (343K) [text/tab-separated-values]
Saving to: ‘cn_train.tsv’
cn_train.tsv        100%[===================>] 342.78K  2.02MB/s    in 0.2s    
2022-11-01 10:26:29 (2.02 MB/s) - ‘cn_train.tsv’ saved [351006/351006]
--2022-11-01 10:26:29--  http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/generation/cn_dev.tsv
Resolving atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)... 47.101.88.27
Connecting to atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com (atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com)|47.101.88.27|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 175713 (172K) [text/tab-separated-values]
Saving to: ‘cn_dev.tsv’
cn_dev.tsv          100%[===================>] 171.59K  --.-KB/s    in 0.07s   
2022-11-01 10:26:30 (2.30 MB/s) - ‘cn_dev.tsv’ saved [175713/175713]

数据下载完成后,可以通过以下代码查看第一条数据。在训练集验证集中,每一行为一条新闻数据,包括新闻标题和新闻内容,两者通过制表符(\t)隔开。

print('Training data sample:')
! head -n 1 cn_train.tsv
print('Development set data sample:')
! head -n 1 cn_dev.tsv
Training data sample:
远离烟草!不止二手烟,还有三手烟! 在广州第一人民医院,一个上午6名患者做支气管镜检查,5人查出肺癌,且4人是老烟民!专家称,吸烟和被动吸烟是肺癌的主要元凶,而二手烟、三手烟(即吸烟后滞留在室内或衣服、头发等的微粒和气体)与吸烟危害性一样大!
Development set data sample:
#全国低碳城市十强#排行榜合肥居榜首  中科院上海高等研究院日前发布《中国低碳城市建设报告》,合肥、广州、南京排全国低碳城市十强榜前三。城市评价的五个领域:经济社会特征、基础设施建设特征、城市能源消耗特征、城市交通运输和城市环境影响特征。同意的点赞,质疑的转发

初始化

在Python 3.6环境下,我们首先从刚刚安装好的EasyNLP中引入模型运行需要的各种库,并做一些初始化。在本教程中,我们使用mt5-title-generation-zh作为预训练模型底座。

# 为了避免EasyNLP中的args与Jupyter系统的冲突,需要手动设置,否则无法进行初始化。
# 在命令行或py文件中运行文中代码则可忽略下述代码。
import sys
sys.argv = ['main.py']
import imp
import sys
import os
import torch.cuda
sys.path.append('./')
from easynlp.core import Trainer
from easynlp.appzoo.sequence_generation.data import SequenceGenerationDataset
from easynlp.appzoo.sequence_generation.model import SequenceGeneration
from easynlp.appzoo.sequence_generation.evaluator import SequenceGenerationEvaluator
from easynlp.appzoo.sequence_generation.predictor import SequenceGenerationPredictor
from easynlp.appzoo import get_application_model_for_evaluation
from easynlp.utils import initialize_easynlp, get_args
from easynlp.utils.global_vars import parse_user_defined_parameters
from easynlp.core import PredictorManager
from easynlp.utils import get_pretrain_model_path
initialize_easynlp()
args = get_args()
user_defined_parameters='pretrain_model_name_or_path=hfl/randeng-summary-generation-base-zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5'
user_defined_parameters = parse_user_defined_parameters(user_defined_parameters)
args.checkpoint_dir = "./finetuned_zh_model"
[2022-12-08 17:49:44,415.415 dsw34730-bbd74fb77-5264h:173618 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
[2022-12-08 17:49:44,781] [WARNING] [partition_parameters.py:61:<module>] unable to find torch.distributed._all_gather_base. will fall back to torch.distributed.all_gather which will result in suboptimal performance. please consider upgrading your pytorch installation.
/home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:524: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/pai/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:532: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
  from cryptography import x509
Please ignore the following import error if you are using tunnel table io.
No module named '_common_io'
No module named 'easy_predict'
The following parameters are not recognized: []
------------------------ arguments ------------------------
  adam_beta1 ...................................... 0.9
  adam_beta2 ...................................... 0.999
  adam_eps ........................................ 1e-08
  adapet .......................................... False
  app_name ........................................ text_classify
  append_cols ..................................... None
  attention_dropout ............................... 0.1
  attention_scale ................................. 1.0
  avg_block_length ................................ 3
  batch_size ...................................... 4
  bert_prob ....................................... 0.5
  blank_maskratio ................................. 0.1
  block_lm ........................................ False
  block_lm_ratio .................................. 0.0
  block_mask_prob ................................. 0.0
  buckets ......................................... None
  cache_dir ....................................... None
  checkpoint_activations .......................... False
  checkpoint_dir .................................. None
  checkpoint_num_layers ........................... 1
  chief_hosts ..................................... 
  clip_grad ....................................... 1.0
  cloze_eval ...................................... False
  context_mask_ratio .............................. 0.0
  continuous_prompt ............................... False
  cpu_optimizer ................................... False
  cpu_torch_adam .................................. False
  data_dir ........................................ ./
  data_threads .................................... 10
  DDP_impl ........................................ torch
  deep_init ....................................... False
  deepspeed_activation_checkpointing .............. False
  delim ........................................... ,
  distributed_backend ............................. nccl
  do_lower_case ................................... False
  encoder_decoder ................................. False
  epoch_num ....................................... 3
  epochs .......................................... None
  eval_batch_size ................................. None
  eval_epoch ...................................... 1
  eval_interval ................................... 1000
  eval_iters ...................................... 100
  eval_max_preds_per_seq .......................... None
  eval_seq_length ................................. None
  eval_text_key ................................... None
  eval_valid ...................................... False
  experiment_name ................................. glm-356M
  export_tf_checkpoint_type ....................... easytransfer
  fast_decode ..................................... False
  few_superglue ................................... False
  filter_english .................................. False
  finetune ........................................ False
  first_sequence .................................. None
  fix_command_token ............................... False
  fp16 ............................................ False
  fp32_allreduce .................................. False
  fp32_embedding .................................. False
  fp32_layernorm .................................. False
  fp32_tokentypes ................................. False
  freeze_transformer .............................. False
  gap_sentence_prob ............................... 0.0
  gap_sentence_ratio .............................. 0.15
  gpt_infill_prob ................................. 0.5
  gpt_min_ratio ................................... 0.5
  gradient_accumulation_steps ..................... 1
  half_lazy_loader ................................ False
  hidden_dropout .................................. 0.1
  hidden_size ..................................... 1024
  hysteresis ...................................... 2
  input_data_sizes_file ........................... sizes.txt
  input_schema .................................... None
  intermediate_size ............................... None
  is_chief ........................................ 
  is_master_node .................................. True
  job_name ........................................ None
  label_enumerate_values .......................... None
  label_name ...................................... None
  label_smoothing ................................. 0.0
  layernorm_epsilon ............................... 1e-05
  learning_rate ................................... 5e-05
  length_penalty .................................. 0.0
  load ............................................ None
  load_pretrained ................................. None
  load_splits ..................................... None
  loader_scatter .................................. None
  local_rank ...................................... None
  log_interval .................................... 100
  logging_steps ................................... 100
  loose_json ...................................... False
  loss_func ....................................... cross_entropy
  loss_scale ...................................... None
  loss_scale_window ............................... 1000
  lr .............................................. 0.0001
  lr_decay_iters .................................. None
  lr_decay_ratio .................................. 0.1
  lr_decay_style .................................. linear
  make_vocab_size_divisible_by .................... 128
  masked_lm ....................................... False
  master_port ..................................... 23456
  max_grad_norm ................................... 1.0
  max_position_embeddings ......................... 512
  max_preds_per_seq ............................... None
  mem_length ...................................... 0
  mg_model ........................................ False
  micro_batch_size ................................ 2
  min_scale ....................................... 1
  min_tgt_length .................................. 0
  mode ............................................ train
  model_parallel_size ............................. 1
  modelzoo_base_dir ............................... 
  multi_batch_size ................................ None
  multi_seq_length ................................ None
  multi_task_data ................................. None
  multi_task_ratio ................................ 0.0
  multi_token ..................................... False
  n_cpu ........................................... 1
  n_gpu ........................................... 1
  new_save_directory .............................. False
  no_block_position ............................... False
  no_deepspeed_load ............................... False
  no_lazy_loader .................................. False
  no_load_lr_scheduler ............................ False
  no_load_optim ................................... False
  no_load_rng ..................................... False
  no_pre_tokenize ................................. False
  no_repeat_ngram_size ............................ 0
  no_save_optim ................................... False
  no_save_rng ..................................... False
  no_shuffle_block ................................ False
  no_validation ................................... False
  non_sentence_start .............................. 0.0
  num_attention_heads ............................. 16
  num_beams ....................................... 1
  num_layers ...................................... 24
  num_prompt_tokens ............................... 0
  num_workers ..................................... 2
  odps_config ..................................... None
  optimizer_type .................................. AdamW
  out_seq_length .................................. 256
  output_dropout .................................. 0.1
  output_schema ................................... 
  outputs ......................................... None
  overlapping_eval ................................ 32
  overwrite ....................................... False
  pattern_id ...................................... 0
  pool_token ...................................... cls
  predict_queue_size .............................. 1024
  predict_slice_size .............................. 4096
  predict_table_read_thread_num ................... 16
  predict_thread_num .............................. 2
  prefix_prompt ................................... 0
  presplit_sentences .............................. False
  pretrained_bert ................................. False
  prompt_func ..................................... lstm
  prompt_init ..................................... False
  ps_hosts ........................................ 
  random_position ................................. False
  random_seed ..................................... 1234
  rank ............................................ 0
  read_odps ....................................... False
  reset_attention_mask ............................ False
  reset_position_ids .............................. False
  restore_works_dir ............................... ./.easynlp_predict_restore_works_dir
  resume_dataloader ............................... False
  resume_from_checkpoint .......................... None
  sample_one_document ............................. False
  save ............................................ None
  save_all_checkpoints ............................ False
  save_checkpoint_steps ........................... None
  save_epoch ...................................... 1
  save_interval ................................... 5000
  save_splits ..................................... None
  save_test_data .................................. None
  second_sequence ................................. None
  seed ............................................ 1234
  segment_length .................................. 0
  select_topk ..................................... False
  sentinel_token .................................. False
  seq_length ...................................... 512
  sequence_length ................................. 16
  short_seq_prob .................................. 0.0
  shuffle ......................................... False
  single_span_prob ................................ 0.0
  skip_first_line ................................. False
  split ........................................... 1000,1,1
  src_seq_length .................................. None
  summary_dir ..................................... 
  switch_linear ................................... False
  tables .......................................... None
  task ............................................ chinesegen
  task_count ...................................... 1
  task_index ...................................... 0
  task_mask ....................................... False
  temperature ..................................... 1.0
  test_data ....................................... None
  text_key ........................................ sentence
  tgt_seq_length .................................. None
  tokenizer_model_type ............................ None
  tokenizer_path .................................. tokenizer.model
  tokenizer_type .................................. ChineseSPTokenizer
  top_k ........................................... 0
  top_p ........................................... 0.0
  train_data ...................................... None
  train_iters ..................................... 0
  transformer_xl .................................. False
  tune_prefix_layers .............................. None
  unidirectional .................................. False
  use_amp ......................................... False
  use_tfrecords ................................... False
  use_torchacc .................................... False
  user_defined_parameters ......................... None
  user_entry_file ................................. None
  user_script ..................................... None
  valid_data ...................................... None
  validation_metric ............................... None
  vocab_size ...................................... 30522
  warmup .......................................... 0.01
  warmup_proportion ............................... 0.1
  weight_decay .................................... 0.0001
  worker_count .................................... 1
  worker_cpu ...................................... -1
  worker_gpu ...................................... -1
  worker_hosts .................................... None
  world_size ...................................... 1
  wsc_negative .................................... False
-------------------- end of arguments ---------------------
> initializing torch distributed ...
[2022-12-08 17:49:47,752.752 dsw34730-bbd74fb77-5264h:173618 INFO distributed_c10d.py:195] Added key: store_based_barrier_key:1 to store for rank: 0
Init dist done. World size: 1, rank 0, l_rank 0
> setting random seeds to 1234 ...

注意:上述代码如果出现“Address already in use”错误,则需要运行以下代码清理端口(默认为6000)上正在执行的程序。

netstat -tunlp|grep 6000

kill -9 PID (需要替换成上一行代码执行结果中对应的程序ID)

载入数据

我们使用EasyNLP中自带的SequenceGenerationDataset,对训练和测试数据进行载入。主要参数如下:

  • pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”mt5-title-generation-zh”,并自动下载模型
  • max_seq_length:文本最大长度,超过将截断,不足将padding
  • input_schema:输入数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等
  • first_sequence、label_name:用于说明input_schema中哪些字段用于作为输入句子和标签列等
  • label_enumerate_values:label类型列举
  • is_training:是否为训练过程,train_dataset为True,valid_dataset为False
  • app_name:指定当前需要执行的任务,如文本分类、序列标注、文本匹配、文本生成等

下面我们将手动设置一些参数以便进行实验。

args.tables = "./cn_train.tsv,./cn_dev.tsv"
args.input_schema = "title_tokens:str:1,content_tokens:str:1"
args.first_sequence = "content_tokens"
args.second_sequence = "title_tokens" 
args.label_name = "title_tokens"
args.learning_rate = 3e-5
args.epoch_num = 1
args.save_checkpoint_steps = 150
args.sequence_length = 512
args.micro_batch_size = 8
args.export_tf_checkpoint_type = "none"
args.app_name = "sequence_generation"
args.pretrained_model_name_or_path = user_defined_parameters.get('pretrain_model_name_or_path', None)
args.pretrained_model_name_or_path = get_pretrain_model_path(args.pretrained_model_name_or_path)
train_dataset = SequenceGenerationDataset(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        data_file=args.tables.split(",")[0],
        max_seq_length=args.sequence_length,
        input_schema=args.input_schema,
        first_sequence=args.first_sequence,
        second_sequence=args.second_sequence,
        user_defined_parameters=user_defined_parameters,
        is_training=True)
valid_dataset = SequenceGenerationDataset(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        data_file=args.tables.split(",")[-1],
        max_seq_length=args.sequence_length,
        input_schema=args.input_schema,
        first_sequence=args.first_sequence,
        second_sequence=args.second_sequence,
        user_defined_parameters=user_defined_parameters,
        is_training=False)
`/root/.easynlp/modelzoo/alibaba-pai/mt5-title-generation-zh.tgz` already exists
****./cn_train.tsv
****./cn_dev.tsv

模型训练

处理好数据与模型载入后,我们开始训练模型。 我们使用EasyNLP中封装好的SequenceGeneration函数进行训练时的模型构建,其参数如下:

  • pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称”mt5-title-generation-zh”,并自动下载模型
  • user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters

构建模型并读取

model = SequenceGeneration(pretrained_model_name_or_path=args.pretrained_model_name_or_path, user_defined_parameters=user_defined_parameters)
**language** parameter is not provided in user defined parameters, using zh as default.
 Loaded weights of the model:
 [shared.weight,encoder.embed_tokens.weight,encoder.block.0.layer.0.SelfAttention.q.weight,encoder.block.0.layer.0.SelfAttention.k.weight,encoder.block.0.layer.0.SelfAttention.v.weight,encoder.block.0.layer.0.SelfAttention.o.weight,encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,encoder.block.0.layer.0.layer_norm.weight,encoder.block.0.layer.1.DenseReluDense.wi_0.weight,encoder.block.0.layer.1.DenseReluDense.wi_1.weight,encoder.block.0.layer.1.DenseReluDense.wo.weight,encoder.block.0.layer.1.layer_norm.weight,encoder.block.1.layer.0.SelfAttention.q.weight,encoder.block.1.layer.0.SelfAttention.k.weight,encoder.block.1.layer.0.SelfAttention.v.weight,encoder.block.1.layer.0.SelfAttention.o.weight,encoder.block.1.layer.0.layer_norm.weight,encoder.block.1.layer.1.DenseReluDense.wi_0.weight,encoder.block.1.layer.1.DenseReluDense.wi_1.weight,encoder.block.1.layer.1.DenseReluDense.wo.weight,encoder.block.1.layer.1.layer_norm.weight,encoder.block.2.layer.0.SelfAttention.q.weight,encoder.block.2.layer.0.SelfAttention.k.weight,encoder.block.2.layer.0.SelfAttention.v.weight,encoder.block.2.layer.0.SelfAttention.o.weight,encoder.block.2.layer.0.layer_norm.weight,encoder.block.2.layer.1.DenseReluDense.wi_0.weight,encoder.block.2.layer.1.DenseReluDense.wi_1.weight,encoder.block.2.layer.1.DenseReluDense.wo.weight,encoder.block.2.layer.1.layer_norm.weight,encoder.block.3.layer.0.SelfAttention.q.weight,encoder.block.3.layer.0.SelfAttention.k.weight,encoder.block.3.layer.0.SelfAttention.v.weight,encoder.block.3.layer.0.SelfAttention.o.weight,encoder.block.3.layer.0.layer_norm.weight,encoder.block.3.layer.1.DenseReluDense.wi_0.weight,encoder.block.3.layer.1.DenseReluDense.wi_1.weight,encoder.block.3.layer.1.DenseReluDense.wo.weight,encoder.block.3.layer.1.layer_norm.weight,encoder.block.4.layer.0.SelfAttention.q.weight,encoder.block.4.layer.0.SelfAttention.k.weight,encoder.block.4.layer.0.SelfAttention.v.weight,encoder.block.4.layer.0.SelfAttention.o.weight,encoder.block.4.layer.0.layer_norm.weight,encoder.block.4.layer.1.DenseReluDense.wi_0.weight,encoder.block.4.layer.1.DenseReluDense.wi_1.weight,encoder.block.4.layer.1.DenseReluDense.wo.weight,encoder.block.4.layer.1.layer_norm.weight,encoder.block.5.layer.0.SelfAttention.q.weight,encoder.block.5.layer.0.SelfAttention.k.weight,encoder.block.5.layer.0.SelfAttention.v.weight,encoder.block.5.layer.0.SelfAttention.o.weight,encoder.block.5.layer.0.layer_norm.weight,encoder.block.5.layer.1.DenseReluDense.wi_0.weight,encoder.block.5.layer.1.DenseReluDense.wi_1.weight,encoder.block.5.layer.1.DenseReluDense.wo.weight,encoder.block.5.layer.1.layer_norm.weight,encoder.block.6.layer.0.SelfAttention.q.weight,encoder.block.6.layer.0.SelfAttention.k.weight,encoder.block.6.layer.0.SelfAttention.v.weight,encoder.block.6.layer.0.SelfAttention.o.weight,encoder.block.6.layer.0.layer_norm.weight,encoder.block.6.layer.1.DenseReluDense.wi_0.weight,encoder.block.6.layer.1.DenseReluDense.wi_1.weight,encoder.block.6.layer.1.DenseReluDense.wo.weight,encoder.block.6.layer.1.layer_norm.weight,encoder.block.7.layer.0.SelfAttention.q.weight,encoder.block.7.layer.0.SelfAttention.k.weight,encoder.block.7.layer.0.SelfAttention.v.weight,encoder.block.7.layer.0.SelfAttention.o.weight,encoder.block.7.layer.0.layer_norm.weight,encoder.block.7.layer.1.DenseReluDense.wi_0.weight,encoder.block.7.layer.1.DenseReluDense.wi_1.weight,encoder.block.7.layer.1.DenseReluDense.wo.weight,encoder.block.7.layer.1.layer_norm.weight,encoder.block.8.layer.0.SelfAttention.q.weight,encoder.block.8.layer.0.SelfAttention.k.weight,encoder.block.8.layer.0.SelfAttention.v.weight,encoder.block.8.layer.0.SelfAttention.o.weight,encoder.block.8.layer.0.layer_norm.weight,encoder.block.8.layer.1.DenseReluDense.wi_0.weight,encoder.block.8.layer.1.DenseReluDense.wi_1.weight,encoder.block.8.layer.1.DenseReluDense.wo.weight,encoder.block.8.layer.1.layer_norm.weight,encoder.block.9.layer.0.SelfAttention.q.weight,encoder.block.9.layer.0.SelfAttention.k.weight,encoder.block.9.layer.0.SelfAttention.v.weight,encoder.block.9.layer.0.SelfAttention.o.weight,encoder.block.9.layer.0.layer_norm.weight,encoder.block.9.layer.1.DenseReluDense.wi_0.weight,encoder.block.9.layer.1.DenseReluDense.wi_1.weight,encoder.block.9.layer.1.DenseReluDense.wo.weight,encoder.block.9.layer.1.layer_norm.weight,encoder.block.10.layer.0.SelfAttention.q.weight,encoder.block.10.layer.0.SelfAttention.k.weight,encoder.block.10.layer.0.SelfAttention.v.weight,encoder.block.10.layer.0.SelfAttention.o.weight,encoder.block.10.layer.0.layer_norm.weight,encoder.block.10.layer.1.DenseReluDense.wi_0.weight,encoder.block.10.layer.1.DenseReluDense.wi_1.weight,encoder.block.10.layer.1.DenseReluDense.wo.weight,encoder.block.10.layer.1.layer_norm.weight,encoder.block.11.layer.0.SelfAttention.q.weight,encoder.block.11.layer.0.SelfAttention.k.weight,encoder.block.11.layer.0.SelfAttention.v.weight,encoder.block.11.layer.0.SelfAttention.o.weight,encoder.block.11.layer.0.layer_norm.weight,encoder.block.11.layer.1.DenseReluDense.wi_0.weight,encoder.block.11.layer.1.DenseReluDense.wi_1.weight,encoder.block.11.layer.1.DenseReluDense.wo.weight,encoder.block.11.layer.1.layer_norm.weight,encoder.final_layer_norm.weight,decoder.embed_tokens.weight,decoder.block.0.layer.0.SelfAttention.q.weight,decoder.block.0.layer.0.SelfAttention.k.weight,decoder.block.0.layer.0.SelfAttention.v.weight,decoder.block.0.layer.0.SelfAttention.o.weight,decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,decoder.block.0.layer.0.layer_norm.weight,decoder.block.0.layer.1.EncDecAttention.q.weight,decoder.block.0.layer.1.EncDecAttention.k.weight,decoder.block.0.layer.1.EncDecAttention.v.weight,decoder.block.0.layer.1.EncDecAttention.o.weight,decoder.block.0.layer.1.layer_norm.weight,decoder.block.0.layer.2.DenseReluDense.wi_0.weight,decoder.block.0.layer.2.DenseReluDense.wi_1.weight,decoder.block.0.layer.2.DenseReluDense.wo.weight,decoder.block.0.layer.2.layer_norm.weight,decoder.block.1.layer.0.SelfAttention.q.weight,decoder.block.1.layer.0.SelfAttention.k.weight,decoder.block.1.layer.0.SelfAttention.v.weight,decoder.block.1.layer.0.SelfAttention.o.weight,decoder.block.1.layer.0.layer_norm.weight,decoder.block.1.layer.1.EncDecAttention.q.weight,decoder.block.1.layer.1.EncDecAttention.k.weight,decoder.block.1.layer.1.EncDecAttention.v.weight,decoder.block.1.layer.1.EncDecAttention.o.weight,decoder.block.1.layer.1.layer_norm.weight,decoder.block.1.layer.2.DenseReluDense.wi_0.weight,decoder.block.1.layer.2.DenseReluDense.wi_1.weight,decoder.block.1.layer.2.DenseReluDense.wo.weight,decoder.block.1.layer.2.layer_norm.weight,decoder.block.2.layer.0.SelfAttention.q.weight,decoder.block.2.layer.0.SelfAttention.k.weight,decoder.block.2.layer.0.SelfAttention.v.weight,decoder.block.2.layer.0.SelfAttention.o.weight,decoder.block.2.layer.0.layer_norm.weight,decoder.block.2.layer.1.EncDecAttention.q.weight,decoder.block.2.layer.1.EncDecAttention.k.weight,decoder.block.2.layer.1.EncDecAttention.v.weight,decoder.block.2.layer.1.EncDecAttention.o.weight,decoder.block.2.layer.1.layer_norm.weight,decoder.block.2.layer.2.DenseReluDense.wi_0.weight,decoder.block.2.layer.2.DenseReluDense.wi_1.weight,decoder.block.2.layer.2.DenseReluDense.wo.weight,decoder.block.2.layer.2.layer_norm.weight,decoder.block.3.layer.0.SelfAttention.q.weight,decoder.block.3.layer.0.SelfAttention.k.weight,decoder.block.3.layer.0.SelfAttention.v.weight,decoder.block.3.layer.0.SelfAttention.o.weight,decoder.block.3.layer.0.layer_norm.weight,decoder.block.3.layer.1.EncDecAttention.q.weight,decoder.block.3.layer.1.EncDecAttention.k.weight,decoder.block.3.layer.1.EncDecAttention.v.weight,decoder.block.3.layer.1.EncDecAttention.o.weight,decoder.block.3.layer.1.layer_norm.weight,decoder.block.3.layer.2.DenseReluDense.wi_0.weight,decoder.block.3.layer.2.DenseReluDense.wi_1.weight,decoder.block.3.layer.2.DenseReluDense.wo.weight,decoder.block.3.layer.2.layer_norm.weight,decoder.block.4.layer.0.SelfAttention.q.weight,decoder.block.4.layer.0.SelfAttention.k.weight,decoder.block.4.layer.0.SelfAttention.v.weight,decoder.block.4.layer.0.SelfAttention.o.weight,decoder.block.4.layer.0.layer_norm.weight,decoder.block.4.layer.1.EncDecAttention.q.weight,decoder.block.4.layer.1.EncDecAttention.k.weight,decoder.block.4.layer.1.EncDecAttention.v.weight,decoder.block.4.layer.1.EncDecAttention.o.weight,decoder.block.4.layer.1.layer_norm.weight,decoder.block.4.layer.2.DenseReluDense.wi_0.weight,decoder.block.4.layer.2.DenseReluDense.wi_1.weight,decoder.block.4.layer.2.DenseReluDense.wo.weight,decoder.block.4.layer.2.layer_norm.weight,decoder.block.5.layer.0.SelfAttention.q.weight,decoder.block.5.layer.0.SelfAttention.k.weight,decoder.block.5.layer.0.SelfAttention.v.weight,decoder.block.5.layer.0.SelfAttention.o.weight,decoder.block.5.layer.0.layer_norm.weight,decoder.block.5.layer.1.EncDecAttention.q.weight,decoder.block.5.layer.1.EncDecAttention.k.weight,decoder.block.5.layer.1.EncDecAttention.v.weight,decoder.block.5.layer.1.EncDecAttention.o.weight,decoder.block.5.layer.1.layer_norm.weight,decoder.block.5.layer.2.DenseReluDense.wi_0.weight,decoder.block.5.layer.2.DenseReluDense.wi_1.weight,decoder.block.5.layer.2.DenseReluDense.wo.weight,decoder.block.5.layer.2.layer_norm.weight,decoder.block.6.layer.0.SelfAttention.q.weight,decoder.block.6.layer.0.SelfAttention.k.weight,decoder.block.6.layer.0.SelfAttention.v.weight,decoder.block.6.layer.0.SelfAttention.o.weight,decoder.block.6.layer.0.layer_norm.weight,decoder.block.6.layer.1.EncDecAttention.q.weight,decoder.block.6.layer.1.EncDecAttention.k.weight,decoder.block.6.layer.1.EncDecAttention.v.weight,decoder.block.6.layer.1.EncDecAttention.o.weight,decoder.block.6.layer.1.layer_norm.weight,decoder.block.6.layer.2.DenseReluDense.wi_0.weight,decoder.block.6.layer.2.DenseReluDense.wi_1.weight,decoder.block.6.layer.2.DenseReluDense.wo.weight,decoder.block.6.layer.2.layer_norm.weight,decoder.block.7.layer.0.SelfAttention.q.weight,decoder.block.7.layer.0.SelfAttention.k.weight,decoder.block.7.layer.0.SelfAttention.v.weight,decoder.block.7.layer.0.SelfAttention.o.weight,decoder.block.7.layer.0.layer_norm.weight,decoder.block.7.layer.1.EncDecAttention.q.weight,decoder.block.7.layer.1.EncDecAttention.k.weight,decoder.block.7.layer.1.EncDecAttention.v.weight,decoder.block.7.layer.1.EncDecAttention.o.weight,decoder.block.7.layer.1.layer_norm.weight,decoder.block.7.layer.2.DenseReluDense.wi_0.weight,decoder.block.7.layer.2.DenseReluDense.wi_1.weight,decoder.block.7.layer.2.DenseReluDense.wo.weight,decoder.block.7.layer.2.layer_norm.weight,decoder.block.8.layer.0.SelfAttention.q.weight,decoder.block.8.layer.0.SelfAttention.k.weight,decoder.block.8.layer.0.SelfAttention.v.weight,decoder.block.8.layer.0.SelfAttention.o.weight,decoder.block.8.layer.0.layer_norm.weight,decoder.block.8.layer.1.EncDecAttention.q.weight,decoder.block.8.layer.1.EncDecAttention.k.weight,decoder.block.8.layer.1.EncDecAttention.v.weight,decoder.block.8.layer.1.EncDecAttention.o.weight,decoder.block.8.layer.1.layer_norm.weight,decoder.block.8.layer.2.DenseReluDense.wi_0.weight,decoder.block.8.layer.2.DenseReluDense.wi_1.weight,decoder.block.8.layer.2.DenseReluDense.wo.weight,decoder.block.8.layer.2.layer_norm.weight,decoder.block.9.layer.0.SelfAttention.q.weight,decoder.block.9.layer.0.SelfAttention.k.weight,decoder.block.9.layer.0.SelfAttention.v.weight,decoder.block.9.layer.0.SelfAttention.o.weight,decoder.block.9.layer.0.layer_norm.weight,decoder.block.9.layer.1.EncDecAttention.q.weight,decoder.block.9.layer.1.EncDecAttention.k.weight,decoder.block.9.layer.1.EncDecAttention.v.weight,decoder.block.9.layer.1.EncDecAttention.o.weight,decoder.block.9.layer.1.layer_norm.weight,decoder.block.9.layer.2.DenseReluDense.wi_0.weight,decoder.block.9.layer.2.DenseReluDense.wi_1.weight,decoder.block.9.layer.2.DenseReluDense.wo.weight,decoder.block.9.layer.2.layer_norm.weight,decoder.block.10.layer.0.SelfAttention.q.weight,decoder.block.10.layer.0.SelfAttention.k.weight,decoder.block.10.layer.0.SelfAttention.v.weight,decoder.block.10.layer.0.SelfAttention.o.weight,decoder.block.10.layer.0.layer_norm.weight,decoder.block.10.layer.1.EncDecAttention.q.weight,decoder.block.10.layer.1.EncDecAttention.k.weight,decoder.block.10.layer.1.EncDecAttention.v.weight,decoder.block.10.layer.1.EncDecAttention.o.weight,decoder.block.10.layer.1.layer_norm.weight,decoder.block.10.layer.2.DenseReluDense.wi_0.weight,decoder.block.10.layer.2.DenseReluDense.wi_1.weight,decoder.block.10.layer.2.DenseReluDense.wo.weight,decoder.block.10.layer.2.layer_norm.weight,decoder.block.11.layer.0.SelfAttention.q.weight,decoder.block.11.layer.0.SelfAttention.k.weight,decoder.block.11.layer.0.SelfAttention.v.weight,decoder.block.11.layer.0.SelfAttention.o.weight,decoder.block.11.layer.0.layer_norm.weight,decoder.block.11.layer.1.EncDecAttention.q.weight,decoder.block.11.layer.1.EncDecAttention.k.weight,decoder.block.11.layer.1.EncDecAttention.v.weight,decoder.block.11.layer.1.EncDecAttention.o.weight,decoder.block.11.layer.1.layer_norm.weight,decoder.block.11.layer.2.DenseReluDense.wi_0.weight,decoder.block.11.layer.2.DenseReluDense.wi_1.weight,decoder.block.11.layer.2.DenseReluDense.wo.weight,decoder.block.11.layer.2.layer_norm.weight,decoder.final_layer_norm.weight,lm_head.weight].
All weights are initialized.

构建训练器并训练

extra_para = {'pretrained_model_name_or_path':args.pretrained_model_name_or_path}
evaluator = SequenceGenerationEvaluator(valid_dataset=valid_dataset, user_defined_parameters=user_defined_parameters, **extra_para)
trainer = Trainer(model=model, train_dataset=train_dataset, user_defined_parameters=user_defined_parameters,
                evaluator=evaluator)
trainer.train()
[2022-12-08 17:49:54,705 INFO] ========== Initializing Tensorboard ==========
[2022-12-08 17:49:54,729 INFO] ========== Training Start ==========
[2022-12-08 17:49:54,731 INFO]   Num of GPUs (all)       = 1
[2022-12-08 17:49:54,732 INFO]   Num of CPUs per worker  = 1
[2022-12-08 17:49:54,732 INFO]   Num dataset examples    = 1000
[2022-12-08 17:49:54,733 INFO]   Num training examples   = 1000
[2022-12-08 17:49:54,733 INFO]   Num validation examples = 500
[2022-12-08 17:49:54,734 INFO]   Train. batch size       = 8
[2022-12-08 17:49:54,734 INFO]   Train. micro batch size = 8
[2022-12-08 17:49:54,734 INFO]   Train. batch no.        = 125
[2022-12-08 17:49:54,736 INFO]   Evaluation batch size   = 8
[2022-12-08 17:49:54,737 INFO]   Total training steps    = 125
[2022-12-08 17:49:54,737 INFO]   Sequence length         = 512
[2022-12-08 17:49:54,738 INFO]   Saving steps            = 150
[2022-12-08 17:49:54,739 INFO]   Distributed_backend     = nccl
[2022-12-08 17:49:54,739 INFO]   Worker Count            = 1
[2022-12-08 17:49:54,739 INFO]   Worker CPU              = -1
[2022-12-08 17:49:54,740 INFO]   Worker data threads     = 10
[2022-12-08 17:49:54,743 INFO]   num model params        = 275,029,248
[2022-12-08 17:49:54,744 INFO]   num trainable params    = 275,029,248
[2022-12-08 17:49:54,744 INFO] 
[2022-12-08 17:49:54,744 INFO] ========== Model Config ==========
[2022-12-08 17:49:54,745 INFO] {
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "easynlp_version": "0.0.3",
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "mt5",
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "tokenizer_class": "T5Tokenizer",
  "use_cache": true,
  "vocab_size": 50000
}
optimizer type: AdamW
/home/pai/lib/python3.6/site-packages/pai_easynlp-0.0.9-py3.6.egg/easynlp/core/optimizers.py:441: UserWarning: This overload of add_ is deprecated:
  add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
  add_(Tensor other, *, Number alpha) (Triggered internally at  /workspace/artifacts/paipytorch1.8/dist/ubuntu18.04-py3.6-cuda10.1/build/src/torch/csrc/utils/python_arg_parser.cpp:1005.)
  exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
/home/pai/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:247: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  warnings.warn("To get the last learning rate computed by the scheduler, "
[2022-12-08 17:50:14,705 INFO] Epoch [ 1/ 1], step [100/125], lr 0.000007, 19.96 s
[2022-12-08 17:50:14,707 INFO]   loss      : 3.6508 
Training Time: 25.2710120677948, rank 0, gsteps 125
100%|██████████| 500/500 [03:19<00:00,  2.51it/s]
[2022-12-08 17:53:39,253 INFO] Saving best model to ./finetuned_zh_model/pytorch_model.bin...
Rouge 1/2/L: 35.91/21.85/32.54
[2022-12-08 17:53:57,904 INFO] Best score: 32.535939404637126
[2022-12-08 17:53:57,907 INFO] Training Time: 243.49750781059265

模型评估

训练过程结束后,模型被我们保存在一开始指定好的checkpoint_dir中,本地路径为”./finetuned_zh_model/”。我们可以对训练好的模型进行效果评估。我们使用EasyNLP中的SequenceGenerationEvaluator来初始化evaluator,并将模型迁移至GPU机器,进行模型评估。

args.tables = "cn_dev.tsv"
extra_para = {'pretrained_model_name_or_path':args.pretrained_model_name_or_path}
evaluator = SequenceGenerationEvaluator(valid_dataset=valid_dataset, user_defined_parameters=user_defined_parameters, **extra_para)
if args.n_gpu > 0:
    model.to(torch.cuda.current_device())
else:
    model.to("cpu")
evaluator.evaluate(model=model)
/root/.easynlp/modelzoo/alibaba-pai/mt5-title-generation-zh
100%|██████████| 500/500 [02:50<00:00,  2.93it/s]
Rouge 1/2/L: 33.51/18.77/30.09
[('rouge-l', 30.08704702646614),
 ('rouge-1', 33.50990449925101),
 ('rouge-2', 18.76897464541499)]

模型预测

我们同样可以使用训练好的模型进行新闻标题生成。我们首先创建一个predictor,并据此实例化一个PredictorManager实例。我们指定预测好的结果输出在cn.preds.txt

args.tables = "cn_dev.tsv"
args.outputs = "cn.preds.txt"
args.input_schema = "title_tokens:str:1,content_tokens:str:1"
args.output_schema = "predictions,beams"
args.append_cols="title_tokens,content_tokens"
args.micro_batch_size = 32
extra_para = {'pretrained_model_name_or_path':args.pretrained_model_name_or_path, 'max_encoder_length':args.sequence_length}
predictor = SequenceGenerationPredictor(model_dir=args.checkpoint_dir, model_cls=SequenceGeneration,
                                      first_sequence=args.first_sequence, user_defined_parameters=user_defined_parameters, **extra_para)
predictor_manager = PredictorManager(
    predictor=predictor,
    input_file=args.tables.split(",")[0],
    input_schema=args.input_schema,
    output_file=args.outputs,
    output_schema=args.output_schema,
    append_cols=args.append_cols,
    batch_size=args.micro_batch_size
)
predictor_manager.run()
**language** parameter is not provided in user defined parameters, using zh as default.
 Loaded weights of the model:
 [shared.weight,encoder.embed_tokens.weight,encoder.block.0.layer.0.SelfAttention.q.weight,encoder.block.0.layer.0.SelfAttention.k.weight,encoder.block.0.layer.0.SelfAttention.v.weight,encoder.block.0.layer.0.SelfAttention.o.weight,encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,encoder.block.0.layer.0.layer_norm.weight,encoder.block.0.layer.1.DenseReluDense.wi_0.weight,encoder.block.0.layer.1.DenseReluDense.wi_1.weight,encoder.block.0.layer.1.DenseReluDense.wo.weight,encoder.block.0.layer.1.layer_norm.weight,encoder.block.1.layer.0.SelfAttention.q.weight,encoder.block.1.layer.0.SelfAttention.k.weight,encoder.block.1.layer.0.SelfAttention.v.weight,encoder.block.1.layer.0.SelfAttention.o.weight,encoder.block.1.layer.0.layer_norm.weight,encoder.block.1.layer.1.DenseReluDense.wi_0.weight,encoder.block.1.layer.1.DenseReluDense.wi_1.weight,encoder.block.1.layer.1.DenseReluDense.wo.weight,encoder.block.1.layer.1.layer_norm.weight,encoder.block.2.layer.0.SelfAttention.q.weight,encoder.block.2.layer.0.SelfAttention.k.weight,encoder.block.2.layer.0.SelfAttention.v.weight,encoder.block.2.layer.0.SelfAttention.o.weight,encoder.block.2.layer.0.layer_norm.weight,encoder.block.2.layer.1.DenseReluDense.wi_0.weight,encoder.block.2.layer.1.DenseReluDense.wi_1.weight,encoder.block.2.layer.1.DenseReluDense.wo.weight,encoder.block.2.layer.1.layer_norm.weight,encoder.block.3.layer.0.SelfAttention.q.weight,encoder.block.3.layer.0.SelfAttention.k.weight,encoder.block.3.layer.0.SelfAttention.v.weight,encoder.block.3.layer.0.SelfAttention.o.weight,encoder.block.3.layer.0.layer_norm.weight,encoder.block.3.layer.1.DenseReluDense.wi_0.weight,encoder.block.3.layer.1.DenseReluDense.wi_1.weight,encoder.block.3.layer.1.DenseReluDense.wo.weight,encoder.block.3.layer.1.layer_norm.weight,encoder.block.4.layer.0.SelfAttention.q.weight,encoder.block.4.layer.0.SelfAttention.k.weight,encoder.block.4.layer.0.SelfAttention.v.weight,encoder.block.4.layer.0.SelfAttention.o.weight,encoder.block.4.layer.0.layer_norm.weight,encoder.block.4.layer.1.DenseReluDense.wi_0.weight,encoder.block.4.layer.1.DenseReluDense.wi_1.weight,encoder.block.4.layer.1.DenseReluDense.wo.weight,encoder.block.4.layer.1.layer_norm.weight,encoder.block.5.layer.0.SelfAttention.q.weight,encoder.block.5.layer.0.SelfAttention.k.weight,encoder.block.5.layer.0.SelfAttention.v.weight,encoder.block.5.layer.0.SelfAttention.o.weight,encoder.block.5.layer.0.layer_norm.weight,encoder.block.5.layer.1.DenseReluDense.wi_0.weight,encoder.block.5.layer.1.DenseReluDense.wi_1.weight,encoder.block.5.layer.1.DenseReluDense.wo.weight,encoder.block.5.layer.1.layer_norm.weight,encoder.block.6.layer.0.SelfAttention.q.weight,encoder.block.6.layer.0.SelfAttention.k.weight,encoder.block.6.layer.0.SelfAttention.v.weight,encoder.block.6.layer.0.SelfAttention.o.weight,encoder.block.6.layer.0.layer_norm.weight,encoder.block.6.layer.1.DenseReluDense.wi_0.weight,encoder.block.6.layer.1.DenseReluDense.wi_1.weight,encoder.block.6.layer.1.DenseReluDense.wo.weight,encoder.block.6.layer.1.layer_norm.weight,encoder.block.7.layer.0.SelfAttention.q.weight,encoder.block.7.layer.0.SelfAttention.k.weight,encoder.block.7.layer.0.SelfAttention.v.weight,encoder.block.7.layer.0.SelfAttention.o.weight,encoder.block.7.layer.0.layer_norm.weight,encoder.block.7.layer.1.DenseReluDense.wi_0.weight,encoder.block.7.layer.1.DenseReluDense.wi_1.weight,encoder.block.7.layer.1.DenseReluDense.wo.weight,encoder.block.7.layer.1.layer_norm.weight,encoder.block.8.layer.0.SelfAttention.q.weight,encoder.block.8.layer.0.SelfAttention.k.weight,encoder.block.8.layer.0.SelfAttention.v.weight,encoder.block.8.layer.0.SelfAttention.o.weight,encoder.block.8.layer.0.layer_norm.weight,encoder.block.8.layer.1.DenseReluDense.wi_0.weight,encoder.block.8.layer.1.DenseReluDense.wi_1.weight,encoder.block.8.layer.1.DenseReluDense.wo.weight,encoder.block.8.layer.1.layer_norm.weight,encoder.block.9.layer.0.SelfAttention.q.weight,encoder.block.9.layer.0.SelfAttention.k.weight,encoder.block.9.layer.0.SelfAttention.v.weight,encoder.block.9.layer.0.SelfAttention.o.weight,encoder.block.9.layer.0.layer_norm.weight,encoder.block.9.layer.1.DenseReluDense.wi_0.weight,encoder.block.9.layer.1.DenseReluDense.wi_1.weight,encoder.block.9.layer.1.DenseReluDense.wo.weight,encoder.block.9.layer.1.layer_norm.weight,encoder.block.10.layer.0.SelfAttention.q.weight,encoder.block.10.layer.0.SelfAttention.k.weight,encoder.block.10.layer.0.SelfAttention.v.weight,encoder.block.10.layer.0.SelfAttention.o.weight,encoder.block.10.layer.0.layer_norm.weight,encoder.block.10.layer.1.DenseReluDense.wi_0.weight,encoder.block.10.layer.1.DenseReluDense.wi_1.weight,encoder.block.10.layer.1.DenseReluDense.wo.weight,encoder.block.10.layer.1.layer_norm.weight,encoder.block.11.layer.0.SelfAttention.q.weight,encoder.block.11.layer.0.SelfAttention.k.weight,encoder.block.11.layer.0.SelfAttention.v.weight,encoder.block.11.layer.0.SelfAttention.o.weight,encoder.block.11.layer.0.layer_norm.weight,encoder.block.11.layer.1.DenseReluDense.wi_0.weight,encoder.block.11.layer.1.DenseReluDense.wi_1.weight,encoder.block.11.layer.1.DenseReluDense.wo.weight,encoder.block.11.layer.1.layer_norm.weight,encoder.final_layer_norm.weight,decoder.embed_tokens.weight,decoder.block.0.layer.0.SelfAttention.q.weight,decoder.block.0.layer.0.SelfAttention.k.weight,decoder.block.0.layer.0.SelfAttention.v.weight,decoder.block.0.layer.0.SelfAttention.o.weight,decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight,decoder.block.0.layer.0.layer_norm.weight,decoder.block.0.layer.1.EncDecAttention.q.weight,decoder.block.0.layer.1.EncDecAttention.k.weight,decoder.block.0.layer.1.EncDecAttention.v.weight,decoder.block.0.layer.1.EncDecAttention.o.weight,decoder.block.0.layer.1.layer_norm.weight,decoder.block.0.layer.2.DenseReluDense.wi_0.weight,decoder.block.0.layer.2.DenseReluDense.wi_1.weight,decoder.block.0.layer.2.DenseReluDense.wo.weight,decoder.block.0.layer.2.layer_norm.weight,decoder.block.1.layer.0.SelfAttention.q.weight,decoder.block.1.layer.0.SelfAttention.k.weight,decoder.block.1.layer.0.SelfAttention.v.weight,decoder.block.1.layer.0.SelfAttention.o.weight,decoder.block.1.layer.0.layer_norm.weight,decoder.block.1.layer.1.EncDecAttention.q.weight,decoder.block.1.layer.1.EncDecAttention.k.weight,decoder.block.1.layer.1.EncDecAttention.v.weight,decoder.block.1.layer.1.EncDecAttention.o.weight,decoder.block.1.layer.1.layer_norm.weight,decoder.block.1.layer.2.DenseReluDense.wi_0.weight,decoder.block.1.layer.2.DenseReluDense.wi_1.weight,decoder.block.1.layer.2.DenseReluDense.wo.weight,decoder.block.1.layer.2.layer_norm.weight,decoder.block.2.layer.0.SelfAttention.q.weight,decoder.block.2.layer.0.SelfAttention.k.weight,decoder.block.2.layer.0.SelfAttention.v.weight,decoder.block.2.layer.0.SelfAttention.o.weight,decoder.block.2.layer.0.layer_norm.weight,decoder.block.2.layer.1.EncDecAttention.q.weight,decoder.block.2.layer.1.EncDecAttention.k.weight,decoder.block.2.layer.1.EncDecAttention.v.weight,decoder.block.2.layer.1.EncDecAttention.o.weight,decoder.block.2.layer.1.layer_norm.weight,decoder.block.2.layer.2.DenseReluDense.wi_0.weight,decoder.block.2.layer.2.DenseReluDense.wi_1.weight,decoder.block.2.layer.2.DenseReluDense.wo.weight,decoder.block.2.layer.2.layer_norm.weight,decoder.block.3.layer.0.SelfAttention.q.weight,decoder.block.3.layer.0.SelfAttention.k.weight,decoder.block.3.layer.0.SelfAttention.v.weight,decoder.block.3.layer.0.SelfAttention.o.weight,decoder.block.3.layer.0.layer_norm.weight,decoder.block.3.layer.1.EncDecAttention.q.weight,decoder.block.3.layer.1.EncDecAttention.k.weight,decoder.block.3.layer.1.EncDecAttention.v.weight,decoder.block.3.layer.1.EncDecAttention.o.weight,decoder.block.3.layer.1.layer_norm.weight,decoder.block.3.layer.2.DenseReluDense.wi_0.weight,decoder.block.3.layer.2.DenseReluDense.wi_1.weight,decoder.block.3.layer.2.DenseReluDense.wo.weight,decoder.block.3.layer.2.layer_norm.weight,decoder.block.4.layer.0.SelfAttention.q.weight,decoder.block.4.layer.0.SelfAttention.k.weight,decoder.block.4.layer.0.SelfAttention.v.weight,decoder.block.4.layer.0.SelfAttention.o.weight,decoder.block.4.layer.0.layer_norm.weight,decoder.block.4.layer.1.EncDecAttention.q.weight,decoder.block.4.layer.1.EncDecAttention.k.weight,decoder.block.4.layer.1.EncDecAttention.v.weight,decoder.block.4.layer.1.EncDecAttention.o.weight,decoder.block.4.layer.1.layer_norm.weight,decoder.block.4.layer.2.DenseReluDense.wi_0.weight,decoder.block.4.layer.2.DenseReluDense.wi_1.weight,decoder.block.4.layer.2.DenseReluDense.wo.weight,decoder.block.4.layer.2.layer_norm.weight,decoder.block.5.layer.0.SelfAttention.q.weight,decoder.block.5.layer.0.SelfAttention.k.weight,decoder.block.5.layer.0.SelfAttention.v.weight,decoder.block.5.layer.0.SelfAttention.o.weight,decoder.block.5.layer.0.layer_norm.weight,decoder.block.5.layer.1.EncDecAttention.q.weight,decoder.block.5.layer.1.EncDecAttention.k.weight,decoder.block.5.layer.1.EncDecAttention.v.weight,decoder.block.5.layer.1.EncDecAttention.o.weight,decoder.block.5.layer.1.layer_norm.weight,decoder.block.5.layer.2.DenseReluDense.wi_0.weight,decoder.block.5.layer.2.DenseReluDense.wi_1.weight,decoder.block.5.layer.2.DenseReluDense.wo.weight,decoder.block.5.layer.2.layer_norm.weight,decoder.block.6.layer.0.SelfAttention.q.weight,decoder.block.6.layer.0.SelfAttention.k.weight,decoder.block.6.layer.0.SelfAttention.v.weight,decoder.block.6.layer.0.SelfAttention.o.weight,decoder.block.6.layer.0.layer_norm.weight,decoder.block.6.layer.1.EncDecAttention.q.weight,decoder.block.6.layer.1.EncDecAttention.k.weight,decoder.block.6.layer.1.EncDecAttention.v.weight,decoder.block.6.layer.1.EncDecAttention.o.weight,decoder.block.6.layer.1.layer_norm.weight,decoder.block.6.layer.2.DenseReluDense.wi_0.weight,decoder.block.6.layer.2.DenseReluDense.wi_1.weight,decoder.block.6.layer.2.DenseReluDense.wo.weight,decoder.block.6.layer.2.layer_norm.weight,decoder.block.7.layer.0.SelfAttention.q.weight,decoder.block.7.layer.0.SelfAttention.k.weight,decoder.block.7.layer.0.SelfAttention.v.weight,decoder.block.7.layer.0.SelfAttention.o.weight,decoder.block.7.layer.0.layer_norm.weight,decoder.block.7.layer.1.EncDecAttention.q.weight,decoder.block.7.layer.1.EncDecAttention.k.weight,decoder.block.7.layer.1.EncDecAttention.v.weight,decoder.block.7.layer.1.EncDecAttention.o.weight,decoder.block.7.layer.1.layer_norm.weight,decoder.block.7.layer.2.DenseReluDense.wi_0.weight,decoder.block.7.layer.2.DenseReluDense.wi_1.weight,decoder.block.7.layer.2.DenseReluDense.wo.weight,decoder.block.7.layer.2.layer_norm.weight,decoder.block.8.layer.0.SelfAttention.q.weight,decoder.block.8.layer.0.SelfAttention.k.weight,decoder.block.8.layer.0.SelfAttention.v.weight,decoder.block.8.layer.0.SelfAttention.o.weight,decoder.block.8.layer.0.layer_norm.weight,decoder.block.8.layer.1.EncDecAttention.q.weight,decoder.block.8.layer.1.EncDecAttention.k.weight,decoder.block.8.layer.1.EncDecAttention.v.weight,decoder.block.8.layer.1.EncDecAttention.o.weight,decoder.block.8.layer.1.layer_norm.weight,decoder.block.8.layer.2.DenseReluDense.wi_0.weight,decoder.block.8.layer.2.DenseReluDense.wi_1.weight,decoder.block.8.layer.2.DenseReluDense.wo.weight,decoder.block.8.layer.2.layer_norm.weight,decoder.block.9.layer.0.SelfAttention.q.weight,decoder.block.9.layer.0.SelfAttention.k.weight,decoder.block.9.layer.0.SelfAttention.v.weight,decoder.block.9.layer.0.SelfAttention.o.weight,decoder.block.9.layer.0.layer_norm.weight,decoder.block.9.layer.1.EncDecAttention.q.weight,decoder.block.9.layer.1.EncDecAttention.k.weight,decoder.block.9.layer.1.EncDecAttention.v.weight,decoder.block.9.layer.1.EncDecAttention.o.weight,decoder.block.9.layer.1.layer_norm.weight,decoder.block.9.layer.2.DenseReluDense.wi_0.weight,decoder.block.9.layer.2.DenseReluDense.wi_1.weight,decoder.block.9.layer.2.DenseReluDense.wo.weight,decoder.block.9.layer.2.layer_norm.weight,decoder.block.10.layer.0.SelfAttention.q.weight,decoder.block.10.layer.0.SelfAttention.k.weight,decoder.block.10.layer.0.SelfAttention.v.weight,decoder.block.10.layer.0.SelfAttention.o.weight,decoder.block.10.layer.0.layer_norm.weight,decoder.block.10.layer.1.EncDecAttention.q.weight,decoder.block.10.layer.1.EncDecAttention.k.weight,decoder.block.10.layer.1.EncDecAttention.v.weight,decoder.block.10.layer.1.EncDecAttention.o.weight,decoder.block.10.layer.1.layer_norm.weight,decoder.block.10.layer.2.DenseReluDense.wi_0.weight,decoder.block.10.layer.2.DenseReluDense.wi_1.weight,decoder.block.10.layer.2.DenseReluDense.wo.weight,decoder.block.10.layer.2.layer_norm.weight,decoder.block.11.layer.0.SelfAttention.q.weight,decoder.block.11.layer.0.SelfAttention.k.weight,decoder.block.11.layer.0.SelfAttention.v.weight,decoder.block.11.layer.0.SelfAttention.o.weight,decoder.block.11.layer.0.layer_norm.weight,decoder.block.11.layer.1.EncDecAttention.q.weight,decoder.block.11.layer.1.EncDecAttention.k.weight,decoder.block.11.layer.1.EncDecAttention.v.weight,decoder.block.11.layer.1.EncDecAttention.o.weight,decoder.block.11.layer.1.layer_norm.weight,decoder.block.11.layer.2.DenseReluDense.wi_0.weight,decoder.block.11.layer.2.DenseReluDense.wi_1.weight,decoder.block.11.layer.2.DenseReluDense.wo.weight,decoder.block.11.layer.2.layer_norm.weight,decoder.final_layer_norm.weight,lm_head.weight].
All weights are initialized.
[2022-12-08 17:54:07,569 INFO] Using SimplePredict to predict...
16it [04:16, 16.06s/it]
print('Labeled samples:')
! tail -n 3 cn_dev.tsv
print('Predicted results:')
! tail -n 3 cn.preds.txt
Labeled samples:
上海嘉定实验小学门口今晨7车连撞1人受伤  今晨7时45分左右,位于平城路上的嘉定区实验小学门口发生7车连撞事故。事发时正值学生上学高峰,路上车流量大,车速慢。据悉,先是两车相撞,随后引发7车连环撞。事故造成1人受伤。警方正对涉事驾驶员进行验血调查。
王中军:投资就像追女孩并购银汉幕后军师是腾讯  “从投资到现在为止没有像我想的那么顺利在接受专访时,华谊兄弟董事长王中军还打趣说,投资谈判遇到阻碍很正常,“就像我要追一个女孩,不是说我今年把她追到手就真能追到手,人家可能三年后才同意,完全有这个可能。
营销不宜过度“快消化” 对于快消品来说,其在选择数字媒体平台时需要考虑:1、确保能够接触到主流受众的规模;2、有效地提高受众接触度并运用整合性的媒体投放;3、迎合目标受众的需求。如果营销过度“快消化”,则不利于长品牌的长期价值构建和与消费者之间的沟通。
Predicted results:
嘉定一小学门口发生7车连撞事故1人受伤 嘉定一小学门口发生7车连撞事故1人受伤||嘉定一小学门口发生7车连撞事故1人受伤||嘉定实验小学门口发生7车连撞事故1人受伤||嘉定实验小学7车连撞事故1人受伤||嘉定实验小学门口发生7车连撞事故  上海嘉定实验小学门口今晨7车连撞1人受伤  今晨7时45分左右,位于平城路上的嘉定区实验小学门口发生7车连撞事故。事发时正值学生上学高峰,路上车流量大,车速慢。据悉,先是两车相撞,随后引发7车连环撞。事故造成1人受伤。警方正对涉事驾驶员进行验血调查。
王中军:从投资到现在为止没有像我想的那么顺利  王中军:从投资到现在为止没有像我想的那么顺利||华谊董事长:从投资到现在为止没有像我想的那么顺利||华谊董事长王中军:从投资到现在为止没有像想的那么顺利||“从投资到现在为止没有像我想的那么顺利”||“从投资到现在为止没有像我想的那么顺利 王中军:投资就像追女孩并购银汉幕后军师是腾讯  “从投资到现在为止没有像我想的那么顺利在接受专访时,华谊兄弟董事长王中军还打趣说,投资谈判遇到阻碍很正常,“就像我要追一个女孩,不是说我今年把她追到手就真能追到手,人家可能三年后才同意,完全有这个可能。
选择数字媒体平台需考虑 选择数字媒体平台需考虑||选择数字媒体平台需谨慎||快消品选择数字媒体平台||快消品选择数字媒体需考虑||快消品选择数字媒体平台需考虑 营销不宜过度“快消化” 对于快消品来说,其在选择数字媒体平台时需要考虑:1、确保能够接触到主流受众的规模;2、有效地提高受众接触度并运用整合性的媒体投放;3、迎合目标受众的需求。如果营销过度“快消化”,则不利于长品牌的长期价值构建和与消费者之间的沟通。

上面展示了数据集中的1条数据以及经过训练以后模型的预测结果。预测的标题为第一列,第二列为集束搜索(beam search)的5条结果,以||隔开。预测结果同时还包含原始新闻标题、新闻原文,相互之间以\t隔开。

一步执行

值得一提的是,上述所有训练/评估/预测代码,都已经被集成在EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py中,此外,我们也预先编写好了多种可供直接执行的脚本。用户可以通过带参数运行上述main.py文件,或者直接执行脚本文件run_user_defined_local_zh.sh的方式,一步执行上述所有训练/评估/预测操作。

mian.py文件一步执行

用户通过以下代码带参数执行main.py中的指令,可直接对模型进行训练/评估/预测操作。 训练代码指令如下。具体的参数解释可见上文,此处不再赘述。

模型训练代码如下:

! python EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py \
    --mode train \
    --app_name=sequence_generation \
    --worker_gpu=1 \
    --tables=./cn_train.tsv,./cn_dev.tsv  \
    --input_schema=target:str:1,source:str:1 \
    --first_sequence=source \
    --second_sequence=target \
    --label_name=target \
    --checkpoint_dir=./finetuned_zh_model \
    --micro_batch_size=8 \
    --sequence_length=512 \
    --epoch_num=1  \
    --save_checkpoint_steps=150 \
    --export_tf_checkpoint_type none \
    --user_defined_parameters 'pretrain_model_name_or_path=hfl/randeng-summary-generation-base-zh language=zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5'

模型评估代码如下:

! python EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py \
    --mode=evaluate \
    --app_name=sequence_generation \
    --worker_gpu=1 \
    --tables=./cn_dev.tsv  \
    --input_schema=target:str:1,source:str:1 \
    --first_sequence=source \
    --second_sequence=target \
    --label_name=target \
    --checkpoint_dir=./finetuned_zh_model \
    --micro_batch_size=8 \
    --sequence_length=512 \
    --epoch_num=1  \
    --save_checkpoint_steps=150 \
    --export_tf_checkpoint_type none \
    --user_defined_parameters 'language=zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5'

模型预测代码如下:

! python EasyNLP/examples/appzoo_tutorials/sequence_generation/main.py \
    --mode=predict \
    --app_name=sequence_generation \
    --worker_gpu=1 \
    --tables=./cn_dev.tsv  \
    --outputs=./cn.preds.txt \
    --input_schema=target:str:1,source:str:1 \
    --output_schema=predictions,beams \
    --append_cols=target,source \
    --first_sequence=source \
    --checkpoint_dir=./finetuned_zh_model \
    --micro_batch_size=32 \
    --sequence_length=512 \
    --user_defined_parameters 'language=zh copy=false max_encoder_length=512 min_decoder_length=12 max_decoder_length=32 no_repeat_ngram_size=2 num_beams=5 num_return_sequences=5'

利用bash文件命令行执行

我们在EasyNLP/examples/appzoo_tutorials/sequence_generation文件夹下封装好了多种可直接执行的脚本,用户同样可以通过带参数执行脚本文件的方式来一步完成模型的训练/评估/预测。以下以run_user_defined_local_zh.sh脚本为例。该脚本文件需要传入两个参数,第一个参数为运行程序的GPU编号,一般为0;第二个参数代表模型的训练/评估/预测。

注意:训练、评估、预测的相关参数在文件run_user_defined_local_zh中,如需使用自有数据和参数则应修改其中的参数。

模型训练:

! cd EasyNLP/examples/appzoo_tutorials/sequence_generation && bash run_user_defined_local_zh.sh 0 train

模型评估:

! bash EasyNLP/examples/appzoo_tutorials/sequence_generation/run_user_defined_local_zh.sh 0 evaluate

模型预测:

! cd EasyNLP/examples/appzoo_tutorials/sequence_generation && bash run_user_defined_local_zh.sh 0 predict
相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
自然语言处理 数据格式
【DSW Gallery】基于ModelScope的中文GPT-3模型(1.3B)的微调训练
本文基于ModelScope,以GPT-3(1.3B)为例介绍如何使用ModelScope-GPT3进行续写训练与输入输出形式的训练,训练方式不需要额外指定,训练数据集仅包含 src_txt 时会进行续写训练,同时包含 src_txt 和 tgt_txt 时会进行输入输出形式的训练。
【DSW Gallery】基于ModelScope的中文GPT-3模型(1.3B)的微调训练
|
4月前
|
机器学习/深度学习 IDE 开发工具
ARTIST的中文文图生成模型问题之什么是PAI-DSW
ARTIST的中文文图生成模型问题之什么是PAI-DSW
|
7月前
|
存储 人工智能 算法
基于向量检索服务与ModelScope模型搭建文本搜图片---魏红斌版
【1月更文挑战第9天】综合产品理解和实操经验,总结向量检索服务的综合水平
98974 4
基于向量检索服务与ModelScope模型搭建文本搜图片---魏红斌版
|
人工智能 编解码 自然语言处理
Midjourney|文心一格 Prompt:完整参数列表、风格汇总、文生图词典合集
Midjourney|文心一格 Prompt:完整参数列表、风格汇总、文生图词典合集
Midjourney|文心一格 Prompt:完整参数列表、风格汇总、文生图词典合集
|
缓存 自然语言处理 Shell
【DSW Gallery】基于CK-BERT的中文序列标注
EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文以序列标注(命名实体识别)为例,为您介绍如何在PAI-DSW中使用EasyNLP。
【DSW Gallery】基于CK-BERT的中文序列标注
|
缓存 自然语言处理 算法
【DSW Gallery】基于EasyNLP Transformer模型的中文文图生成
EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文简要介绍文图生成的技术,以及如何在PAI-DSW中基于EasyNLP轻松实现文图生成,带你秒变艺术家。
【DSW Gallery】基于EasyNLP Transformer模型的中文文图生成
|
机器学习/深度学习 并行计算 数据可视化
【DSW Gallery】EasyCV-基于关键点的视频分类示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以基于关键点的视频分类为例,为您介绍如何在PAI-DSW中使用EasyCV。
【DSW Gallery】EasyCV-基于关键点的视频分类示例
|
算法 PyTorch 算法框架/工具
【DSW Gallery】基于EasyCV的视频分类示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以视频分类为例,为您介绍如何在PAI-DSW中使用EasyCV。
【DSW Gallery】基于EasyCV的视频分类示例
|
人工智能 并行计算 算法
【DSW Gallery】基于MOCOV2的自监督学习示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以自监督学习-MOCO为例,为您介绍如何在PAI-DSW中使用EasyCV。
【DSW Gallery】基于MOCOV2的自监督学习示例
|
机器学习/深度学习 算法
【DSW Gallery】如何使用EasyRec训练DeepFM模型
本文基于EasyRec 0.4.7 展示了如何使用EasyRec快速的训练一个DeepFM模型
【DSW Gallery】如何使用EasyRec训练DeepFM模型

热门文章

最新文章