❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发感兴趣,我会每日分享大模型与 AI 领域的开源项目和应用,提供运行实例和实用教程,帮助你快速上手AI技术!
🥦 AI 在线答疑 -> 智能检索历史文章和开源项目 -> 丰富的 AI 工具库 -> 每日更新 -> 尽在微信公众号 -> 搜一搜:蚝油菜花 🥦
🎨 「设计师集体失业?复旦开源模型5亿参数秒出商业级高清图」
大家好,我是蚝油菜花。你是否也经历过这些设计噩梦——
- 👉 通宵改稿第18版,甲方最后说"还是用第一版吧"
- 👉 想尝试新风格却卡在素材搜索,试到灵感枯竭
- 👉 用AI生成图像,结果分辨率一放大就糊成马赛克...
今天要拆解的 SimpleAR ,正在颠覆图像生成规则!这个由复旦&字节联手打造的开源神器:
- ✅ 小身材大能量:仅5亿参数生成1024×1024高清图,GenEval得分0.59
- ✅ 三阶段训练法:预训练+SFT+强化学习,文本跟随能力吊打同类
- ✅ 14秒极速出图:兼容vLLM加速技术,商业应用零门槛
已有广告团队用它1天做完季度提案,接下来将揭秘这套「参数少质量高」的黑科技原理!
SimpleAR 是什么
SimpleAR 是复旦大学视觉与学习实验室和字节 Seed 团队联合推出的纯自回归图像生成模型。该模型采用简洁的自回归架构,通过优化训练和推理过程,实现了高质量的图像生成。
SimpleAR 仅用 5 亿参数即可生成 1024×1024 分辨率的图像,在 GenEval 等基准测试中取得了优异成绩。训练采用"预训练 – 有监督微调 – 强化学习"的三阶段方法,显著提升了文本跟随能力和生成效果。SimpleAR 兼容现有加速技术,推理时间可缩短至 14 秒以内。
SimpleAR 的主要功能
- 高质量文本到图像生成:SimpleAR 是纯自回归的视觉生成框架,仅用 5 亿参数就能生成 1024×1024 分辨率的高质量图像,在 GenEval 等基准测试中取得了 0.59 的优异成绩。
- 多模态融合生成:将文本和视觉 token 平等对待,集成在一个统一的 Transformer 架构中,支持多模态建模,能更好地进行文本引导的图像生成。
SimpleAR 的技术原理
- 自回归生成机制:SimpleAR 采用经典的自回归生成方式,通过"下一个 token 预测"的形式逐步生成图像内容。这种机制将图像分解为一系列离散的 token,然后逐个预测这些 token,从而构建出完整的图像。
- 多模态融合:SimpleAR 将文本编码和视觉生成集成在一个 decoder-only 的 Transformer 架构中。提高了参数的利用效率,更好地支持了文本和视觉模态之间的联合建模,使模型能更自然地理解和生成与文本描述对应的图像。
- 三阶段训练方法:
- 预训练:通过大规模数据预训练,学习通用的视觉和语言模式。
- 有监督微调(SFT):在预训练基础上,通过有监督学习进一步提升生成质量和指令跟随能力。
- 强化学习(GRPO):基于简单的 reward 函数(如 CLIP)进行后训练,优化生成内容的美学性和多模态对齐。
- 推理加速技术:SimpleAR 通过 vLLM 等技术优化推理过程,显著缩短了图像生成时间。例如,0.5B 参数的模型可以在 14 秒内生成 1024×1024 分辨率的高质量图像。
- 视觉 tokenizer 的选择:SimpleAR 使用 Cosmos 作为视觉 tokenizer,在低分辨率图像和细节重建上存在局限,仍有改进空间。
如何运行 SimpleAR
安装环境
python3 -m venv env
source env/bin/activate
pip install -e ".[train]"
模型下载
这里提供SFT和RL两种checkpoint:
name | GenEval | DPG | HF权重 |
---|---|---|---|
SimpleAR-0.5B-SFT | 0.53 | 79.34 | simplear-0.5B-sft |
SimpleAR-0.5B-RL | 0.59 | 79.66 | simplear-0.5B-grpo |
下载视觉tokenizer:
cd checkpoints
git lfs install
git clone https://huggingface.co/nvidia/Cosmos-1.0-Tokenizer-DV8x16x16
生成图像示例代码
import os
import torch
from torchvision.utils import save_image
from transformers import AutoTokenizer
from simpar.model.tokenizer.cosmos_tokenizer.networks import TokenizerConfigs
from simpar.model.tokenizer.cosmos_tokenizer.video_lib import CausalVideoTokenizer as CosmosTokenizer
from simpar.model.language_model.simpar_qwen2 import SimpARForCausalLM
device = "cuda:0"
model_name = "Daniel0724/SimpleAR-0.5B-RL"
# define your prompt here:
prompt = "Inside a warm room with a large window showcasing a picturesque winter landscape, three gleaming ruby red necklaces are elegantly laid out on the plush surface of a deep purple velvet jewelry box. The gentle glow from the overhead light accentuates the rich color and intricate design of the necklaces. Just beyond the glass pane, snowflakes can be seen gently falling to coat the ground outside in a blanket of white."
# Load LLM and tokenizer
model = SimpARForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load Cosmos tokenizer
tokenizer_config = TokenizerConfigs["DV"].value
tokenizer_config.update(dict(spatial_compression=16, temporal_compression=8))
vq_model = CosmosTokenizer(checkpoint_enc=f"./checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/encoder.jit", checkpoint_dec=f"./checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/decoder.jit", tokenizer_config=tokenizer_config)
vq_model.eval()
vq_model.requires_grad_(False)
codebook_size = 64000
latent_size = 64
format_prompt = "<|t2i|>" + "A highly realistic image of " + prompt + "<|soi|>"
input_ids = tokenizer(format_prompt, return_tensors="pt").input_ids.to(device)
uncond_prompt = "<|t2i|>" + "An image of aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" + "<|soi|>"
uncond_input_ids = tokenizer(uncond_prompt, return_tensors="pt").input_ids.to(device)
# next token prediction
with torch.inference_mode():
output_ids = model.generate_visual(
input_ids,
negative_prompt_ids=uncond_input_ids,
cfg_scale=6.0,
do_sample=True,
temperature=1.0,
top_p=1.0,
top_k=64000,
max_new_tokens=4096,
use_cache=True
)
index_sample = output_ids[:, input_ids.shape[1]: input_ids.shape[1] + 4096].clone()
index_sample = index_sample - len(tokenizer)
index_sample = torch.clamp(index_sample, min=0, max=codebook_size-1)
index_sample = index_sample.reshape(-1, latent_size, latent_size).unsqueeze(1)
# decode with tokenizer
with torch.inference_mode():
samples = vq_model.decode(index_sample)
samples = samples.squeeze(2)
save_image(samples, os.path.join(f"{prompt[:50]}.png"), normalize=True, value_range=(-1, 1))
资源
- GitHub 仓库:https://github.com/wdrink/SimpleAR
❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发感兴趣,我会每日分享大模型与 AI 领域的开源项目和应用,提供运行实例和实用教程,帮助你快速上手AI技术!
🥦 AI 在线答疑 -> 智能检索历史文章和开源项目 -> 丰富的 AI 工具库 -> 每日更新 -> 尽在微信公众号 -> 搜一搜:蚝油菜花 🥦