【DSW Gallery】基于EasyNLP Transformer模型的中文文图生成

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文简要介绍文图生成的技术,以及如何在PAI-DSW中基于EasyNLP轻松实现文图生成,带你秒变艺术家。

直接使用

请打开基于EasyNLP Transformer模型的中文文图生成,并点击右上角 “ 在DSW中打开” 。

image.png

基于EasyNLP Transformer模型的中文文图生成

EasyNLP是阿里云机器学习PAI算法团队基于PyTorch开发的易用且丰富的NLP算法框架( https://github.com/alibaba/EasyNLP ),支持常用的中文预训练模型和大模型落地技术,并且提供了从训练到部署的一站式NLP开发体验。EasyNLP提供了简洁的接口供用户开发NLP模型,包括NLP应用AppZoo和预训练ModelZoo,同时提供技术帮助用户高效地落地超大预训练模型到业务,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。由于跨模态理解需求的不断增加,EasyNLP也将支持各种跨模态模型,特别是中文领域的跨模态模型,希望能够服务更多的NLP和多模态算法开发者和研究者。 本文简要介绍文图生成的技术,以及如何在PAI-DSW中基于EasyNLP轻松实现文图生成,带你秒变艺术家。

ARTITST模型详解

ARTITST模型的构建基于Transformer模型 ,将文图生成任务分为两个阶段进行,第一阶段是通过VQGAN模型对图像进行矢量量化,对于输入的图像,通过编码器将图像编码为定长的离散序列,解码阶段是以离散序列作为输入,输出重构图。第二阶段是将文本序列和编码后的图像序列作为输入,利用GPT模型学习以文本序列为条件的图像序列生成。为了增强模型先验,我们将知识图谱中的的实体知识引入模型,辅助图像中对应实体的生成,从而使得生成的图像的实体信息更加精准。

运行环境要求

建议用户使用:Python 3.6,Pytorch 1.8镜像,GPU机型 P100 or V100,内存至少为 32G

EasyNLP安装

建议从GitHub下载EasyNLP源代码进行安装,命令如下:

! git clone https://github.com/alibaba/EasyNLP.git
Cloning into 'EasyNLP'...
remote: Enumerating objects: 3359, done.
remote: Counting objects: 100% (3359/3359), done.
remote: Compressing objects: 100% (1171/1171), done.
remote: Total 3359 (delta 2137), reused 3298 (delta 2114), pack-reused 0
Receiving objects: 100% (3359/3359), 1.56 MiB | 1.80 MiB/s, done.
Resolving deltas: 100% (2137/2137), done.
! echo y | pip uninstall pai-easynlp easynlp
! pip install -r EasyNLP/requirements.txt -i http://mirrors.aliyun.com/pypi/simple
! cd EasyNLP && python setup.py install

安装完成easynlp之后,建议重启notebook,防止环境存在缓存,未更新

您可以使用如下命令验证是否安装成功:

import easynlp
easynlp.__file__
/home/pai/bin/easynlp

数据准备

首先,您需要进入指定目录,下载用于本示例的训练数据与验证数据,以及用于提取向量进行向量检索的单列测试数据。命令如下:

! wget  -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_train_text_imgbase64.tsv
! wget  -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_val_text_imgbase64.tsv
! wget  -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_test.text.tsv

训练数据和验证数据都为.tsv文件。每行为一个数据,以制表符\t分隔为三列(idx, text, imgbase64),第一列是文本编号,第二列是文本,第三列是对应图片的base64编码。

将输入数据与lattice、entity位置信息拼接到一起:输出格式为以制表符\t分隔的几列(idx, text, lex_ids, pos_s, pos_e, seq_len, [Optional] imgbase64)

初始化

在Python 3.6环境下,我们首先从刚刚安装好的EasyNLP中引入模型运行需要的各种库,并做一些初始化。在本教程中,我们使用的文图生成模型为pai-painter-base-zh。EasyNLP中集成了丰富的预训练模型库,如果想尝试其他预训练模型,也可以在user_defined_parameters中进行相应修改,具体的模型名称可见模型列表。

# 为了避免EasyNLP中的args与Jupyter系统的冲突,需要手动设置,否则无法进行初始化。
# 在命令行或py文件中运行文中代码则可忽略下述代码。
import sys
import imp
import os
sys.argv = ['main.py']
sys.path.append('./')
import torch.cuda
from easynlp.core import Trainer
# from easynlp.appzoo import get_application_evaluator
from easynlp.appzoo.sequence_classification.data import ClassificationDataset
from easynlp.appzoo import TextImageDataset
from easynlp.appzoo import TextImageGeneration
from easynlp.appzoo import TextImageGenerationEvaluator
from easynlp.appzoo import TextImageGenerationPredictor
from easynlp.utils import initialize_easynlp, get_args
from easynlp.utils.global_vars import parse_user_defined_parameters
from easynlp.core import PredictorManager
from easynlp.utils import get_pretrain_model_path
initialize_easynlp()
args = get_args()
user_defined_parameters = parse_user_defined_parameters('pretrain_model_name_or_path=alibaba-pai/pai-painter-base-zh')
args.checkpoint_dir = "./painter_model/"
args.pretrained_model_name_or_path = "alibaba-pai/pai-painter-base-zh"
pretrained_model_name_or_path = get_pretrain_model_path(args.pretrained_model_name_or_path)
[2022-10-19 06:30:59,381.381 dsw-150674-58485769d8-67rvk:5526 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:8: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
  from cryptography import utils, x509
Please ignore the following import error if you are using tunnel table io.
No module named '_common_io'
No module named 'easy_predict'
------------------------ arguments ------------------------
  app_name ........................................ text_classify
  append_cols ..................................... None
  buckets ......................................... None
  checkpoint_dir .................................. None
  chief_hosts ..................................... 
  data_threads .................................... 10
  distributed_backend ............................. nccl
  do_lower_case ................................... False
  epoch_num ....................................... 3.0
  export_tf_checkpoint_type ....................... easytransfer
  first_sequence .................................. None
  gradient_accumulation_steps ..................... 1
  input_schema .................................... None
  is_chief ........................................ 
  is_master_node .................................. True
  job_name ........................................ None
  label_enumerate_values .......................... None
  label_name ...................................... None
  learning_rate ................................... 5e-05
  local_rank ...................................... None
  logging_steps ................................... 100
  master_port ..................................... 23456
  max_grad_norm ................................... 1.0
  micro_batch_size ................................ 2
  mode ............................................ train
  modelzoo_base_dir ............................... 
  n_cpu ........................................... 1
  n_gpu ........................................... 1
  odps_config ..................................... None
  optimizer_type .................................. AdamW
  output_schema ................................... 
  outputs ......................................... None
  predict_queue_size .............................. 1024
  predict_slice_size .............................. 4096
  predict_table_read_thread_num ................... 16
  predict_thread_num .............................. 2
  ps_hosts ........................................ 
  random_seed ..................................... 1234
  rank ............................................ 0
  read_odps ....................................... False
  restore_works_dir ............................... ./.easynlp_predict_restore_works_dir
  resume_from_checkpoint .......................... None
  save_all_checkpoints ............................ False
  save_checkpoint_steps ........................... None
  second_sequence ................................. None
  sequence_length ................................. 16
  skip_first_line ................................. False
  tables .......................................... None
  task_count ...................................... 1
  task_index ...................................... 0
  use_amp ......................................... False
  use_torchacc .................................... False
  user_defined_parameters ......................... None
  user_entry_file ................................. None
  user_script ..................................... None
  warmup_proportion ............................... 0.1
  weight_decay .................................... 0.0001
  worker_count .................................... 1
  worker_cpu ...................................... -1
  worker_gpu ...................................... -1
  worker_hosts .................................... None
  world_size ...................................... 1
-------------------- end of arguments ---------------------
> initializing torch distributed ...
[2022-10-19 06:31:01,724.724 dsw-150674-58485769d8-67rvk:5526 INFO distributed_c10d.py:195] Added key: store_based_barrier_key:1 to store for rank: 0
Init dist done. World size: 1, rank 0, l_rank 0
> setting random seeds to 1234 ...
Downloading `alibaba-pai/pai-painter-base-zh` to /root/.easynlp/modelzoo/alibaba-pai/pai-painter-base-zh.tgz

注意:上述代码如果出现“Address already in use”错误,则需要运行以下代码清理端口上正在执行的程序。 netstat -tunlp|grep 6000 kill -9 PID (需要替换成上一行代码执行结果中对应的程序ID)

载入数据

我们使用EasyNLP中自带的TextImageDataset,对训练和测试数据进行载入。主要参数如下: pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"pai-painter-base-zh"以得到其路径,并自动下载模型 max_seq_length:文本最大长度,超过将截断,不足将padding input_schema:输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等 first_sequence、second_sequence:用于说明input_schema中哪些字段作为第一/第二列输入数据 is_training:是否为训练过程,train_dataset为True,valid_dataset为False

user_defined_parameters={"pretrain_model_name_or_path":"alibaba-pai/pai-painter-base-zh",
                         "size":256, "text_len":32, "img_len":256 ,"img_vocab_size":16384}
train_dataset = TextImageDataset(
        pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/pai-painter-base-zh"),
        data_file="./tmp/MUGE_train_text_imgbase64.tsv",
        max_seq_length=32,
        input_schema="idx:str:1,text:str:1,imgbase64:str:1",
        first_sequence="text",
        second_sequence="imgbase64",
        user_defined_parameters=user_defined_parameters,
        is_training=True)
user_defined_parameters={"size":256, "text_len":32, "img_len":256, "img_vocab_size":16384,"max_generated_num":1}
valid_dataset = TextImageDataset(
        pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/pai-painter-base-zh"),
        data_file="./tmp/MUGE_val_text_imgbase64.tsv",
        max_seq_length=32,
        input_schema="idx:str:1,text:str:1,imgbase64:str:1",
        first_sequence="text",
        second_sequence="imgbase64",
        user_defined_parameters=user_defined_parameters,
        is_training=False)
`/root/.easynlp/modelzoo/alibaba-pai/pai-painter-base-zh.tgz` already exists
****./tmp/MUGE_train_text_imgbase64.tsv
`/root/.easynlp/modelzoo/alibaba-pai/pai-painter-base-zh.tgz` already exists
****./tmp/MUGE_val_text_imgbase64.tsv

模型训练

处理好数据与模型载入后,我们开始训练模型。 我们使用EasyNLP中封装好的get_application_model函数进行训练时的模型构建,其参数如下: app_name:任务名称,这里选择文本分类"clip" pretrained_model_name_or_path:预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"pai-painter-base-zh"以得到其路径,并自动下载模型 user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters

model = TextImageGeneration(
                pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/pai-painter-base-zh"), 
                user_defined_parameters=user_defined_parameters)
`/root/.easynlp/modelzoo/alibaba-pai/pai-painter-base-zh.tgz` already exists
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Restored from /root/.easynlp/modelzoo/alibaba-pai/pai-painter-base-zh

从日志中可以看出,我们对预训练模型的参数进行了载入。下一步我们使用EasyNLP中的Train类创建训练实例,并进行训练。

evaluator = TextImageGenerationEvaluator(valid_dataset=valid_dataset,
                                         user_defined_parameters=user_defined_parameters)
trainer = Trainer(model=model, 
                  train_dataset=train_dataset, 
                  user_defined_parameters=user_defined_parameters,
                  evaluator=evaluator)
trainer.train()

模型预测

我们可以使用训练好的模型进行预测,也就是文本和图片的特征向量提取。我们首先创建一个predictor,并据此实例化一个PredictorManager实例。以文本特征向量提取为例,我们指定输入为MUGE_test.text.tsv,预测好的结果输出在"T2I_outputs.tsv",并指定输出格式为"output_schema"。

predictor = TextImageGenerationPredictor(model_dir="./painter_model/", 
                                         model_cls=TextImageGeneration,
                                         first_sequence="text", 
                                         user_defined_parameters=user_defined_parameters)
predictor_manager = PredictorManager(
            predictor=predictor,
            input_file="./temp/MUGE_test.text.tsv",
            input_schema="idx:str:1,text:str:1",
            output_file="T2I_outputs.tsv",
            output_schema="idx,text,gen_imgbase64",
            append_cols=args.append_cols,
            batch_size=16
        )
predictor_manager.run()
exit()

一步执行

值得一提的是,上述所有训练/评估/预测代码,都已经被集成在EasyNLP/examples/text2image_generation/main.py中,此外,我们也预先编写好了多种可供直接执行的脚本。用户可以通过带参数运行main.py中指令,或者直接使用bash文件命令行执行的方式,一步执行上述所有训练/评估/预测操作。

main文件一步执行

用户通过以下代码带参数执行main.py中的指令,可直接对模型进行训练/评估/预测操作。 训练代码指令如下。参数中,tables指定了训练集和验证集tsv文件的路径,input_schema表示tsv的数据格式,first_sequence、second_sequence用于说明input_schema中哪些字段用于作为第一/第二列数据。模型存储的路径位于checkpoint_dir,learning_rate、epoch_num、random_seed、save_checkpoint_steps、sequence_length、train_batch_size等为训练的超参数。在本示例中,预训练模型指定为pai-painter-base-zh。

! python -m torch.distributed.launch $DISTRIBUTED_ARGS EasyNLP/examples/text2image_generation/main.py \
    --mode=train \
    --worker_gpu=1 \
    --tables=./tmp/MUGE_train_text_imgbase64.tsv,./tmp/MUGE_val_text_imgbase64.tsv \
    --input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
    --first_sequence=text \
    --second_sequence=imgbase64 \
    --checkpoint_dir=./tmp/finetune_model \
    --learning_rate=4e-5 \
    --epoch_num=40 \
    --random_seed=42 \
    --logging_steps=100 \
    --save_checkpoint_steps=1000 \
    --sequence_length=288 \
    --micro_batch_size=16 \
    --app_name=text2image_generation \
    --user_defined_parameters="\
        pretrain_model_name_or_path=alibaba-pai/pai-painter-base-zh\
        size=256\
        text_len=32\
        img_len=256\
        img_vocab_size=16384\
      "

预测代码如下

! python -m torch.distributed.launch $DISTRIBUTED_ARGS EasyNLP/examples/text2image_generation/main.py \
    --mode=predict \
    --worker_gpu=1 \
    --tables=./tmp/MUGE_test.text.tsv \
    --input_schema=idx:str:1,text:str:1 \
    --first_sequence=text \
    --outputs=./tmp/T2I_outputs.tsv \
    --output_schema=idx,text,gen_imgbase64 \
    --checkpoint_dir=./tmp/finetune_model \
    --sequence_length=288 \
    --micro_batch_size=16 \
    --app_name=text2image_generation \
    --user_defined_parameters="\
        size=256\
        text_len=32\
        img_len=256\
        img_vocab_size=16384\
        max_generated_num=\
      "

利用bash文件命令行执行

我们在EasyNLP/examples/text2image_generation/文件夹下封装好了多种可直接执行的bash脚本,用户同样可以通过直接使用bash文件命令行执行的方式来一步完成模型的训练/评估/预测。以下以run_user_defined_local.sh脚本为例。该bash文件需要传入两个参数,第一个参数为运行程序的GPU编号,一般为0;第二个参数代表模型的训练/预测。

模型训练:

! cd EasyNLP/examples/text2image_generation && bash run_user_defined_local.sh 0 finetune

模型预测:

! cd EasyNLP/examples/text2image_generation && bash run_user_defined_local.sh 0 predict


相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
自然语言处理 数据格式
【DSW Gallery】基于ModelScope的中文GPT-3模型(1.3B)的微调训练
本文基于ModelScope,以GPT-3(1.3B)为例介绍如何使用ModelScope-GPT3进行续写训练与输入输出形式的训练,训练方式不需要额外指定,训练数据集仅包含 src_txt 时会进行续写训练,同时包含 src_txt 和 tgt_txt 时会进行输入输出形式的训练。
【DSW Gallery】基于ModelScope的中文GPT-3模型(1.3B)的微调训练
|
4月前
|
机器学习/深度学习 IDE 开发工具
ARTIST的中文文图生成模型问题之什么是PAI-DSW
ARTIST的中文文图生成模型问题之什么是PAI-DSW
|
5月前
|
自然语言处理 API 开发工具
初识langchain:LLM大模型+Langchain实战[qwen2.1、GLM-4]+Prompt工程
【7月更文挑战第6天】初识langchain:LLM大模型+Langchain实战[qwen2.1、GLM-4]+Prompt工程
初识langchain:LLM大模型+Langchain实战[qwen2.1、GLM-4]+Prompt工程
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
ARTIST的中文文图生成模型问题之在EasyNLP中使用ARTIST模型的问题如何解决
ARTIST的中文文图生成模型问题之在EasyNLP中使用ARTIST模型的问题如何解决
|
4月前
|
存储 人工智能 自然语言处理
【AI大模型】Transformers大模型库(十四):Datasets Viewer
【AI大模型】Transformers大模型库(十四):Datasets Viewer
33 0
|
机器学习/深度学习 人工智能 编解码
【DSW Gallery】基于EasyNLP-Diffusion模型的中文文图生成
EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文简要介绍文图生成的技术,以及如何在PAI-DSW中基于EasyNLP使用diffusion model进行finetune和预测评估。
【DSW Gallery】基于EasyNLP-Diffusion模型的中文文图生成
|
机器学习/深度学习 算法
【DSW Gallery】如何使用EasyRec训练DeepFM模型
本文基于EasyRec 0.4.7 展示了如何使用EasyRec快速的训练一个DeepFM模型
【DSW Gallery】如何使用EasyRec训练DeepFM模型
|
缓存 自然语言处理 Shell
【DSW Gallery】基于CK-BERT的中文序列标注
EasyNLP提供多种模型的训练及预测功能,旨在帮助自然语言开发者方便快捷地构建模型并应用于生产。本文以序列标注(命名实体识别)为例,为您介绍如何在PAI-DSW中使用EasyNLP。
【DSW Gallery】基于CK-BERT的中文序列标注
|
算法 PyTorch 算法框架/工具
【DSW Gallery】基于EasyCV的视频分类示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以视频分类为例,为您介绍如何在PAI-DSW中使用EasyCV。
【DSW Gallery】基于EasyCV的视频分类示例
|
机器学习/深度学习 并行计算 数据可视化
【DSW Gallery】EasyCV-基于关键点的视频分类示例
EasyCV是基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具,并包含图像分类,度量学习,目标检测,姿态识别等视觉任务的SOTA算法。本文以基于关键点的视频分类为例,为您介绍如何在PAI-DSW中使用EasyCV。
【DSW Gallery】EasyCV-基于关键点的视频分类示例

热门文章

最新文章