打造社交APP人物动漫化:通义万相wan2.x训练优化指南

简介: 本项目基于通义万相AIGC模型,为社交APP打造“真人变身跳舞动漫仙女”特效视频生成功能。通过LoRA微调与全量训练结合,并引入Sage Attention、TeaCache、xDIT并行等优化技术,实现高质量、高效率的动漫风格视频生成,兼顾视觉效果与落地成本,最终优选性价比最高的wan2.1 lora模型用于生产部署。(239字)

1. 需求场景:AI特效生成

本项目旨在为社交类APP集成AIGC驱动的个人宣传视频生成功能,通过AI技术将用户上传的真人图像,转化为具有动漫风格的个性化短视频,尤其聚焦于“真人变身跳舞动漫仙女”的特定场景。项目采用通义万相系列AIGC模型,结合定制化训练与推理优化,打造高效、高质量、可商业落地的视频生成解决方案。


项目需求

客户希望在社交APP中新增一个功能模块:用户上传真人图片,系统自动生成一段具有动漫风格的跳舞短视频,用于个人形象展示、社交传播等场景。该功能需具备以下核心能力:

  • 人物动漫形象化转换:将真人形象转化为动漫风格角色;
  • 动态动作生成:支持舞蹈等复杂动作序列;
  • 高质量视频输出:支持720p分辨率,帧率稳定,画面细腻;
  • 风格一致性控制:确保生成视频在风格、色彩、动作上保持统一;


痛点问题

虽然当前市面上已有多个主流AIGC视频生成模型(如Stable Diffusion 3、Runway ML、Pika Labs等),但在本项目场景下存在以下关键痛点:

  • 动态动作生成不稳定:现有模型在复杂动作(如舞蹈)生成中容易出现动作不连贯、肢体穿透、帧抖动等问题;
  • 动漫风格控制能力弱:难以精准实现“仙女”类动漫风格,风格一致性差;
  • 视频质量在低分辨率下下降明显:纹理丢失严重,细节表现不佳;
  • 推理速度慢,不满足生产部署需求:无法在主流消费级显卡上实现高效推理;
  • 缺乏个性化定制能力:无法针对“真人→动漫仙女”的特定场景进行模型强化;

针对以上痛点问题,本项目决定:

1. 进行模型选型,选择在动漫领域生成表现效果好的AIGC模型做对比,选择合适的模型;

2. 同时针对这一特定场景,专门做模型后训练、采用多种训练策略来强化该场景下的表现效果;

3. 更进一步为了提升训练效率节省成本,对训练过程和推理过程做性能和显存的优化,推动方案实际生产落地;

后续按照该顺序介绍整体的流程;


2. 模型、算力选型+对比验证

2.1 模型选型

生产级主流的视频生成模型选择wan2.1、wan2.2,Wan2.2-I2V-A14B、Wan2.1-I2V-14B-720P;


这两个模型在模型尺寸上是相匹配的,都是14b尺寸的大模型,在功能上是专门用于图片生成5s视频,符合客户实际的场景需求;5b模型过小,生成效果不理想;文生视频的模型不符合客户需求;


wan2.1该版本已经能够生成多种艺术风格的图像,如写实、卡通、水墨、油画、赛博朋克等,细节生成能力增强人物面部、光影效果、纹理细节等方面有了显著提升,能够生成更逼真、更具艺术感的图像;支持用户通过关键词、风格标签、构图控制等方式更精细地控制生成结果,提升创作的可操作性


wan2.2 支持更高分辨率图像的生成,同时在细节刻画、色彩表现、光影渲染等方面更加自然,接近专业艺术作品水平,该版本引入了更先进的风格编码机制,能够实现多种风格的融合与创新,用户可以自由组合不同风格元素,创造出电影级别的视觉效果,WAN2.2 在理解文本描述方面更加精准,能够根据复杂语义生成符合逻辑的场景构图,包括多对象布局、空间关系、动态动作等;支持图像局部编辑、修复、重绘等功能,用户可以在生成图像的基础上进行再创作,提升创作灵活性和实用性


根据实际场景的特点,选定这两个模型用于对比效果训练,选择效果更好的模型用于实际场景;


2.2 算力选型

根据wan2.1、wan2.2模型的算力需求和显存占用,综合分析单卡和多卡训练和推理场景的算力。

推理场景

训练场景

机型1和机型2因显存不符合推理和训练的要求,已排除,机型3的显存足够,但整个机器的算力成本非常高,导致训练的性价比较低;

结合算力成本,综合分析决定采用机型4作为训练和部署推理的机器;因为训练视频文件需要的显存非常大,和帧数、fps相关,需要预留足够多的显存。


2.3 数据集构建

构建多模态训练集,包含:50组小样本数据集和5000组全量数据集,数据集由提示词文本、首帧图片、控制视频、vace视频;小样本用于本地效果验证,大样本用于生产级模型训练;同时对数据集随机切分,按9:1的形式分割训练集和测试集,用测试集评估生成质量;

以下是数据集的构建方法:


以下是数据集的组织格式:

本数据集中使用提示词、训练视频首帧、训练视频传入wan模型做训练;

微调训练方法:

采用LoRA+全参训练的对比训练方法:

在相同数据集上,对比训练wan2.1和wan2.2,lora训练模块选择注意力机制的qkv和前馈神经网络的前两层,全训练模块选择全部;

3. lora微调+全量训练

3.1 训练过程

基于PAI DSW进行小样本数据集的验证测试;


lora微调,在实际测试中发现,因wan2.1、wan2.2本身对显存的占用不同,wan2.1的显存占用42GB,wan2.2占用达到51GB,因此能够被训练的视频文件中的帧数是不一样的,本数据集中全部采用5s视频,fps15,总帧数75,wan2.1能够参与训练的最大帧数是前60帧,也就是视频的前4s,wan2.2能够参与训练的最大帧数是前45帧,也就是视频的前3s;wan2.1、wan2.2采用相同的优化策略;但由于训练帧数因显存原因无法对齐,后续在优化策略中会解决这一问题;而全参数训练占用的显存更加大,wan2.1只能训练到前41帧,wan2.2 25帧;

镜像选择DSW官方镜像:

dsw-registry-vpc.cn-wulanchabu.cr.aliyuncs.com/pai/modelscope:1.29.0-pytorch2.6.0-gpu-py311-cu124-ubuntu22.04

以下是训练的环境依赖,实际测试中dsw的官方镜像支持直接进行训练和推理:

torch>=2.0.0
torchvision
transformers
imageio
imageio[ffmpeg]
safetensors
einops
sentencepiece
protobuf
modelscope
ftfy
pynvml
pandas
accelerate
peft

以下是训练命令,用deepspeed训练框架:

wan2.1 lora

accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path train_dataset \
  --dataset_metadata_path train_dataset/metadata.csv \
  --height 720 \
  --width 1280 \
  --num_frames 60 \
  --dataset_repeat 10 \
  --model_paths '[
    [
        "wan21/diffusion_pytorch_model-00001-of-00007.safetensors",
        "wan21/diffusion_pytorch_model-00002-of-00007.safetensors",
        "wan21/diffusion_pytorch_model-00003-of-00007.safetensors",
        "wan21/diffusion_pytorch_model-00004-of-00007.safetensors",
        "wan21/diffusion_pytorch_model-00005-of-00007.safetensors",
        "wan21/diffusion_pytorch_model-00006-of-00007.safetensors",
        "wan21/diffusion_pytorch_model-00007-of-00007.safetensors"
    ],
    "wan21/models_t5_umt5-xxl-enc-bf16.pth",
    "wan21/Wan2.1_VAE.pth",
    "wan21/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
    ]' \
  --learning_rate 1e-4 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image" \
  --use_gradient_checkpointing_offload

wan2.1 full

accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata.csv \
  --height 720 \
  --width 1280 \
  --num_frames 41 \
  --dataset_repeat 10 \
  --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
  --learning_rate 1e-5 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-I2V-14B-720P_full" \
  --trainable_models "dit" \
  --extra_inputs "input_image" \
  --use_gradient_checkpointing_offload

wan2.2 lora

accelerate launch DiffSynth-Studio/examples/wanvideo/model_training/train.py \
  --dataset_base_path train_dataset \
  --dataset_metadata_path train_dataset/metadata.csv \
  --height 480 \
  --width 832 \
  --num_frames 45 \
  --dataset_repeat 10 \
  --model_paths '[
    [
      "wan22/high_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
      "wan22/high_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
      "wan22/high_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
      "wan22/high_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
      "wan22/high_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
      "wan22/high_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
    ],
    [
      "wan22/low_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
      "wan22/low_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
      "wan22/low_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
      "wan22/low_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
      "wan22/low_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
      "wan22/low_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
    ],
    "wan22/models_t5_umt5-xxl-enc-bf16.pth",
    "wan22/Wan2.1_VAE.pth"
  ]' \
  --learning_rate 1e-4 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image" \
  --max_timestep_boundary 0.358 \
  --min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]


accelerate launch DiffSynth-Studio/examples/wanvideo/model_training/train.py \
  --dataset_base_path train_dataset \
  --dataset_metadata_path train_dataset/metadata.csv \
  --height 480 \
  --width 832 \
  --num_frames 25 \
  --dataset_repeat 10 \
  --model_paths '[
    [
    "wan22/high_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
    "wan22/high_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
    "wan22/high_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
    "wan22/high_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
    "wan22/high_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
    "wan22/high_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
    ],
    [
    "wan22/low_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
    "wan22/low_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
    "wan22/low_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
    "wan22/low_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
    "wan22/low_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
    "wan22/low_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
    ],
    "wan22/models_t5_umt5-xxl-enc-bf16.pth",
    "wan22/Wan2.1_VAE.pth"
    ]' \
  --learning_rate 1e-4 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image" \
  --max_timestep_boundary 1 \
  --min_timestep_boundary 0.358 \
# boundary corresponds to timesteps [0, 900

wan2.2 full

accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata.csv \
  --height 480 \
  --width 832 \
  --num_frames 25 \
  --dataset_repeat 10 \
  --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
  --learning_rate 1e-5 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \
  --trainable_models "dit" \
  --extra_inputs "input_image" \
  --use_gradient_checkpointing_offload \
  --max_timestep_boundary 0.358 \
  --min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]

accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata.csv \
  --height 480 \
  --width 832 \
  --num_frames 49 \
  --dataset_repeat 100 \
  --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
  --learning_rate 1e-5 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \
  --trainable_models "dit" \
  --extra_inputs "input_image" \
  --use_gradient_checkpointing_offload \
  --max_timestep_boundary 1 \
  --min_timestep_boundary 0.358
# boundary corresponds to timesteps [0, 900)


3.2 训练性能数据

在总体的小样本数据集上,50组训练视频,各自取一定的帧数,再乘重复训练次数(epoch),就是总的训练步长,例如wan2.1 lora里取前60帧,那么总步长=60*50*10=30000;总步长决定训练的效率,每次梯度迭代会经历一定的步长;

以下是训练wan2.1的过程:

训练后推理的时长:


3.3 训练后测试推理的命令

以wan2.1 lora为例,以下是加载lora权重和测试推理的代码,注意num_frames参数是控制生成的帧数,如果和训练时指定的帧数相同,则最大化训练的效果,但本训练数据中因显卡最大显存限制,无法训练全部的75帧,所以只取前一部分帧数做训练,num_frames可以不指定具体的值,默认取5s视频的帧数;

import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig


pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(path=[
    "wan21/diffusion_pytorch_model-00001-of-00007.safetensors",
    "wan21/diffusion_pytorch_model-00002-of-00007.safetensors",
    "wan21/diffusion_pytorch_model-00003-of-00007.safetensors",
    "wan21/diffusion_pytorch_model-00004-of-00007.safetensors",
    "wan21/diffusion_pytorch_model-00005-of-00007.safetensors",
    "wan21/diffusion_pytorch_model-00006-of-00007.safetensors",
    "wan21/diffusion_pytorch_model-00007-of-00007.safetensors",]),
        ModelConfig(path="wan21/models_t5_umt5-xxl-enc-bf16.pth"),
        ModelConfig(path="wan21/Wan2.1_VAE.pth"),
        ModelConfig(path="wan21/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
    ],
    use_usp=True,
)

pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-2.safetensors", alpha=1)
pipe.enable_vram_management()


image = Image.open("1.jpg")
prompt='''
Flower-Fairy. A radiant metamorphosis unfolds as the character, encircled by shimmering butterflies, rises from a whirling vortex of whimsical 2D halos. The initial attire transforms into a resplendent emerald green tunic adorned with golden embellishments and a multi-layered tulle skirt, reminiscent of Tinker Bell's iconic outfit, while preserving the original hairstyle. Luminous halos sparkle beneath her feet as the backdrop bursts into a vivid, enchanted forest teeming with bioluminescent plants and flickering fireflies.
'''
# Image-to-video
video = pipe(
    prompt=prompt,
    negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
    input_image=image,
    seed=0, tiled=True,
    height=720, width=1280,
    #num_frames=25
)
save_video(video, "video_lora.mp4", fps=15, quality=5)

wan2.2的效果更加写实风格,wan2.1偏向动漫风格,更适合本场景需求;


4. 训练优化加速

优化加速主要是为了让模型的训练和推理能够提升速度,同时尽量保持原有的性能水平,尽量减少使用算力的时间成本,使训练和推理能够真正的在消费级显卡上使用,成为一个可落地的解决方案;

一般来说,优化加速可以从以下方面去考虑:

  • 训练速度提升;
  • 推理速度提升;
  • 训练时显存占用减少;

训练速度提升是在模型训练过程中,通过高效注意力计算机制、多卡并行训练策略、模型量化等方法,加速模型反向传播计算的时间,同时尽量减少对最优参数的选择的影响,从而提升训练效率;

推理速度提升是在模型训练结束后的推理验证中,通过多卡并行推理策略、缓存机制、高效注意力计算机制等数学方法,加速前向传播的计算时间,同时尽量减少对最优参数的选择的影响,从而提升推理效率;

显存占用减少是模型训练或者推理过程中,通过分块组计算、非核心参数卸载等方法,减少实际在GPU中参与计算的参数数量,很大程度上减少GPU的显存占用,从而在算力不变的情况下能够训练更大的模型、更多的数据;

以下是本项目中实际采用的优化方法:

4.1 Sage Attention - 27% Train & Infer Speed-Up

Sage Attention 是一种基于int8量化的高效注意力计算方法,加速Transformer模型的推理过程,同时保持模型精度。其优势包括:

  • K矩阵平滑:通过减去token间的平均值缓解K矩阵的通道级异常值,并且不影响softmax分数的计算;
  • 混合精度计算:Q和K使用INT8量化,P和V保持FP16并采用FP16累加器;
  • 自适应量化策略:根据层对精度的敏感度动态选择量化粒度(per-token或per-block);
  • 硬件优化:基于Triton实现的高效内核,利用NVIDIA Tensor Core的INT8和FP16指令加速计算;

sage计算方法如上,K矩阵的平滑计算,特别适合对扩散模型diffusion的视频帧计算,能够很大程度减少相同连续区块的重复计算,因此对AIGC类模型加速效果比flash attn更好,在消费级显卡上,sage attention能实现比flash attention高2-3倍的加速比;但目前语言类模型和视觉理解类模型的加速效果还是采用flash attention效果更好;

以下是使用方法:

基础环境需求

从源码编译安装sage

git clone https://github.com/thu-ml/SageAttention.git
cd SageAttention 
export EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32
python setup.py install

对一般的Diffusion模型,使用方法是修改模型transformer架构中attention矩阵,用sage替换torch自带的dot-product-attention

from sageattention import sageattn
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
F.scaled_dot_product_attention = attn_output

但在wan2.1、wan2.2模型已经原生支持sage加速,只需要启动训练命令,自动会按照sage方法计算attention;本项目中,开启sage和不开启的效果对比,wan2.1 lora为例

4.2 TeaCache - 28% Infer Speed-Up

在视频生成中,很多帧之间是高度相似的,比如:摄像机缓慢移动,背景不变、只有前景物体移动,人物说话时背景不变;在这种情况下,如果每一帧都从头开始重新生成,会浪费大量计算资源;TeaCache 是一种 用于视频生成任务的缓存加速机制,通过缓存帧之间的相似内容,减少重复计算,从而提升视频生成速度并降低计算资源消耗。


TeaCache 主要包含以下几个步骤:


1. 在生成第 t 帧时,模型会与前一帧 t−1 做对比,判断哪些区域变化较小。使用 L1 距离(像素级差异) 或 特征空间差异(latent)来衡量两帧之间的差异。如果某个区域的差异小于阈值 ,则认为该区域“变化不大”,可以使用缓存。

2. 缓存的内容通常是 Latent 空间中的中间表示。缓存区域的坐标(bounding box)、缓存区域的 latent 特征、缓存区域的时间戳(用于判断是否过期)

3. 在生成下一帧时:对于变化较小的区域,直接从缓存中提取之前计算好的 latent 特征,避免重复计算。只对变化较大的区域进行完整的扩散模型计算。

4. 在每一帧生成完成后,会更新缓存:更新缓存区域的 latent 特征、移除过期的缓存条目(例如超过一定帧数)、添加新生成的缓存区域;


tea_cache_l1_thresh是teacache的关键参数,判断帧间相似度的阈值,0.02-0.1取值,值越小,缓存帧数越多,视频生成越快,质量损失大,值越大,缓存帧数越少,视频生成越慢,质量损失小;


在wan模型中如何集成:


在 WanVideoPipeline 的推理流程中,TeaCache 模块被嵌入到每一帧的生成过程中:

for frame_idx in range(total_frames):
    text_emb = encode_text(prompt)
    noise = get_initial_noise()

    if frame_idx > 0:
        cache_mask = tea_cache.get_cache_mask(prev_latent, current_latent)

    latent = diffusion_model(noise, text_emb, cache_mask=cache_mask)

    tea_cache.update_cache(latent, frame_idx)

    frame = vae.decode(latent)


以下是使用方法:

基础环境需求

accelerate>0.17.0
bs4
click
colossalai==0.4.0
diffusers==0.30.0
einops
fabric
ftfy
imageio
imageio-ffmpeg
matplotlib
ninja
numpy<2.0.0
omegaconf
packaging
psutil
pydantic
ray
rich
safetensors
sentencepiece
timm
torch>=1.13
tqdm
peft==0.13.2
transformers==4.39.3

从源码编译安装

git clone https://github.com/ali-vilab/TeaCache.git
cd TeaCache
python setup.py install

代码使用,已内置在wanvideo pipeline中:

video = pipe(
    prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
    negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
    seed=0, tiled=True,
    # TeaCache parameters
    tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
    tea_cache_model_id="Wan2.1-I2V-14B-720P",
)

效果对比

4.3 XDIT Sequence Parallelism - 400% Train & Infer Speed-Up

xDiT是一种面向扩散模型的大规模并行推理框架,其核心目标是通过创新的并行化技术和编译优化技术,实现高效的大规模扩散模型训练与推理。

xDiT的核心原理

1. PipeFusion:分块流水线并行,将图像分割为多个patch,用pipeline在不同卡上并行处理,最后整合结果,有效减少单卡上的显存占用;

2. USP:统一序列并行,针对扩散模型中的长序列生成任务,将序列的不同维度拆分到不同卡,实现多卡的序列并行计算,最后整合向量。

3. CFG Parallel:在分类器自由指导过程中,将正向和负向样本的计算分配到不同卡,降低单次推理的计算负载。

4. DistVAE:分布式VAE模块,对扩散模型中的VAE模块进行分块并行处理,避免显存溢出。

5. 编译优化技术:xDiT通过编译器技术优化GPU执行效率,主要依赖:Torch.compile:PyTorch 2.0的JIT编译器,通过融合

算子、消除冗余计算提升性能;OneDiff:针对扩散模型的专用编译优化工具,支持内核融合、内存复用等高级特性。


以下是使用方法:


从源码编译

git clone https://github.com/xdit-project/xDiT.git
cd xDiT
pip install -e .
# Or optionally, with flash attention
pip install -e ".[flash-attn]"

使用方法:

多卡并行训练

accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata.csv \
  --height 720 \
  --width 1280 \
  --num_frames 49 \
  --dataset_repeat 100 \
  --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
  --learning_rate 1e-4 \
  --num_epochs 5 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image" \
  --use_gradient_checkpointing_offload

多卡并行推理

## 首先pipeline启用usp
pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    use_usp=True,
    model_configs=[
        ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
        ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
        ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
    ],
)
pipe.enable_vram_management()

## 然后shell启动多卡推理
torchrun --standalone --nproc_per_node=8 test.py

效果对比


4.4 Gradient Checkpointing Offload - 13% GPU Memory Reduce

梯度检查点卸载(Gradient Checkpointing Offload)是一种以计算换显存的优化技术。在前向传播时,模型不会存储所有中间激活值到GPU显存,并且把激活值卸载到内存中,而是仅保留部分关键层的激活值,并在需要时重新计算中间值。这可以显著减少显存占用,但会增加少量计算时间,还引入额外的 GPU-CPU 数据传输开销。


当GPU的显存比较紧张时,为了实现训练大模型,可以通过这种方法显著减少显存占用,实现训练目标;


实现步骤:

1. 前向传播阶段

计算激活值:模型在前向传播时会计算每一层的输出(即激活值),模型会将部分激活值 从 GPU 显存移动到 CPU 内存。对于未启用梯度检查点的层,激活值通常会保留到显存中。


2. 反向传播阶段

在反向传播计算梯度时,模型需要中间激活值来计算梯度。则需要把激活值从内存加载回 GPU 显存,优化卸载大幅减少显存占用(降低20%左右显存占用)。

实现代码,以下是前向传播时传递的关键参数:

# 在前向传播时传递卸载参数

inputs_shared = {

"use_gradient_checkpointing": self.use_gradient_checkpointing,

"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,

...

}

import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath
os.environ["TOKENIZERS_PARALLELISM"] = "false"



class WanTrainingModule(DiffusionTrainingModule):
    def __init__(
        self,
        model_paths=None, model_id_with_origin_paths=None, audio_processor_config=None,
        trainable_models=None,
        lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None,
        use_gradient_checkpointing=True,
        use_gradient_checkpointing_offload=False,
        extra_inputs=None,
        max_timestep_boundary=1.0,
        min_timestep_boundary=0.0,
    ):
        super().__init__()
        # Load models
        model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
        if audio_processor_config is not None:
            audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1])
        self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, audio_processor_config=audio_processor_config)
        
        # Training mode
        self.switch_pipe_to_training_mode(
            self.pipe, trainable_models,
            lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
            enable_fp8_training=False,
        )
        
        # Store other configs
        self.use_gradient_checkpointing = use_gradient_checkpointing
        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
        self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
        self.max_timestep_boundary = max_timestep_boundary
        self.min_timestep_boundary = min_timestep_boundary
        
        
    def forward_preprocess(self, data):
        # CFG-sensitive parameters
        inputs_posi = {"prompt": data["prompt"]}
        inputs_nega = {}
        
        # CFG-unsensitive parameters
        inputs_shared = {
            # Assume you are usingthis pipeline for inference,
            # please fill in the input parameters.
            "input_video": data["video"],
            "height": data["video"][0].size[1],
            "width": data["video"][0].size[0],
            "num_frames": len(data["video"]),
            # Please donot modify the following parameters
            # unless you clearly know what this will cause.
            "cfg_scale": 1,
            "tiled": False,
            "rand_device": self.pipe.device,
            "use_gradient_checkpointing": self.use_gradient_checkpointing,
            "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
            "cfg_merge": False,
            "vace_scale": 1,
            "max_timestep_boundary": self.max_timestep_boundary,
            "min_timestep_boundary": self.min_timestep_boundary,
        }
        
        # Extra inputs
        for extra_input in self.extra_inputs:
            if extra_input == "input_image":
                inputs_shared["input_image"] = data["video"][0]
            elif extra_input == "end_image":
                inputs_shared["end_image"] = data["video"][-1]
            elif extra_input == "reference_image"or extra_input == "vace_reference_image":
                inputs_shared[extra_input] = data[extra_input][0]
            else:
                inputs_shared[extra_input] = data[extra_input]
        
        # Pipeline units will automatically process the input parameters.
        for unit in self.pipe.units:
            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
        return {**inputs_shared, **inputs_posi}
    
    
    def forward(self, data, inputs=None):
        if inputs is None: inputs = self.forward_preprocess(data)
        models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
        loss = self.pipe.training_loss(**models, **inputs)
        return loss


if __name__ == "__main__":
    parser = wan_parser()
    args = parser.parse_args()
    dataset = UnifiedDataset(
        base_path=args.dataset_base_path,
        metadata_path=args.dataset_metadata_path,
        repeat=args.dataset_repeat,
        data_file_keys=args.data_file_keys.split(","),
        main_data_operator=UnifiedDataset.default_video_operator(
            base_path=args.dataset_base_path,
            max_pixels=args.max_pixels,
            height=args.height,
            width=args.width,
            height_division_factor=16,
            width_division_factor=16,
            num_frames=args.num_frames,
            time_division_factor=4,
            time_division_remainder=1,
        ),
        special_operator_map={
            "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
            "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),
        }
    )
    model = WanTrainingModule(
        model_paths=args.model_paths,
        model_id_with_origin_paths=args.model_id_with_origin_paths,
        audio_processor_config=args.audio_processor_config,
        trainable_models=args.trainable_models,
        lora_base_model=args.lora_base_model,
        lora_target_modules=args.lora_target_modules,
        lora_rank=args.lora_rank,
        lora_checkpoint=args.lora_checkpoint,
        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
        extra_inputs=args.extra_inputs,
        max_timestep_boundary=args.max_timestep_boundary,
        min_timestep_boundary=args.min_timestep_boundary,
    )
    model_logger = ModelLogger(
        args.output_path,
        remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
    )
    launch_training_task(dataset, model, model_logger, args=args)

在训练时启用该参数gradient_checkpointing_offload

accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata.csv \
  --height 720 \
  --width 1280 \
  --num_frames 49 \
  --dataset_repeat 100 \
  --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
  --learning_rate 1e-4 \
  --num_epochs 5 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image" \
  --use_gradient_checkpointing_offload

效果对比

gpu显存占用下降了13%,相对应的训练时间延长了9%,牺牲了少量时间,来降低显存占用,从而让更多的帧数可以参与训练,是合理的优化。

4.5 Tiled VAE - 11% GPU Memory Reduce

视频通常由多帧连续图像组成,直接处理高分辨率视频序列会占用大量显存。分块编码解码技术通过逐块处理单帧图像,降低单次计算的显存需求,能有效减少显存占用,但会略微降低视频生成质量;

视频VAE的latent空间通常为5D张量([B, C, T, H, W],B=batch size, C=通道数, T=帧数, H/W=高度/宽度),先对单帧图像进行空间编码/解码,再通过时间模块(Time-Series Transformer)建模帧间关系。


分块技术的实现步骤:

1. 编码阶段(视频到latent空间)

对输入视频的每一帧独立分块。例如,若单帧分辨率为[H, W],则用kernel划分为多个tile_size大小的向量块;通过tile_stride确保相邻块部分重叠,避免块边界处的语义断裂。

2. 解码阶段(latent空间到视频)

逐帧解码:对潜在空间的每帧独立解码,解码向量块到对应图像块;对重叠区域的像素值进行加权平均,减少块边界伪影。对解码后的多帧图像通过时间滤波增强帧间连续性。

3. 时间一致性优化策略

在解码前,通过时间模块(Time-Series Transformer)对分块后的潜在特征进行全局时间一致性约束;在分块时对齐相邻帧的块位置,确保同一物体在不同帧中的重建区域一致。

4. 显存优化

根据显存限制,动态调整tile_size和tile_stride;用Sequence Parallel同时处理多个块;

以下是使用方法:

#tiled: Whether to enable tiled VAE inference, default is False. Setting to True significantly reduces VRAM usage during VAE encoding/decoding but introduces small errors and slightly increases inference time.
#tile_size: Tile size during VAE encoding/decoding, defaultis(30, 52), only effective when tiled=True.
#tile_stride: Stride of tiles during VAE encoding/decoding, defaultis(15, 26), only effective when tiled=True. Must be less than or equal to tile_size.

video = pipe(
    prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
    negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
    seed=0, tiled=True,
    tile_size=[30,52],tile_stride=[15,26]
)

效果对比

Tiled VAE有效减少了显存占用,同时速度基本不影响;

4.6 Quantization - 33% Train Speed-Up

在训练模型初期,发现wan2.1、wan2.2模型本身架构不同,wan2.2有高噪声和低噪声两套架构,导致模型大小比wan2.1大一倍左右,因此两个模型的显存占用存在差异,在相同数据集下,导致了能够参与训练的视频的帧数两者不一样,例如lora训练中,wan2.1能训练到前60帧,而wan2.2只有前45帧,理论上来说,帧数越少,那么模型能遵循的视频效果越差,为了再对比训练中尽量保持条件一致,并且要保持wan2.2-I2V-A14B模型的基本架构,最好的办法就是模型量化,用精度更低的wan2.2模型,降低大量显存占用,提高训练帧数,还能加速训练时间。


后面的结果不再采用原始的wan2.2模型,后续的结果全部基于量化wan2.2和wan2.1的对比效果。


wan2.2系列模型族由众多的量化模型,AWQ、GGUF、FP8、FP16、INT8,分别代表不同的参数精度,根据量化模型的尺寸,最终选择wan2.2-fp8模型,模型显存占用基本和wan2.1达到相同,训练帧数也和wan2.1最大帧数相同。

链接

以下是原始的两个模型对比:

模型量化是深度学习模型优化的重要技术,在通过降低模型参数和计算的精度来减少模型大小、内存占用和计算开销,同时尽量保持模型性能。


核心方法是将高精度浮点数(如FP32)转换为低精度表示(如INT8或FP16),从而减少存储和计算需求,有效降低显存占用和计算资源消耗。


量化类型 权重量化:仅对模型权重进行量化,激活值保持高精度。 激活量化:对输入/输出激活值进行量化。 全量化:同时量化权重和激活值。 混合量化:对不同层使用不同精度。


量化方法


对称量化:数值范围对称分布(如-127~127),适用于激活值接近零的情况。 非对称量化:数值范围非对称(如0~255),适用于偏移较大的数据。 动态量化:推理时根据输入动态调整量化参数。 静态量化:训练后通过校准数据集确定量化参数。 量化感知训练:在训练阶段模拟量化误差,提升量化后模型精度。

以下是常见的量化精度:

效果对比

5. 结果

总结训练的全流程链路,一般的模型后训练都可以采用这一套完整的流程,实现系统化、可复用、可扩展的实践:

模型的训练实际上是一个持续迭代优化的过程,没有一蹴而就的完美训练,更多的是要持续的调整训练参数、采用新的训练方法、对比实验等等来不断提升模型的准确率和性能,这需要反复迭代,才能达到理想中的最优效果;

在评价AIGC生成内容的质量时,有一些常用的量化指标,在本场景中,动态表现、相机控制、帧质量、准确性是衡量生成质量的重要参考因素,决定采用客观指标+主观打分的方法来综合评价模型训练的效果;客观指标决定采用行业通用的视频质量评价参数:

峰值信噪比(PSNR)

计算方式:基于均方误差(MSE)的对数转换

评价维度:像素级保真度,反映视频帧与参考帧的噪声差异;

典型阈值:10-50,优质视频≥32

结构相似性指数(SSIM)

计算方式:8*8 kernel的滑动窗口计算

评价维度:动态场景中空间结构与时间连续性的保持能力;

标准范围:0(完全失真)-1(完全匹配),优质生成需≥0.85;

这是设计的评测流程:


PSNR的计算代码:

import cv2
import numpy as np

def calculate_psnr(video_ref_path, video_gen_path):
    cap_ref = cv2.VideoCapture(video_ref_path)
    cap_gen = cv2.VideoCapture(video_gen_path)
    
    psnr_list = []
    while True:
        ret_ref, frame_ref = cap_ref.read()
        ret_gen, frame_gen = cap_gen.read()
        
        ifnot ret_ref ornot ret_gen:
            break
        
        # 确保分辨率一致
        if frame_ref.shape != frame_gen.shape:
            frame_gen = cv2.resize(frame_gen, (frame_ref.shape[1], frame_ref.shape[0]))
        
        # 计算MSE
        mse = np.mean((frame_ref - frame_gen) ** 2)
        if mse == 0: 
            psnr = float('inf')
        else:
            psnr = 20 * np.log10(255.0 / np.sqrt(mse))
        psnr_list.append(psnr)
    
    return np.mean(psnr_list)


psnr_value = calculate_psnr("reference_video.mp4", "generated_video.mp4")
print(f"PSNR: {psnr_value:.2f} dB")


SSIM计算代码:

from skimage.metrics import structural_similarity as ssim
import cv2
import numpy as np

def calculate_ssim(video_ref_path, video_gen_path):
    cap_ref = cv2.VideoCapture(video_ref_path)
    cap_gen = cv2.VideoCapture(video_gen_path)
    
    ssim_list = []
    while True:
        ret_ref, frame_ref = cap_ref.read()
        ret_gen, frame_gen = cap_gen.read()
        
        ifnot ret_ref ornot ret_gen:
            break
        
        if frame_ref.shape != frame_gen.shape:
            frame_gen = cv2.resize(frame_gen, (frame_ref.shape[1], frame_ref.shape[0]))
        
        # 转换为灰度图
        gray_ref = cv2.cvtColor(frame_ref, cv2.COLOR_BGR2GRAY)
        gray_gen = cv2.cvtColor(frame_gen, cv2.COLOR_BGR2GRAY)
        

        score, _ = ssim(gray_ref, gray_gen, full=True)
        ssim_list.append(score)
    
    return np.mean(ssim_list)


ssim_value = calculate_ssim("reference_video.mp4", "generated_video.mp4")
print(f"SSIM: {ssim_value:.4f}")


主观评价由客户划定几个评价指标:动态表现、相机控制、帧质量、目标准确性四个方面分别做人工打分,评分1-5分;

最后综合客观得分+主观得分,综合确定视频生成的效果;计算方式是加权分,总分=主观分/20 * 50% + PSNR/50 * 25% + SSIM/1 *25%;

50组样本里,随机划分5个测试数据,对这5个用于测试的未参与训练的素材图片,测试四个模型的生成效果,再经过自动评分+人工评分,得到如下的得分结果:

最终根据生成视频的总体得分情况,wan2.1 full的生成效果最好,但整体的全量训练所需算力成本比lora大很多,综合比较wan2.1、wan2.2 的训练成本,包括训练时间、所需算力,综合选择性价比最高的wan2.1 lora模型作为上线生产环境的主力模型。


来源  |  阿里云开发者公众号

作者  |  李德

作者介绍
目录