EasyNLP中文文图生成模型带你秒变艺术家

本文涉及的产品
模型训练 PAI-DLC,100CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
阿里云百炼推荐规格 ADB PostgreSQL,4核16GB 100GB 1个月
简介: 本文简要介绍文图生成的技术,以及如何在EasyNLP框架中如何轻松实现文图生成,带你秒变艺术家。本文开头的展示图片即为我们模型创作的作品。

image.png

作者 | 汪诚愚、刘婷婷
来源 | 阿里开发者公众号

导读

image.png

宣物莫大于言,存形莫善于画。
--【晋】陆机

多模态数据(文本、图像、声音)是人类认识、理解和表达世间万物的重要载体。近年来,多模态数据的爆炸性增长促进了内容互联网的繁荣,也带来了大量多模态内容理解和生成的需求。与常见的跨模态理解任务不同,文到图的生成任务是流行的跨模态生成任务,旨在生成与给定文本对应的图像。这一文图生成的任务,极大地释放了AI的想象力,也激发了人类的创意。典型的模型例如OpenAI开发的DALL-E和DALL-E2。近期,业界也训练出了更大、更新的文图生成模型,例如Google提出的Parti和Imagen。

然而,上述模型一般不能用于处理中文的需求,而且上述模型的参数量庞大,很难被开源社区的广大用户直接用来Fine-tune和推理。本次,EasyNLP开源框架再次迎来大升级,集成了先进的文图生成架构Transformer+VQGAN,同时,向开源社区免费开放不同参数量的中文文图生成模型的Checkpoint,以及相应Fine-tune和推理接口。用户可以在我们开放的Checkpoint基础上进行少量领域相关的微调,在不消耗大量计算资源的情况下,就能一键进行各种艺术创作。

EasyNLP是阿里云机器学习PAI 团队基于 PyTorch 开发的易用且丰富的中文NLP算法框架,并且提供了从训练到部署的一站式 NLP 开发体验。EasyNLP 提供了简洁的接口供用户开发 NLP 模型,包括NLP应用 AppZoo 、预训练模型 ModelZoo、数据仓库DataHub等特性。由于跨模态理解和生成需求的不断增加,EasyNLP也支持各种跨模态模型,特别是中文领域的跨模态模型,推向开源社区。例如,在先前的工作中,EasyNLP已经对中文图文检索CLIP模型进行了支持[11]。我们希望能够服务更多的 NLP 和多模态算法开发者和研究者,也希望和社区一起推动 NLP /多模态技术的发展和模型落地。本文简要介绍文图生成的技术,以及如何在EasyNLP框架中如何轻松实现文图生成,带你秒变艺术家。本文开头的展示图片即为我们模型创作的作品。

文图生成模型简述

下面以几个经典的基于Transformer的工作为例,简单介绍文图生成模型的技术。DALL-E由OpenAI提出,采取两阶段的方法生成图像。在第一阶段,训练一个dVAE(discrete variational autoencoder)的模型将256×256的RGB图片转化为32×32的image token,这一步骤将图片进行信息压缩和离散化,方便进行文本到图像的生成。第二阶段,DALL-E训练一个自回归的Transformer模型,将文本输入转化为上述1024个image token。

由清华大学等单位提出的CogView模型对上述两阶段文图生成的过程进行了进一步的优化。在下图中,CogView采用了sentence piece作为text tokenizer使得输入文本的空间表达更加丰富,并且在模型的Fine-tune过程中采用了多种技术,例如图像的超分、风格迁移等。

image.png

ERNIE-ViLG模型考虑进一步考虑了Transformer模型学习知识的可迁移性,同时学习了从文本生成图像和从图像生成文本这两种任务。其架构图如下所示:

image.png

随着文图生成技术的不断发展,新的模型和技术不断涌现。举例来说,OFA将多种跨模态的生成任务统一在同一个模型架构中。DALL-E 2同样由OpenAI提出,是DALL-E模型的升级版,考虑了层次化的图像生成技术,模型利用CLIP encoder作为编码器,更好地融入了CLIP预训练的跨模态表征。Google进一步提出了Diffusion Model的架构,能有效生成高清大图,如下所示:

image.png

在本文中,我们不再对这些细节进行赘述。感兴趣的读者可以进一步查阅参考文献。

EasyNLP文图生成模型

由于前述模型的规模往往在数十亿、百亿参数级别,庞大的模型虽然能生成质量较大的图片,然后对计算资源和预训练数据的要求使得这些模型很难在开源社区广泛应用,尤其在需要面向垂直领域的情况下。在本节中,我们详细介绍EasyNLP提供的中文文图生成模型,它在较小参数量的情况下,依然具有良好的文图生成效果。

模型架构

模型框架图如下图所示:

image.png

考虑到Transformer模型复杂度随序列长度呈二次方增长,文图生成模型的训练一般以图像矢量量化和自回归训练两阶段结合的方式进行。

图像矢量量化是指将图像进行离散化编码,如将256×256的RGB图像进行16倍降采样,得到16×16的离散化序列,序列中的每个image token对应于codebook中的表示。常见的图像矢量量化方法包括:VQVAE、VQVAE-2和VQGAN等。我们采用VQGAN在ImageNet上训练的f16_16384(16倍降采样,词表大小为16384)的模型权重来生成图像的离散化序列。

自回归训练是指将文本序列和图像序列作为输入,在图像部分,每个image token仅与文本序列的tokens和其之前的image tokens进行attention计算。我们采用GPT作为backbone,能够适应不同模型规模的生成任务。在模型预测阶段,输入文本序列,模型以自回归的方式逐步生成定长的图像序列,再通过VQGAN decoder重构为图像。

开源模型参数设置

在EasyNLP中,我们提供两个版本的中文文图生成模型,模型参数配置如下表:

image.png

模型实现

在EasyNLP框架中,我们在模型层构建基于minGPT的backbone构建模型,核心部分如下所示:

self.first_stage_model = VQModel(ckpt_path=vqgan_ckpt_path).eval()
self.transformer = GPT(self.config)

VQModel的Encoding阶段过程为:

# in easynlp/appzoo/text2image_generation/model.py

@torch.no_grad()
def encode_to_z(self, x):
    quant_z, _, info = self.first_stage_model.encode(x)
    indices = info[2].view(quant_z.shape[0], -1)
    return quant_z, indices

x = inputs['image']
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
# one step to produce the logits
_, z_indices = self.encode_to_z(x)  # z_indice: torch.Size([batch_size, 256]) 

VQModel的Decoding阶段过程为:

# in easynlp/appzoo/text2image_generation/model.py

@torch.no_grad()
def decode_to_img(self, index, zshape):
    bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
    quant_z = self.first_stage_model.quantize.get_codebook_entry(
        index.reshape(-1), shape=bhwc)
    x = self.first_stage_model.decode(quant_z)
    return x

# sample为训练阶段的结果生成,与预测阶段的generate类似,详解见下文generate
index_sample = self.sample(z_start_indices, c_indices,
                           steps=z_indices.shape[1],
                           ...)
x_sample = self.decode_to_img(index_sample, quant_z.shape)

Transformer采用minGPT进行构建,输入图像的离散编码,输出文本token。前向传播过程为:

image.png

在预测阶段,输入为文本token, 输出为256*256的图像。首先,将输入文本预处理为token序列:

# in easynlp/appzoo/text2image_generation/predictor.py

def preprocess(self, in_data):
    if not in_data:
        raise RuntimeError("Input data should not be None.")

    if not isinstance(in_data, list):
        in_data = [in_data]
    rst = {"idx": [], "input_ids": []}
    max_seq_length = -1
    for record in in_data:
        if "sequence_length" not in record:
            break
        max_seq_length = max(max_seq_length, record["sequence_length"])
    max_seq_length = self.sequence_length if (max_seq_length == -1) else max_seq_length

    for record in in_data:
        text= record[self.first_sequence]
        try:
            self.MUTEX.acquire()
            text_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
            text_ids = text_ids[: self.text_len]
            n_pad = self.text_len - len(text_ids)
            text_ids += [self.pad_id] * n_pad
            text_ids = np.array(text_ids) + self.img_vocab_size

        finally:
            self.MUTEX.release()

        rst["idx"].append(record["idx"]) 
        rst["input_ids"].append(text_ids)
    return rst

逐步生成长度为16*16的图像离散token序列:

# in easynlp/appzoo/text2image_generation/model.py

def generate(self, inputs, top_k=100, temperature=1.0):
    cidx = inputs
    sample = True
    steps = 256
    for k in range(steps):
        x_cond = cidx
        logits, _ = self.transformer(x_cond)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = self.top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        cidx = torch.cat((cidx, ix), dim=1)
    img_idx = cidx[:, 32:]
    return img_idx

最后,我们调用VQModel的Decoding过程将这些图像离散token序列转换为图像。

模型效果

我们在四个中文的公开数据集COCO-CN、MUGE、Flickr8k-CN、Flickr30k-CN上验证了EasyNLP框架中文图生成模型的效果。同时,我们对比了这个模型和CogView、DALL-E的效果,如下所示:

image.png

其中:
1)MUGE是天池平台公布的电商场景的中文大规模多模态评测基准[12]。为了方便计算指标,MUGE我们采用valid数据集的结果,其他数据集采用test数据集的结果。

2)CogView源自[13]

3)DALL-E模型没有公开的官方代码。已经公开的部分只包含VQVAE的代码,不包括Transformer部分。我们基于广受关注的[14]版本的代码和该版本推荐的checkpoits进行复现,checkpoints为2.09亿参数,为OpenAI的DALL-E模型参数量的1/100。(OpenAI版本DALL-E为120亿参数,其中CLIP为4亿参数)。

经典案例

我们分别在自然风景数据集COCO-CN上Fine-tune了base和large级别的模型,如下展示了模型的效果:

示例1:一只俏皮的狗正跑过草地

image.png

示例2:一片水域的景色以日落为背景

image.png

我们也积累了阿里集团的海量电商商品数据,微调得到了面向电商商品的文图生成模型。效果如下:

示例3:女童套头毛衣打底衫秋冬针织衫童装儿童内搭上衣

image.png

示例4:春夏真皮工作鞋女深色软皮久站舒适上班面试职业皮鞋

image.png

除了支持特定领域的应用,文图生成也极大地辅助了人类的艺术创作。使用训练得到的模型,我们可以秒变“中国国画艺术大师”,示例如下所示:

image.png

更多的示例请欣赏:

image.png

image.png

使用教程

欣赏了模型生成的作品之后,如果我们想DIY,训练自己的文图生成模型,应该如何进行呢?以下我们简要介绍在EasyNLP框架对预训练的文图生成模型进行Fine-tune和推理。

安装EasyNLP

用户可以直接参考链接[15]的说明安装EasyNLP算法框架。

数据准备

首先准备训练数据与验证数据,为tsv文件。这一文件包含以制表符\t分隔的两列,第一列为索引号,第二列为文本,第三列为图片的base64编码。用于测试的输入文件为两列,仅包含索引号和文本。

为了方便开发者,我们也提供了转换图片到base64编码的示例代码:

import base64
from io import BytesIO
from PIL import Image

img = Image.open(fn)
img_buffer = BytesIO()
img.save(img_buffer, format=img.format)
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data) # bytes

下列文件已经完成预处理,可用于测试:

# train
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_train_text_imgbase64.tsv

# valid
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_val_text_imgbase64.tsv

# test
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_test.text.tsv

模型训练

我们采用以下命令对模型进行fine-tune:

easynlp \
    --mode=train \
    --worker_gpu=1 \
    --tables=MUGE_val_text_imgbase64.tsv,MUGE_val_text_imgbase64.tsv \
    --input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
    --first_sequence=text \
    --second_sequence=imgbase64 \
    --checkpoint_dir=./finetuned_model/ \
    --learning_rate=4e-5 \
    --epoch_num=1 \
    --random_seed=42 \
    --logging_steps=100 \
    --save_checkpoint_steps=1000 \
    --sequence_length=288 \
    --micro_batch_size=16 \
    --app_name=text2image_generation \
    --user_defined_parameters='
        pretrain_model_name_or_path=alibaba-pai/pai-painter-large-zh
        size=256
        text_len=32
        img_len=256
        img_vocab_size=16384
    ' 

我们提供base和large两个版本的预训练模型,pretrain_model_name_or_path分别为alibaba-pai/pai-painter-base-zh和alibaba-pai/pai-painter-large-zh。
训练完成后模型被保存到./finetuned_model/。

模型批量推理

模型训练完毕后,我们可以将其用于图像生成,其示例如下:

easynlp \
    --mode=predict \
    --worker_gpu=1 \
    --tables=MUGE_test.text.tsv \
    --input_schema=idx:str:1,text:str:1 \
    --first_sequence=text \
    --outputs=./T2I_outputs.tsv \
    --output_schema=idx,text,gen_imgbase64 \
    --checkpoint_dir=./finetuned_model/ \
    --sequence_length=288 \
    --micro_batch_size=8 \
    --app_name=text2image_generation \
    --user_defined_parameters='
        size=256
        text_len=32
        img_len=256
        img_vocab_size=16384

结果存储在一个tsv文件中,每行对应输入中的一个文本,输出的图像以base64编码。

使用Pipeline接口快速体验文图生成效果

为了进一步方便开发者使用,我们在EasyNLP框架内也实现了Inference Pipeline功能。用户可以使用如下命令调用Fine-tune过的电商场景下的文图生成模型:

# 直接构建pipeline
default_ecommercial_pipeline = pipeline("pai-painter-commercial-base-zh")

# 模型预测
data = ["宽松T恤"]
results = default_ecommercial_pipeline(data)  # results的每一条是生成图像的base64编码

# base64转换为图像
def base64_to_image(imgbase64_str):
    image = Image.open(BytesIO(base64.urlsafe_b64decode(imgbase64_str)))
    return image

# 保存以文本命名的图像
for text, result in zip(data, results):
    imgpath = '{}.png'.format(text)
    imgbase64_str = result['gen_imgbase64']
    image = base64_to_image(imgbase64_str)
    image.save(imgpath)
    print('text: {}, save generated image: {}'.format(text, imgpath))

除了电商场景,我们还提供了以下场景的模型:

  • 自然风光场景:“pai-painter-scenery-base-zh”
  • 中国山水画场景:“pai-painter-painting-base-zh”

在上面的代码当中替换“pai-painter-commercial-base-zh”,就可以直接体验,欢迎试用。

对于用户Fine-tune的文图生成模型,我们也开放了自定义模型加载的Pipeline接口:

# 加载模型,构建pipeline
local_model_path = ...
text_to_image_pipeline = pipeline("text2image_generation", local_model_path)

# 模型预测
data = ["xxxx"]
results = text_to_image_pipeline(data)  # results的每一条是生成图像的base64编码

未来展望

在这一期的工作中,我们在EasyNLP框架中集成了中文文图生成功能,同时开放了模型的Checkpoint,方便开源社区用户在资源有限情况下进行少量领域相关的微调,进行各种艺术创作。在未来,我们计划在EasyNLP框架中推出更多相关模型,敬请期待。我们也将在EasyNLP框架中集成更多SOTA模型(特别是中文模型),来支持各种NLP和多模态任务。此外,阿里云机器学习PAI团队也在持续推进中文多模态模型的自研工作,欢迎用户持续关注我们,也欢迎加入我们的开源社区,共建中文NLP和多模态算法库!

Github地址:https://github.com/alibaba/EasyNLP

Reference

1、Chengyu Wang, Minghui Qiu, Taolin Zhang, Tingting Liu, Lei Li, Jianing Wang, Ming Wang, Jun Huang, Wei Lin. EasyNLP: A Comprehensive and Easy-to-use Toolkit for Natural Language Processing. arXiv

2、Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, Ilya Sutskever. Zero-Shot Text-to-Image Generation. ICML 2021: 8821-8831

3、Ming Ding, Zhuoyi Yang, Wenyi Hong, Wendi Zheng, Chang Zhou, Da Yin, Junyang Lin, Xu Zou, Zhou Shao, Hongxia Yang, Jie Tang. CogView: Mastering Text-to-Image Generation via Transformers. NeurIPS 2021: 19822-19835

4、Han Zhang, Weichong Yin, Yewei Fang, Lanxin Li, Boqiang Duan, Zhihua Wu, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang. ERNIE-ViLG: Unified Generative Pre-training for Bidirectional Vision-Language Generation. arXiv

5、Peng Wang, An Yang, Rui Men, Junyang Lin, Shuai Bai, Zhikang Li, Jianxin Ma, Chang Zhou, Jingren Zhou, Hongxia Yang. Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework. ICML 2022

6、Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. Hierarchical Text-Conditional Image Generation with CLIP Latents. arXiv

7、Van Den Oord A, Vinyals O. Neural discrete representation learning. NIPS 2017

8、Esser P, Rombach R, Ommer B. Taming transformers for high-resolution image synthesis. CVPR 2021: 12873-12883.

9、Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho, David J. Fleet, Mohammad Norouzi: Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. arXiv

10、Jiahui Yu, Yuanzhong Xu, Jing Yu Koh, Thang Luong, Gunjan Baid, Zirui Wang, Vijay Vasudevan, Alexander Ku, Yinfei Yang, Burcu Karagol Ayan, Ben Hutchinson, Wei Han, Zarana Parekh, Xin Li, Han Zhang, Jason Baldridge, Yonghui Wu. Scaling Autoregressive Models for Content-Rich Text-to-Image Generation. arXiv

11、https://zhuanlan.zhihu.com/p/528476134

12、http://tianchi.aliyun.com/muge

13、https://github.com/THUDM/CogView

14、https://github.com/lucidrains/DALLE-pytorch

15、https://github.com/alibaba/EasyNLP


大数据知识图谱—基于DataWorks搭建新零售数据中台

本篇文章向大家分享新零售企业如何基于DataWorks搭建数据中台,从商业模式及业务的设计,到数据中台的架构设计与产品选型,再到数据中台搭建的最佳实践,最后利用数据中台去反哺业务,辅助人工与智能的决策。 内容贡献:李启平(首义),盒马从初创至今的数据研发负责人,有非常资深的数仓及数据中台建设的经验,原阿里巴巴国际业务数仓负责人。

相关实践学习
AnalyticDB PostgreSQL 企业智能数据中台:一站式管理数据服务资产
企业在数据仓库之上可构建丰富的数据服务用以支持数据应用及业务场景;ADB PG推出全新企业智能数据平台,用以帮助用户一站式的管理企业数据服务资产,包括创建, 管理,探索, 监控等; 助力企业在现有平台之上快速构建起数据服务资产体系
相关文章
|
安全 Python Windows
python - http请求带Authorization
# 背景 接入公司的一个数据统计平台,该平台的接口是带上了Authorization验证方式来保证验签计算安全   # 方法 其实很简单,就是在header中加入key=Authorization,value是协商好的协议即可; 如,我们这边是base64.
5210 0
|
编解码 人工智能 调度
Meissonic:高效高分辨率文生图重大革新
Meissonic的新模型,仅1b参数可实现高质量图像生成,能在普通电脑上运行,未来有望支持无线端文本到图像的生成。
|
5月前
|
JSON API 数据安全/隐私保护
《揭秘:抖音电商 API 接口,让直播带货数据精准掌控!》
在数字营销时代,抖音电商API为直播带货提供数据支持。通过API可实时获取销售、用户互动等关键数据,助力商家优化策略,实现自动化分析与精准营销,提升效率并驱动业务增长。
476 0
|
9月前
|
JavaScript NoSQL 关系型数据库
当下弹幕互动游戏源码开发教程及功能逻辑分析
当下很多游戏开发者或者想学习游戏开发的人,想要了解如何制作弹幕互动游戏,比如直播平台上常见的那种,观众通过发送弹幕来影响游戏进程。需要涵盖教程的步骤和功能逻辑的分析。
|
11月前
|
存储 人工智能 缓存
【AI系统】Ascend C 语法扩展
Ascend C 是基于标准 C++ 扩展的编程语言,专为华为昇腾处理器设计。本文介绍了 Ascend C 的基础语法扩展、API(基础与高阶)、关键编程对象(数据存储、任务间通信与同步、资源管理及临时变量),以及如何利用这些特性高效开发。通过华为自研的毕昇编译器,Ascend C 实现了主机与设备侧的独立执行能力,支持不同地址空间的访问。API 包括计算、数据搬运、内存管理和任务同步等功能,旨在帮助开发者构建高性能的 AI 应用。
331 2
【AI系统】Ascend C 语法扩展
|
分布式计算 DataWorks 数据处理
"DataWorks高级技巧揭秘:手把手教你如何在PyODPS节点中将模型一键写入OSS,实现数据处理的完美闭环!"
【10月更文挑战第23天】DataWorks是企业级的云数据开发管理平台,支持强大的数据处理和分析功能。通过PyODPS节点,用户可以编写Python代码执行ODPS任务。本文介绍了如何在DataWorks中训练模型并将其保存到OSS的详细步骤和示例代码,包括初始化ODPS和OSS服务、读取数据、训练模型、保存模型到OSS等关键步骤。
631 3
|
缓存 测试技术 API
电商平台 API 接入技术要点深度剖析
本文介绍了高效使用电商平台API的关键步骤。首先,深入理解API文档,明确功能权限与参数格式要求;其次,选择合适的接入方式,如HTTP/HTTPS协议和RESTful API;接着,实施身份验证与授权机制,确保数据安全传输;此外,还需关注性能优化、安全防护、监控与日志记录,以提升系统稳定性和响应速度;最后,进行充分测试与调试,并关注API版本更新,确保长期兼容性。
|
Web App开发 JavaScript 前端开发
WebRTC 和 RTC 有什么区别?
【10月更文挑战第25天】WebRTC是RTC的一种具体实现方式,侧重于网页端的实时通信,具有便捷性和跨平台性等特点;而RTC则是一个更广泛的概念,包括了各种不同平台和技术实现的实时通信方式,应用场景更加丰富多样。在实际应用中,需要根据具体的需求和场景选择合适的实时通信技术。
|
机器学习/深度学习 人工智能 算法
手把手教你强化学习 (一) 什么是强化学习?与机器学习有什么区别?
手把手教你强化学习 (一) 什么是强化学习?与机器学习有什么区别?
570 3
|
机器学习/深度学习 数据采集 算法
利用机器学习进行用户行为预测的技术解析
【5月更文挑战第17天】本文探讨了利用机器学习预测用户行为的技术,包括数据收集与处理、特征工程、模型选择与训练、评估预测。通过理解用户数据、提取有效特征,使用如RNN、LSTM等深度学习模型进行训练,评估模型性能后,可实现用户行为预测,助力企业决策,如个性化推荐和精准营销。随着技术发展,机器学习在该领域的应用将更加广泛。
1143 1