消费级显卡微调可图Kolors最佳实践!

简介: 近期,快手开源了一种名为Kolors(可图)的文本到图像生成模型,该模型具有对英语和汉语的深刻理解,并能够生成高质量、逼真的图像。

近期,快手开源了一种名为Kolors(可图)的文本到图像生成模型,该模型具有对英语和汉语的深刻理解,并能够生成高质量、逼真的图像。

魔搭社区在DiffSynth-Studio中提供了可图Kolors微调脚本。

代码开源链接:

https://github.com/Kwai-Kolors/Kolors

模型开源链接:

https://modelscope.cn/models/Kwai-Kolors/Kolors

技术报告链接:

https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf

微调脚本链接:

https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/train/kolors

微调最佳实践

下载模型权重

下载可图Kolors模型

modelscope download --model=Kwai-Kolors/Kolors --local_dir models/kolors/Kolors

image.gif

下载额外的VAE模型(https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix

modelscope download --model=AI-ModelScope/sdxl-vae-fp16-fix --local_dir models/kolors/sdxl-vae-fp16-fix diffusion_pytorch_model.safetensors

image.gif

模型文件结构:

models
├── kolors
│   └── Kolors
│       ├── text_encoder
│       │   ├── config.json
│       │   ├── pytorch_model-00001-of-00007.bin
│       │   ├── pytorch_model-00002-of-00007.bin
│       │   ├── pytorch_model-00003-of-00007.bin
│       │   ├── pytorch_model-00004-of-00007.bin
│       │   ├── pytorch_model-00005-of-00007.bin
│       │   ├── pytorch_model-00006-of-00007.bin
│       │   ├── pytorch_model-00007-of-00007.bin
│       │   └── pytorch_model.bin.index.json
│       ├── unet
│       │   └── diffusion_pytorch_model.safetensors
│       └── vae
│           └── diffusion_pytorch_model.safetensors
└── sdxl-vae-fp16-fix
    └── diffusion_pytorch_model.safetensors

image.gif

微调:

安装依赖:

pip install peft lightning pandas torchvision

image.gif

数据准备:

我们准备了一些开源数据集:

柯基小狗数据集:

https://modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune

文生图风格定制数据集(metadata做了汉化):

https://modelscope.cn/datasets/iic/style_custom_dataset

数据集按照如下格式:

data/dog/
└── train
    ├── 00.jpg
    ├── 01.jpg
    ├── 02.jpg
    ├── 03.jpg
    ├── 04.jpg
    └── metadata.csv

image.gif

metadata.csv:

file_name,text
00.jpg,一只小狗
01.jpg,一只小狗
02.jpg,一只小狗
03.jpg,一只小狗
04.jpg,一只小狗

image.gif

训练lora模型:

我们提供了训练脚本 train_kolors_lora.py,在运行该训练脚本之前,需要先clone本项目

https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio

image.gif

采用以下设置,需要22GB VRAM

CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
  --pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
  --pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
  --pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
  --dataset_path data/dog \
  --output_path ./models \
  --max_epochs 10 \
  --center_crop \
  --use_gradient_checkpointing \
  --precision "16-mixed"

image.gif

可选参数:

-h, --help            show this help message and exit
  --pretrained_unet_path PRETRAINED_UNET_PATH
                        Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.
  --pretrained_text_encoder_path PRETRAINED_TEXT_ENCODER_PATH
                        Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.
  --pretrained_fp16_vae_path PRETRAINED_FP16_VAE_PATH
                        Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.
  --dataset_path DATASET_PATH
                        The path of the Dataset.
  --output_path OUTPUT_PATH
                        Path to save the model.
  --steps_per_epoch STEPS_PER_EPOCH
                        Number of steps per epoch.
  --height HEIGHT       Image height.
  --width WIDTH         Image width.
  --center_crop         Whether to center crop the input images to the resolution. If not set, the images will be randomly cropped. The images will be resized to the resolution first before cropping.
  --random_flip         Whether to randomly flip images horizontally
  --batch_size BATCH_SIZE
                        Batch size (per device) for the training dataloader.
  --dataloader_num_workers DATALOADER_NUM_WORKERS
                        Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
  --precision {32,16,16-mixed}
                        Training precision
  --learning_rate LEARNING_RATE
                        Learning rate.
  --lora_rank LORA_RANK
                        The dimension of the LoRA update matrices.
  --lora_alpha LORA_ALPHA
                        The weight of the LoRA update matrices.
  --use_gradient_checkpointing
                        Whether to use gradient checkpointing.
  --accumulate_grad_batches ACCUMULATE_GRAD_BATCHES
                        The number of batches in gradient accumulation.
  --training_strategy {auto,deepspeed_stage_1,deepspeed_stage_2,deepspeed_stage_3}
                        Training strategy
  --max_epochs MAX_EPOCHS
                        Number of epochs.

image.gif

训练后推理

训练完成后,可以使用自己训练的LoRA来生成新图像。以下是一些示例:

from diffsynth import ModelManager, KolorsImagePipeline
from peft import LoraConfig, inject_adapter_in_model
import torch
def load_lora(model, lora_rank, lora_alpha, lora_path):
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        init_lora_weights="gaussian",
        target_modules=["to_q", "to_k", "to_v", "to_out"],
    )
    model = inject_adapter_in_model(lora_config, model)
    state_dict = torch.load(lora_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    return model
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
                             file_path_list=[
                                 "models/kolors/Kolors/text_encoder",
                                 "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
                                 "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
                             ])
pipe = KolorsImagePipeline.from_model_manager(model_manager)
# Generate an image with lora
pipe.unet = load_lora(
    pipe.unet,
    lora_rank=4, lora_alpha=4.0, # The two parameters should be consistent with those in your training script.
    lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
)
torch.manual_seed(0)
image = pipe(
    prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉",
    negative_prompt="",
    cfg_scale=4,
    num_inference_steps=50, height=1024, width=1024,
)
image.save("image_with_lora.jpg")

image.gif

柯基lora:

Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉

image.gif

3D风格lora:

Prompt:一只小狗和一只小猫3D

image.gif

点击链接👇直达链接

https://modelscope.cn/models/Kwai-Kolors/Kolors?from=alizishequ__text

相关文章
|
6月前
|
物联网 测试技术 API
用消费级显卡微调属于自己的Agent
本文为魔搭社区轻量级训练推理工具SWIFT微调实战教程系列
|
6月前
|
存储 缓存 算法
使用Mixtral-offloading在消费级硬件上运行Mixtral-8x7B
Mixtral-8x7B是最好的开放大型语言模型(LLM)之一,但它是一个具有46.7B参数的庞大模型。即使量化为4位,该模型也无法在消费级GPU上完全加载(例如,24 GB VRAM是不够的)。
221 4
|
数据可视化 物联网 PyTorch
双卡3090消费级显卡 SFT OpenBuddy-LLaMA1-65B 最佳实践
OpenBuddy继接连开源OpenBuddy-LLaMA1-13B、OpenBuddy-LLaMA1-30B后,8月10日,一鼓作气发布了650亿参数的大型跨语言对话模型 OpenBuddy-LLaMA1-65B。
|
机器学习/深度学习 人工智能 算法
阿里公开自研AI集群细节:64个GPU,百万分类训练速度提升4倍
从节点架构到网络架构,再到通信算法,阿里巴巴把自研的高性能AI集群技术细节写成了论文,并对外公布。
阿里公开自研AI集群细节:64个GPU,百万分类训练速度提升4倍
|
3月前
|
小程序 API 调度
消费级显卡,17G显存,玩转图像生成模型FLUX.1!
近期stable diffusion的部分核心开发同学,推出了全新的图像生成模型FLUX.1。
|
6月前
|
人工智能 自动驾驶 算法
只要千元级,人人可用百亿级多模态大模型!国产“AI模盒”秒级训练推理
云天励飞,中国AI独角兽,发布“AI模盒”,以千元成本实现多模态大模型的秒级训练推理,降低AI应用门槛。该产品凸显了公司在技术创新与普及中的努力,旨在构建智能城市并重塑日常生活,同时也面临数据安全、隐私保护及人才挑战。
89 3
只要千元级,人人可用百亿级多模态大模型!国产“AI模盒”秒级训练推理
|
6月前
|
人工智能 数据挖掘 大数据
随着AI算力需求不断增强,800G光模块的需求不断增大
随着AI算力需求增长和硅光技术进步,光模块产业正经历快速发展,尤其在400G、800G及1.6T领域。到2024年,硅光方案将广泛应用于高带宽光模块,推动技术更新速度加快。800G光模块因高速、高密度和低功耗特性,市场需求日益增长,将在2025年成为市场主流,预计市场规模将达到16亿美元。光模块厂家需关注技术创新、产品多样化和产能提升以适应竞争。
411 1
|
6月前
|
自然语言处理 JavaScript 前端开发
MFTCoder 重磅升级 v0.3.0 发布,支持 Mixtral 等更多模型,支持收敛均衡,支持 FSDP
今天,我们对MFTCoder进行重磅升级,比如对Mixtral这个开源MoE的SOTA的多任务微调的支持;再比如我们提供了之前论文中提到的收敛均衡技术:Self-Paced Loss。 MFTCoder已适配支持了更多的主流开源LLMs,如Mixtral、Mistral、Deepseek、 Llama、CodeLlama、Qwen、CodeGeeX2、StarCoder、Baichuan2、ChatGLM2/3、GPT-Neox等。以Deepseek-coder-33b-base为底座,使用MFTCoder微调得到的CodeFuse-Deepseek-33B在HumaneEval测试中pass
132 0
|
算法 数据库 异构计算
Milvus 2.3.功能全面升级,核心组件再升级,超低延迟、高准确度、MMap一触开启数据处理量翻倍、支持GPU使用!
Milvus 2.3.功能全面升级,核心组件再升级,超低延迟、高准确度、MMap一触开启数据处理量翻倍、支持GPU使用!
Milvus 2.3.功能全面升级,核心组件再升级,超低延迟、高准确度、MMap一触开启数据处理量翻倍、支持GPU使用!
|
6月前
|
存储 机器人 PyTorch
使用 ExLlamaV2 在消费级 GPU 上运行 Llama 2 70B
使用 ExLlamaV2 在消费级 GPU 上运行 Llama 2 70B
535 0
下一篇
无影云桌面