1. 需求场景:AI特效生成
本项目旨在为社交类APP集成AIGC驱动的个人宣传视频生成功能,通过AI技术将用户上传的真人图像,转化为具有动漫风格的个性化短视频,尤其聚焦于“真人变身跳舞动漫仙女”的特定场景。项目采用通义万相系列AIGC模型,结合定制化训练与推理优化,打造高效、高质量、可商业落地的视频生成解决方案。
项目需求
客户希望在社交APP中新增一个功能模块:用户上传真人图片,系统自动生成一段具有动漫风格的跳舞短视频,用于个人形象展示、社交传播等场景。该功能需具备以下核心能力:
- 人物动漫形象化转换:将真人形象转化为动漫风格角色;
- 动态动作生成:支持舞蹈等复杂动作序列;
- 高质量视频输出:支持720p分辨率,帧率稳定,画面细腻;
- 风格一致性控制:确保生成视频在风格、色彩、动作上保持统一;
痛点问题
虽然当前市面上已有多个主流AIGC视频生成模型(如Stable Diffusion 3、Runway ML、Pika Labs等),但在本项目场景下存在以下关键痛点:
- 动态动作生成不稳定:现有模型在复杂动作(如舞蹈)生成中容易出现动作不连贯、肢体穿透、帧抖动等问题;
- 动漫风格控制能力弱:难以精准实现“仙女”类动漫风格,风格一致性差;
- 视频质量在低分辨率下下降明显:纹理丢失严重,细节表现不佳;
- 推理速度慢,不满足生产部署需求:无法在主流消费级显卡上实现高效推理;
- 缺乏个性化定制能力:无法针对“真人→动漫仙女”的特定场景进行模型强化;
针对以上痛点问题,本项目决定:
1. 进行模型选型,选择在动漫领域生成表现效果好的AIGC模型做对比,选择合适的模型;
2. 同时针对这一特定场景,专门做模型后训练、采用多种训练策略来强化该场景下的表现效果;
3. 更进一步为了提升训练效率节省成本,对训练过程和推理过程做性能和显存的优化,推动方案实际生产落地;
后续按照该顺序介绍整体的流程;
2. 模型、算力选型+对比验证
2.1 模型选型
生产级主流的视频生成模型选择wan2.1、wan2.2,Wan2.2-I2V-A14B、Wan2.1-I2V-14B-720P;
这两个模型在模型尺寸上是相匹配的,都是14b尺寸的大模型,在功能上是专门用于图片生成5s视频,符合客户实际的场景需求;5b模型过小,生成效果不理想;文生视频的模型不符合客户需求;
wan2.1该版本已经能够生成多种艺术风格的图像,如写实、卡通、水墨、油画、赛博朋克等,细节生成能力增强人物面部、光影效果、纹理细节等方面有了显著提升,能够生成更逼真、更具艺术感的图像;支持用户通过关键词、风格标签、构图控制等方式更精细地控制生成结果,提升创作的可操作性
wan2.2 支持更高分辨率图像的生成,同时在细节刻画、色彩表现、光影渲染等方面更加自然,接近专业艺术作品水平,该版本引入了更先进的风格编码机制,能够实现多种风格的融合与创新,用户可以自由组合不同风格元素,创造出电影级别的视觉效果,WAN2.2 在理解文本描述方面更加精准,能够根据复杂语义生成符合逻辑的场景构图,包括多对象布局、空间关系、动态动作等;支持图像局部编辑、修复、重绘等功能,用户可以在生成图像的基础上进行再创作,提升创作灵活性和实用性
根据实际场景的特点,选定这两个模型用于对比效果训练,选择效果更好的模型用于实际场景;
2.2 算力选型
根据wan2.1、wan2.2模型的算力需求和显存占用,综合分析单卡和多卡训练和推理场景的算力。
推理场景
训练场景
机型1和机型2因显存不符合推理和训练的要求,已排除,机型3的显存足够,但整个机器的算力成本非常高,导致训练的性价比较低;
结合算力成本,综合分析决定采用机型4作为训练和部署推理的机器;因为训练视频文件需要的显存非常大,和帧数、fps相关,需要预留足够多的显存。
2.3 数据集构建
构建多模态训练集,包含:50组小样本数据集和5000组全量数据集,数据集由提示词文本、首帧图片、控制视频、vace视频;小样本用于本地效果验证,大样本用于生产级模型训练;同时对数据集随机切分,按9:1的形式分割训练集和测试集,用测试集评估生成质量;
以下是数据集的构建方法:
以下是数据集的组织格式:
本数据集中使用提示词、训练视频首帧、训练视频传入wan模型做训练;
微调训练方法:
采用LoRA+全参训练的对比训练方法:
在相同数据集上,对比训练wan2.1和wan2.2,lora训练模块选择注意力机制的qkv和前馈神经网络的前两层,全训练模块选择全部;
3. lora微调+全量训练
3.1 训练过程
基于PAI DSW进行小样本数据集的验证测试;
lora微调,在实际测试中发现,因wan2.1、wan2.2本身对显存的占用不同,wan2.1的显存占用42GB,wan2.2占用达到51GB,因此能够被训练的视频文件中的帧数是不一样的,本数据集中全部采用5s视频,fps15,总帧数75,wan2.1能够参与训练的最大帧数是前60帧,也就是视频的前4s,wan2.2能够参与训练的最大帧数是前45帧,也就是视频的前3s;wan2.1、wan2.2采用相同的优化策略;但由于训练帧数因显存原因无法对齐,后续在优化策略中会解决这一问题;而全参数训练占用的显存更加大,wan2.1只能训练到前41帧,wan2.2 25帧;
镜像选择DSW官方镜像:
dsw-registry-vpc.cn-wulanchabu.cr.aliyuncs.com/pai/modelscope:1.29.0-pytorch2.6.0-gpu-py311-cu124-ubuntu22.04
以下是训练的环境依赖,实际测试中dsw的官方镜像支持直接进行训练和推理:
torch>=2.0.0 torchvision transformers imageio imageio[ffmpeg] safetensors einops sentencepiece protobuf modelscope ftfy pynvml pandas accelerate peft
以下是训练命令,用deepspeed训练框架:
wan2.1 lora
accelerate launch examples/wanvideo/model_training/train.py \ --dataset_base_path train_dataset \ --dataset_metadata_path train_dataset/metadata.csv \ --height 720 \ --width 1280 \ --num_frames 60 \ --dataset_repeat 10 \ --model_paths '[ [ "wan21/diffusion_pytorch_model-00001-of-00007.safetensors", "wan21/diffusion_pytorch_model-00002-of-00007.safetensors", "wan21/diffusion_pytorch_model-00003-of-00007.safetensors", "wan21/diffusion_pytorch_model-00004-of-00007.safetensors", "wan21/diffusion_pytorch_model-00005-of-00007.safetensors", "wan21/diffusion_pytorch_model-00006-of-00007.safetensors", "wan21/diffusion_pytorch_model-00007-of-00007.safetensors" ], "wan21/models_t5_umt5-xxl-enc-bf16.pth", "wan21/Wan2.1_VAE.pth", "wan21/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" ]' \ --learning_rate 1e-4 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "input_image" \ --use_gradient_checkpointing_offload
wan2.1 full
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ --dataset_metadata_path data/example_video_dataset/metadata.csv \ --height 720 \ --width 1280 \ --num_frames 41 \ --dataset_repeat 10 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ --learning_rate 1e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.1-I2V-14B-720P_full" \ --trainable_models "dit" \ --extra_inputs "input_image" \ --use_gradient_checkpointing_offload
wan2.2 lora
accelerate launch DiffSynth-Studio/examples/wanvideo/model_training/train.py \ --dataset_base_path train_dataset \ --dataset_metadata_path train_dataset/metadata.csv \ --height 480 \ --width 832 \ --num_frames 45 \ --dataset_repeat 10 \ --model_paths '[ [ "wan22/high_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors" ], [ "wan22/low_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors" ], "wan22/models_t5_umt5-xxl-enc-bf16.pth", "wan22/Wan2.1_VAE.pth" ]' \ --learning_rate 1e-4 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "input_image" \ --max_timestep_boundary 0.358 \ --min_timestep_boundary 0 # boundary corresponds to timesteps [900, 1000] accelerate launch DiffSynth-Studio/examples/wanvideo/model_training/train.py \ --dataset_base_path train_dataset \ --dataset_metadata_path train_dataset/metadata.csv \ --height 480 \ --width 832 \ --num_frames 25 \ --dataset_repeat 10 \ --model_paths '[ [ "wan22/high_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors", "wan22/high_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors" ], [ "wan22/low_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors", "wan22/low_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors" ], "wan22/models_t5_umt5-xxl-enc-bf16.pth", "wan22/Wan2.1_VAE.pth" ]' \ --learning_rate 1e-4 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "input_image" \ --max_timestep_boundary 1 \ --min_timestep_boundary 0.358 \ # boundary corresponds to timesteps [0, 900
wan2.2 full
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ --dataset_metadata_path data/example_video_dataset/metadata.csv \ --height 480 \ --width 832 \ --num_frames 25 \ --dataset_repeat 10 \ --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ --learning_rate 1e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \ --trainable_models "dit" \ --extra_inputs "input_image" \ --use_gradient_checkpointing_offload \ --max_timestep_boundary 0.358 \ --min_timestep_boundary 0 # boundary corresponds to timesteps [900, 1000] accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ --dataset_metadata_path data/example_video_dataset/metadata.csv \ --height 480 \ --width 832 \ --num_frames 49 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ --learning_rate 1e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \ --trainable_models "dit" \ --extra_inputs "input_image" \ --use_gradient_checkpointing_offload \ --max_timestep_boundary 1 \ --min_timestep_boundary 0.358 # boundary corresponds to timesteps [0, 900)
3.2 训练性能数据
在总体的小样本数据集上,50组训练视频,各自取一定的帧数,再乘重复训练次数(epoch),就是总的训练步长,例如wan2.1 lora里取前60帧,那么总步长=60*50*10=30000;总步长决定训练的效率,每次梯度迭代会经历一定的步长;
以下是训练wan2.1的过程:
训练后推理的时长:
3.3 训练后测试推理的命令
以wan2.1 lora为例,以下是加载lora权重和测试推理的代码,注意num_frames参数是控制生成的帧数,如果和训练时指定的帧数相同,则最大化训练的效果,但本训练数据中因显卡最大显存限制,无法训练全部的75帧,所以只取前一部分帧数做训练,num_frames可以不指定具体的值,默认取5s视频的帧数;
import torch from PIL import Image from diffsynth import save_video, VideoData from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(path=[ "wan21/diffusion_pytorch_model-00001-of-00007.safetensors", "wan21/diffusion_pytorch_model-00002-of-00007.safetensors", "wan21/diffusion_pytorch_model-00003-of-00007.safetensors", "wan21/diffusion_pytorch_model-00004-of-00007.safetensors", "wan21/diffusion_pytorch_model-00005-of-00007.safetensors", "wan21/diffusion_pytorch_model-00006-of-00007.safetensors", "wan21/diffusion_pytorch_model-00007-of-00007.safetensors",]), ModelConfig(path="wan21/models_t5_umt5-xxl-enc-bf16.pth"), ModelConfig(path="wan21/Wan2.1_VAE.pth"), ModelConfig(path="wan21/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), ], use_usp=True, ) pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-2.safetensors", alpha=1) pipe.enable_vram_management() image = Image.open("1.jpg") prompt=''' Flower-Fairy. A radiant metamorphosis unfolds as the character, encircled by shimmering butterflies, rises from a whirling vortex of whimsical 2D halos. The initial attire transforms into a resplendent emerald green tunic adorned with golden embellishments and a multi-layered tulle skirt, reminiscent of Tinker Bell's iconic outfit, while preserving the original hairstyle. Luminous halos sparkle beneath her feet as the backdrop bursts into a vivid, enchanted forest teeming with bioluminescent plants and flickering fireflies. ''' # Image-to-video video = pipe( prompt=prompt, negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", input_image=image, seed=0, tiled=True, height=720, width=1280, #num_frames=25 ) save_video(video, "video_lora.mp4", fps=15, quality=5)
wan2.2的效果更加写实风格,wan2.1偏向动漫风格,更适合本场景需求;
4. 训练优化加速
优化加速主要是为了让模型的训练和推理能够提升速度,同时尽量保持原有的性能水平,尽量减少使用算力的时间成本,使训练和推理能够真正的在消费级显卡上使用,成为一个可落地的解决方案;
一般来说,优化加速可以从以下方面去考虑:
- 训练速度提升;
- 推理速度提升;
- 训练时显存占用减少;
训练速度提升是在模型训练过程中,通过高效注意力计算机制、多卡并行训练策略、模型量化等方法,加速模型反向传播计算的时间,同时尽量减少对最优参数的选择的影响,从而提升训练效率;
推理速度提升是在模型训练结束后的推理验证中,通过多卡并行推理策略、缓存机制、高效注意力计算机制等数学方法,加速前向传播的计算时间,同时尽量减少对最优参数的选择的影响,从而提升推理效率;
显存占用减少是模型训练或者推理过程中,通过分块组计算、非核心参数卸载等方法,减少实际在GPU中参与计算的参数数量,很大程度上减少GPU的显存占用,从而在算力不变的情况下能够训练更大的模型、更多的数据;
以下是本项目中实际采用的优化方法:
4.1 Sage Attention - 27% Train & Infer Speed-Up
Sage Attention 是一种基于int8量化的高效注意力计算方法,加速Transformer模型的推理过程,同时保持模型精度。其优势包括:
- K矩阵平滑:通过减去token间的平均值缓解K矩阵的通道级异常值,并且不影响softmax分数的计算;
- 混合精度计算:Q和K使用INT8量化,P和V保持FP16并采用FP16累加器;
- 自适应量化策略:根据层对精度的敏感度动态选择量化粒度(per-token或per-block);
- 硬件优化:基于Triton实现的高效内核,利用NVIDIA Tensor Core的INT8和FP16指令加速计算;
sage计算方法如上,K矩阵的平滑计算,特别适合对扩散模型diffusion的视频帧计算,能够很大程度减少相同连续区块的重复计算,因此对AIGC类模型加速效果比flash attn更好,在消费级显卡上,sage attention能实现比flash attention高2-3倍的加速比;但目前语言类模型和视觉理解类模型的加速效果还是采用flash attention效果更好;
以下是使用方法:
基础环境需求
从源码编译安装sage
git clone https://github.com/thu-ml/SageAttention.git cd SageAttention export EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 python setup.py install
对一般的Diffusion模型,使用方法是修改模型transformer架构中attention矩阵,用sage替换torch自带的dot-product-attention
from sageattention import sageattn attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False) F.scaled_dot_product_attention = attn_output
但在wan2.1、wan2.2模型已经原生支持sage加速,只需要启动训练命令,自动会按照sage方法计算attention;本项目中,开启sage和不开启的效果对比,wan2.1 lora为例
4.2 TeaCache - 28% Infer Speed-Up
在视频生成中,很多帧之间是高度相似的,比如:摄像机缓慢移动,背景不变、只有前景物体移动,人物说话时背景不变;在这种情况下,如果每一帧都从头开始重新生成,会浪费大量计算资源;TeaCache 是一种 用于视频生成任务的缓存加速机制,通过缓存帧之间的相似内容,减少重复计算,从而提升视频生成速度并降低计算资源消耗。
TeaCache 主要包含以下几个步骤:
1. 在生成第 t 帧时,模型会与前一帧 t−1 做对比,判断哪些区域变化较小。使用 L1 距离(像素级差异) 或 特征空间差异(latent)来衡量两帧之间的差异。如果某个区域的差异小于阈值 ,则认为该区域“变化不大”,可以使用缓存。
2. 缓存的内容通常是 Latent 空间中的中间表示。缓存区域的坐标(bounding box)、缓存区域的 latent 特征、缓存区域的时间戳(用于判断是否过期)
3. 在生成下一帧时:对于变化较小的区域,直接从缓存中提取之前计算好的 latent 特征,避免重复计算。只对变化较大的区域进行完整的扩散模型计算。
4. 在每一帧生成完成后,会更新缓存:更新缓存区域的 latent 特征、移除过期的缓存条目(例如超过一定帧数)、添加新生成的缓存区域;
tea_cache_l1_thresh是teacache的关键参数,判断帧间相似度的阈值,0.02-0.1取值,值越小,缓存帧数越多,视频生成越快,质量损失大,值越大,缓存帧数越少,视频生成越慢,质量损失小;
在wan模型中如何集成:
在 WanVideoPipeline 的推理流程中,TeaCache 模块被嵌入到每一帧的生成过程中:
for frame_idx in range(total_frames): text_emb = encode_text(prompt) noise = get_initial_noise() if frame_idx > 0: cache_mask = tea_cache.get_cache_mask(prev_latent, current_latent) latent = diffusion_model(noise, text_emb, cache_mask=cache_mask) tea_cache.update_cache(latent, frame_idx) frame = vae.decode(latent)
以下是使用方法:
基础环境需求
accelerate>0.17.0 bs4 click colossalai==0.4.0 diffusers==0.30.0 einops fabric ftfy imageio imageio-ffmpeg matplotlib ninja numpy<2.0.0 omegaconf packaging psutil pydantic ray rich safetensors sentencepiece timm torch>=1.13 tqdm peft==0.13.2 transformers==4.39.3
从源码编译安装
git clone https://github.com/ali-vilab/TeaCache.git cd TeaCache python setup.py install
代码使用,已内置在wanvideo pipeline中:
video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", seed=0, tiled=True, # TeaCache parameters tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality. tea_cache_model_id="Wan2.1-I2V-14B-720P", )
效果对比
4.3 XDIT Sequence Parallelism - 400% Train & Infer Speed-Up
xDiT是一种面向扩散模型的大规模并行推理框架,其核心目标是通过创新的并行化技术和编译优化技术,实现高效的大规模扩散模型训练与推理。
xDiT的核心原理
1. PipeFusion:分块流水线并行,将图像分割为多个patch,用pipeline在不同卡上并行处理,最后整合结果,有效减少单卡上的显存占用;
2. USP:统一序列并行,针对扩散模型中的长序列生成任务,将序列的不同维度拆分到不同卡,实现多卡的序列并行计算,最后整合向量。
3. CFG Parallel:在分类器自由指导过程中,将正向和负向样本的计算分配到不同卡,降低单次推理的计算负载。
4. DistVAE:分布式VAE模块,对扩散模型中的VAE模块进行分块并行处理,避免显存溢出。
5. 编译优化技术:xDiT通过编译器技术优化GPU执行效率,主要依赖:Torch.compile:PyTorch 2.0的JIT编译器,通过融合
算子、消除冗余计算提升性能;OneDiff:针对扩散模型的专用编译优化工具,支持内核融合、内存复用等高级特性。
以下是使用方法:
从源码编译
git clone https://github.com/xdit-project/xDiT.git cd xDiT pip install -e . # Or optionally, with flash attention pip install -e ".[flash-attn]"
使用方法:
多卡并行训练
accelerate launch examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ --dataset_metadata_path data/example_video_dataset/metadata.csv \ --height 720 \ --width 1280 \ --num_frames 49 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "input_image" \ --use_gradient_checkpointing_offload
多卡并行推理
## 首先pipeline启用usp pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", use_usp=True, model_configs=[ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), ], ) pipe.enable_vram_management() ## 然后shell启动多卡推理 torchrun --standalone --nproc_per_node=8 test.py
效果对比
4.4 Gradient Checkpointing Offload - 13% GPU Memory Reduce
梯度检查点卸载(Gradient Checkpointing Offload)是一种以计算换显存的优化技术。在前向传播时,模型不会存储所有中间激活值到GPU显存,并且把激活值卸载到内存中,而是仅保留部分关键层的激活值,并在需要时重新计算中间值。这可以显著减少显存占用,但会增加少量计算时间,还引入额外的 GPU-CPU 数据传输开销。
当GPU的显存比较紧张时,为了实现训练大模型,可以通过这种方法显著减少显存占用,实现训练目标;
实现步骤:
1. 前向传播阶段
计算激活值:模型在前向传播时会计算每一层的输出(即激活值),模型会将部分激活值 从 GPU 显存移动到 CPU 内存。对于未启用梯度检查点的层,激活值通常会保留到显存中。
2. 反向传播阶段
在反向传播计算梯度时,模型需要中间激活值来计算梯度。则需要把激活值从内存加载回 GPU 显存,优化卸载大幅减少显存占用(降低20%左右显存占用)。
实现代码,以下是前向传播时传递的关键参数:
# 在前向传播时传递卸载参数
inputs_shared = {
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
...
}
import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath os.environ["TOKENIZERS_PARALLELISM"] = "false" class WanTrainingModule(DiffusionTrainingModule): def __init__( self, model_paths=None, model_id_with_origin_paths=None, audio_processor_config=None, trainable_models=None, lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None, use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, max_timestep_boundary=1.0, min_timestep_boundary=0.0, ): super().__init__() # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) if audio_processor_config is not None: audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1]) self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, audio_processor_config=audio_processor_config) # Training mode self.switch_pipe_to_training_mode( self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, enable_fp8_training=False, ) # Store other configs self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.max_timestep_boundary = max_timestep_boundary self.min_timestep_boundary = min_timestep_boundary def forward_preprocess(self, data): # CFG-sensitive parameters inputs_posi = {"prompt": data["prompt"]} inputs_nega = {} # CFG-unsensitive parameters inputs_shared = { # Assume you are usingthis pipeline for inference, # please fill in the input parameters. "input_video": data["video"], "height": data["video"][0].size[1], "width": data["video"][0].size[0], "num_frames": len(data["video"]), # Please donot modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, "tiled": False, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "cfg_merge": False, "vace_scale": 1, "max_timestep_boundary": self.max_timestep_boundary, "min_timestep_boundary": self.min_timestep_boundary, } # Extra inputs for extra_input in self.extra_inputs: if extra_input == "input_image": inputs_shared["input_image"] = data["video"][0] elif extra_input == "end_image": inputs_shared["end_image"] = data["video"][-1] elif extra_input == "reference_image"or extra_input == "vace_reference_image": inputs_shared[extra_input] = data[extra_input][0] else: inputs_shared[extra_input] = data[extra_input] # Pipeline units will automatically process the input parameters. for unit in self.pipe.units: inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) return {**inputs_shared, **inputs_posi} def forward(self, data, inputs=None): if inputs is None: inputs = self.forward_preprocess(data) models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(**models, **inputs) return loss if __name__ == "__main__": parser = wan_parser() args = parser.parse_args() dataset = UnifiedDataset( base_path=args.dataset_base_path, metadata_path=args.dataset_metadata_path, repeat=args.dataset_repeat, data_file_keys=args.data_file_keys.split(","), main_data_operator=UnifiedDataset.default_video_operator( base_path=args.dataset_base_path, max_pixels=args.max_pixels, height=args.height, width=args.width, height_division_factor=16, width_division_factor=16, num_frames=args.num_frames, time_division_factor=4, time_division_remainder=1, ), special_operator_map={ "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000), } ) model = WanTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, audio_processor_config=args.audio_processor_config, trainable_models=args.trainable_models, lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank, lora_checkpoint=args.lora_checkpoint, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, max_timestep_boundary=args.max_timestep_boundary, min_timestep_boundary=args.min_timestep_boundary, ) model_logger = ModelLogger( args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt ) launch_training_task(dataset, model, model_logger, args=args)
在训练时启用该参数gradient_checkpointing_offload
accelerate launch examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ --dataset_metadata_path data/example_video_dataset/metadata.csv \ --height 720 \ --width 1280 \ --num_frames 49 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "input_image" \ --use_gradient_checkpointing_offload
效果对比
gpu显存占用下降了13%,相对应的训练时间延长了9%,牺牲了少量时间,来降低显存占用,从而让更多的帧数可以参与训练,是合理的优化。
4.5 Tiled VAE - 11% GPU Memory Reduce
视频通常由多帧连续图像组成,直接处理高分辨率视频序列会占用大量显存。分块编码解码技术通过逐块处理单帧图像,降低单次计算的显存需求,能有效减少显存占用,但会略微降低视频生成质量;
视频VAE的latent空间通常为5D张量([B, C, T, H, W],B=batch size, C=通道数, T=帧数, H/W=高度/宽度),先对单帧图像进行空间编码/解码,再通过时间模块(Time-Series Transformer)建模帧间关系。
分块技术的实现步骤:
1. 编码阶段(视频到latent空间)
对输入视频的每一帧独立分块。例如,若单帧分辨率为[H, W],则用kernel划分为多个tile_size大小的向量块;通过tile_stride确保相邻块部分重叠,避免块边界处的语义断裂。
2. 解码阶段(latent空间到视频)
逐帧解码:对潜在空间的每帧独立解码,解码向量块到对应图像块;对重叠区域的像素值进行加权平均,减少块边界伪影。对解码后的多帧图像通过时间滤波增强帧间连续性。
3. 时间一致性优化策略
在解码前,通过时间模块(Time-Series Transformer)对分块后的潜在特征进行全局时间一致性约束;在分块时对齐相邻帧的块位置,确保同一物体在不同帧中的重建区域一致。
4. 显存优化
根据显存限制,动态调整tile_size和tile_stride;用Sequence Parallel同时处理多个块;
以下是使用方法:
#tiled: Whether to enable tiled VAE inference, default is False. Setting to True significantly reduces VRAM usage during VAE encoding/decoding but introduces small errors and slightly increases inference time. #tile_size: Tile size during VAE encoding/decoding, defaultis(30, 52), only effective when tiled=True. #tile_stride: Stride of tiles during VAE encoding/decoding, defaultis(15, 26), only effective when tiled=True. Must be less than or equal to tile_size. video = pipe( prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", seed=0, tiled=True, tile_size=[30,52],tile_stride=[15,26] )
效果对比
Tiled VAE有效减少了显存占用,同时速度基本不影响;
4.6 Quantization - 33% Train Speed-Up
在训练模型初期,发现wan2.1、wan2.2模型本身架构不同,wan2.2有高噪声和低噪声两套架构,导致模型大小比wan2.1大一倍左右,因此两个模型的显存占用存在差异,在相同数据集下,导致了能够参与训练的视频的帧数两者不一样,例如lora训练中,wan2.1能训练到前60帧,而wan2.2只有前45帧,理论上来说,帧数越少,那么模型能遵循的视频效果越差,为了再对比训练中尽量保持条件一致,并且要保持wan2.2-I2V-A14B模型的基本架构,最好的办法就是模型量化,用精度更低的wan2.2模型,降低大量显存占用,提高训练帧数,还能加速训练时间。
后面的结果不再采用原始的wan2.2模型,后续的结果全部基于量化wan2.2和wan2.1的对比效果。
wan2.2系列模型族由众多的量化模型,AWQ、GGUF、FP8、FP16、INT8,分别代表不同的参数精度,根据量化模型的尺寸,最终选择wan2.2-fp8模型,模型显存占用基本和wan2.1达到相同,训练帧数也和wan2.1最大帧数相同。
以下是原始的两个模型对比:
模型量化是深度学习模型优化的重要技术,在通过降低模型参数和计算的精度来减少模型大小、内存占用和计算开销,同时尽量保持模型性能。
核心方法是将高精度浮点数(如FP32)转换为低精度表示(如INT8或FP16),从而减少存储和计算需求,有效降低显存占用和计算资源消耗。
量化类型 权重量化:仅对模型权重进行量化,激活值保持高精度。 激活量化:对输入/输出激活值进行量化。 全量化:同时量化权重和激活值。 混合量化:对不同层使用不同精度。
量化方法
对称量化:数值范围对称分布(如-127~127),适用于激活值接近零的情况。 非对称量化:数值范围非对称(如0~255),适用于偏移较大的数据。 动态量化:推理时根据输入动态调整量化参数。 静态量化:训练后通过校准数据集确定量化参数。 量化感知训练:在训练阶段模拟量化误差,提升量化后模型精度。
以下是常见的量化精度:
效果对比
5. 结果
总结训练的全流程链路,一般的模型后训练都可以采用这一套完整的流程,实现系统化、可复用、可扩展的实践:
模型的训练实际上是一个持续迭代优化的过程,没有一蹴而就的完美训练,更多的是要持续的调整训练参数、采用新的训练方法、对比实验等等来不断提升模型的准确率和性能,这需要反复迭代,才能达到理想中的最优效果;
在评价AIGC生成内容的质量时,有一些常用的量化指标,在本场景中,动态表现、相机控制、帧质量、准确性是衡量生成质量的重要参考因素,决定采用客观指标+主观打分的方法来综合评价模型训练的效果;客观指标决定采用行业通用的视频质量评价参数:
峰值信噪比(PSNR)
计算方式:基于均方误差(MSE)的对数转换
评价维度:像素级保真度,反映视频帧与参考帧的噪声差异;
典型阈值:10-50,优质视频≥32
结构相似性指数(SSIM)
计算方式:8*8 kernel的滑动窗口计算
评价维度:动态场景中空间结构与时间连续性的保持能力;
标准范围:0(完全失真)-1(完全匹配),优质生成需≥0.85;
这是设计的评测流程:
PSNR的计算代码:
import cv2 import numpy as np def calculate_psnr(video_ref_path, video_gen_path): cap_ref = cv2.VideoCapture(video_ref_path) cap_gen = cv2.VideoCapture(video_gen_path) psnr_list = [] while True: ret_ref, frame_ref = cap_ref.read() ret_gen, frame_gen = cap_gen.read() ifnot ret_ref ornot ret_gen: break # 确保分辨率一致 if frame_ref.shape != frame_gen.shape: frame_gen = cv2.resize(frame_gen, (frame_ref.shape[1], frame_ref.shape[0])) # 计算MSE mse = np.mean((frame_ref - frame_gen) ** 2) if mse == 0: psnr = float('inf') else: psnr = 20 * np.log10(255.0 / np.sqrt(mse)) psnr_list.append(psnr) return np.mean(psnr_list) psnr_value = calculate_psnr("reference_video.mp4", "generated_video.mp4") print(f"PSNR: {psnr_value:.2f} dB")
SSIM计算代码:
from skimage.metrics import structural_similarity as ssim import cv2 import numpy as np def calculate_ssim(video_ref_path, video_gen_path): cap_ref = cv2.VideoCapture(video_ref_path) cap_gen = cv2.VideoCapture(video_gen_path) ssim_list = [] while True: ret_ref, frame_ref = cap_ref.read() ret_gen, frame_gen = cap_gen.read() ifnot ret_ref ornot ret_gen: break if frame_ref.shape != frame_gen.shape: frame_gen = cv2.resize(frame_gen, (frame_ref.shape[1], frame_ref.shape[0])) # 转换为灰度图 gray_ref = cv2.cvtColor(frame_ref, cv2.COLOR_BGR2GRAY) gray_gen = cv2.cvtColor(frame_gen, cv2.COLOR_BGR2GRAY) score, _ = ssim(gray_ref, gray_gen, full=True) ssim_list.append(score) return np.mean(ssim_list) ssim_value = calculate_ssim("reference_video.mp4", "generated_video.mp4") print(f"SSIM: {ssim_value:.4f}")
主观评价由客户划定几个评价指标:动态表现、相机控制、帧质量、目标准确性四个方面分别做人工打分,评分1-5分;
最后综合客观得分+主观得分,综合确定视频生成的效果;计算方式是加权分,总分=主观分/20 * 50% + PSNR/50 * 25% + SSIM/1 *25%;
50组样本里,随机划分5个测试数据,对这5个用于测试的未参与训练的素材图片,测试四个模型的生成效果,再经过自动评分+人工评分,得到如下的得分结果:
最终根据生成视频的总体得分情况,wan2.1 full的生成效果最好,但整体的全量训练所需算力成本比lora大很多,综合比较wan2.1、wan2.2 的训练成本,包括训练时间、所需算力,综合选择性价比最高的wan2.1 lora模型作为上线生产环境的主力模型。
来源 | 阿里云开发者公众号
作者 | 李德