Datawhale AI夏令营第四期魔搭-AIGC文生图方向Task1笔记

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,5000CU*H 3个月
模型训练 PAI-DLC,5000CU*H 3个月
简介: 这段内容介绍了一个使用Stable Diffusion与LoRA技术创建定制化二次元图像生成模型的全流程。首先,通过安装必要的软件包如Data-Juicer和DiffSynth-Studio准备开发环境。接着,下载并处理二次元图像数据集,利用Data-Juicer进行数据清洗和筛选,确保图像质量和尺寸的一致性。随后,训练一个针对二次元风格优化的LoRA模型,并调整参数以控制模型复杂度。完成训练后,加载模型并通过精心设计的提示词(prompt)生成一系列高质量的二次元图像,展示模型对细节和艺术风格的理解与再现能力。整个过程展示了从数据准备到模型训练及结果生成的完整步骤,为定制化图像提供了方向。

output.png

提示词
提示词很重要,一般写法:主体描述,细节描述,修饰词,艺术风格,艺术家

举个例子

【promts】Beautiful and cute girl, smiling, 16 years old, denim jacket, gradient background, soft colors, soft lighting, cinematic edge lighting, light and dark contrast, anime, super detail, 8k

【负向prompts】(lowres, low quality, worst quality:1.2), (text:1.2), deformed, black and white,disfigured, low contrast, cropped, missing fingers

Lora
Stable Diffusion中的Lora(LoRA)模型是一种轻量级的微调方法,它代表了“Low-Rank Adaptation”,即低秩适应。Lora不是指单一的具体模型,而是指一类通过特定微调技术应用于基础模型的扩展应用。在Stable Diffusion这一文本到图像合成模型的框架下,Lora被用来对预训练好的大模型进行针对性优化,以实现对特定主题、风格或任务的精细化控制。

ComfyUI
ComfyUI 是一个工作流工具,主要用于简化和优化 AI 模型的配置和训练过程。通过直观的界面和集成的功能,用户可以轻松地进行模型微调、数据预处理、图像生成等任务,从而提高工作效率和生成效果。
a4c3848b-68a8-4b38-b040-c79c91b90c95.png

在ComfyUI平台的前端页面上,用户可以基于节点/流程图的界面设计并执行AIGC文生图或者文生视频的pipeline。

参考图控制

ControlNet是一种用于精确控制图像生成过程的技术组件。它是一个附加到预训练的扩散模型(如Stable Diffusion模型)上的可训练神经网络模块。扩散模型通常用于从随机噪声逐渐生成图像的过程,而ControlNet的作用在于引入额外的控制信号,使得用户能够更具体地指导图像生成的各个方面(如姿势关键点、分割图、深度图、颜色等)。

第一步:安装

安装 Data-Juicer 和 DiffSynth-Studio

pip install simple-aesthetics-predictor

pip install -v -e data-juicer

pip uninstall pytorch-lightning -y
pip install peft lightning pandas torchvision

pip install -e DiffSynth-Studio

请在这里手动重启 Notebook kernel

第二步:下载数据集

from modelscope.msdatasets import MsDataset

ds = MsDataset.load(
'AI-ModelScope/lowres_anime',
subset_name='default',
split='train',
cache_dir="/mnt/workspace/kolors/data"
)


保存数据集中的图片及元数据

import json, os
from data_juicer.utils.mm_utils import SpecialTokens
from tqdm import tqdm

os.makedirs("./data/lora_dataset/train", exist_ok=True)
os.makedirs("./data/data-juicer/input", exist_ok=True)
with open("./data/data-juicer/input/metadata.jsonl", "w") as f:
for data_id, data in enumerate(tqdm(ds)):
image = data["image"].convert("RGB")
image.save(f"/mnt/workspace/kolors/data/lora_dataset/train/{data_id}.jpg")
metadata = {"text": "二次元", "image": [f"/mnt/workspace/kolors/data/lora_dataset/train/{data_id}.jpg"]}
f.write(json.dumps(metadata))
f.write("\n")


第三步:数据处理

使用 data-juicer 处理数据

data_juicer_config = """

global parameters

project_name: 'data-process'
dataset_path: './data/data-juicer/input/metadata.jsonl' # path to your dataset directory or file
np: 4 # number of subprocess to process your dataset

text_keys: 'text'
image_key: 'image'
image_special_token: '<__dj__image>'

export_path: './data/data-juicer/output/result.jsonl'

process schedule

a list of several process operators with their arguments

process:

- image_shape_filter:
    min_width: 1024
    min_height: 1024
    any_or_all: any
- image_aspect_ratio_filter:
    min_ratio: 0.5
    max_ratio: 2.0
    any_or_all: any

"""
with open("data/data-juicer/data_juicer_config.yaml", "w") as file:
file.write(data_juicer_config.strip())

!dj-process --config data/data-juicer/data_juicer_config.yaml


保存处理好的数据

import pandas as pd
import os, json
from PIL import Image
from tqdm import tqdm

texts, file_names = [], []
os.makedirs("./data/lora_dataset_processed/train", exist_ok=True)
with open("./data/data-juicer/output/result.jsonl", "r") as file:
for data_id, data in enumerate(tqdm(file.readlines())):
data = json.loads(data)
text = data["text"]
texts.append(text)
image = Image.open(data["image"][0])
image_path = f"./data/lora_dataset_processed/train/{data_id}.jpg"
image.save(image_path)
file_names.append(f"{data_id}.jpg")
data_frame = pd.DataFrame()
data_frame["file_name"] = file_names
data_frame["text"] = texts
data_frame.to_csv("./data/lora_dataset_processed/train/metadata.csv", index=False, encoding="utf-8-sig")
data_frame


第四步:训练模型

下载模型

from diffsynth import download_models

download_models(["Kolors", "SDXL-vae-fp16-fix"])

查看训练脚本的输入参数

!python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py -h

开始训练

提示:

在训练命令中填入 --modelscope_model_id xxxxx 以及 --modelscope_access_token xxxxx 后,训练程序会在结束时自动上传模型到 ModelScope
部分参数可根据实际需求调整,例如 lora_rank 可以控制 LoRA 模型的参数量

import os

cmd = """
python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py \
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
--pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
--lora_rank 16 \
--lora_alpha 4.0 \
--dataset_path data/lora_dataset_processed \
--output_path ./models \
--max_epochs 1 \
--center_crop \
--use_gradient_checkpointing \
--precision "16-mixed"
""".strip()

os.system(cmd)


加载模型

from diffsynth import ModelManager, SDXLImagePipeline
from peft import LoraConfig, inject_adapter_in_model
import torch

def load_lora(model, lora_rank, lora_alpha, lora_path):
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights="gaussian",
target_modules=["to_q", "to_k", "to_v", "to_out"],
)
model = inject_adapter_in_model(lora_config, model)
state_dict = torch.load(lora_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model

Load models

model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
file_path_list=[
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
])
pipe = SDXLImagePipeline.from_model_manager(model_manager)

Load LoRA

pipe.unet = load_lora(
pipe.unet,
lora_rank=16, # This parameter should be consistent with that in your training script.
lora_alpha=2.0, # lora_alpha can control the weight of LoRA.
lora_path="models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt"

)

生成图像

torch.manual_seed(0)
image = pipe(
prompt="二次元,一个紫色短发小女孩,在床上甜美的睡着了,全身,粉色连衣裙,梦到了美丽的风景",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("1.jpg")

torch.manual_seed(1)
image = pipe(
prompt="二次元,日系动漫,博物馆的门口,一个紫色短发小女孩穿着粉色吊带漏肩连衣裙在去博物馆的门口",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("2.jpg")

orch.manual_seed(2)
image = pipe(
prompt="二次元,日系动漫,博物馆的室内,一个紫色短发小女孩穿着粉色吊带漏肩连衣裙站着,看墙上的画,露出憧憬的神情",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度,色情擦边",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("3.jpg")

torch.manual_seed(5)
image = pipe(
prompt="二次元,一个紫色短发小女孩穿着粉色吊带漏肩连衣裙,对着流星许愿,闭着眼睛,十指交叉,侧面",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度,扭曲的手指,多余的手指",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("4.jpg")

torch.manual_seed(0)
image = pipe(
prompt="二次元,一个紫色中等长度头发小女孩穿着粉色吊带漏肩连衣裙,在画室练习画画",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("5.jpg")

orch.manual_seed(1)
image = pipe(
prompt="二次元,紫色长发少女,穿着粉色吊带漏肩连衣裙看着墙上,墙上有幅风景画",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("6.jpg")

torch.manual_seed(7)
image = pipe(
prompt="二次元,紫色长发少女很高兴,在人山人海的人群看画",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("7.jpg")

torch.manual_seed(0)
image = pipe(
prompt="二次元,紫色长发少女,穿着黑色礼服,手里拿着一个奖杯,背后是她画好的画",
negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)

image.save("8.jpg")

import numpy as np
from PIL import Image

images = [np.array(Image.open(f"{i}.jpg")) for i in range(1, 9)]
image = np.concatenate([
np.concatenate(images[0:2], axis=1),
np.concatenate(images[2:4], axis=1),
np.concatenate(images[4:6], axis=1),
np.concatenate(images[6:8], axis=1),
], axis=0)
image = Image.fromarray(image).resize((1024, 2048))
image

相关文章
|
29天前
|
人工智能
Datawhale X 魔搭 AI夏令营task 2笔记
Datawhale X 魔搭 AI夏令营task 2笔记
44 1
Datawhale X 魔搭 AI夏令营task 2笔记
|
26天前
|
数据采集 人工智能 物联网
Datawhale X 魔搭 AI夏令营task 3笔记
Datawhale X 魔搭 AI夏令营task 3笔记
60 2
|
27天前
|
人工智能 物联网
Datawhale X 魔搭 AI夏令营task 3笔记
Datawhale X 魔搭 AI夏令营task 3笔记
|
27天前
|
机器学习/深度学习 人工智能 中间件
解读顺网算力与AI,破局AIGC落地“最后一公里”
解读顺网算力与AI,破局AIGC落地“最后一公里”
|
30天前
|
人工智能
Datawhale X 魔搭 AI夏令营task 2笔记
Datawhale X 魔搭 AI夏令营task 2笔记
|
22天前
|
存储 人工智能 异构计算
就AI 基础设施的演进与挑战问题之通讯墙在AIGC中挑战的问题如何解决
就AI 基础设施的演进与挑战问题之通讯墙在AIGC中挑战的问题如何解决
|
22天前
|
人工智能 弹性计算 芯片
就AI 基础设施的演进与挑战问题之AIGC场景下训练和推理的成本的问题如何解决
就AI 基础设施的演进与挑战问题之AIGC场景下训练和推理的成本的问题如何解决
|
25天前
|
算法 物联网 Serverless
一键打造你的定制化AIGC文生图工具
【8月更文挑战第2天】一键打造你的定制化AIGC文生图工具
68 0
|
27天前
|
机器学习/深度学习 人工智能 编解码
AI文生图模型
8月更文挑战第16天
|
28天前
|
人工智能 编解码 自然语言处理