引言
近期大火的OpenAI推出的Sora模型,其核心技术点之一,是将视觉数据转化为Patch的统一表示形式,并通过Transformers技术和扩散模型结合,展现了卓越的scale特性。
被Twitter上广泛传播的论文《Scalable diffusion models with transformers》也被认为是Sora技术背后的重要基础。而这项研究的发布遇到了一些坎坷,曾经被CVPR2023拒稿过。
无独有偶,虽然DiT被拒了,我们看到来自清华大学,人民大学和北京人工智能研究院等机构共同研究的CVPR2023的论文U-ViT《All are Worth Words: A ViT Backbone for Diffusion Models》,这项研究设计了一个简单而通用的基于vit的架构(U-ViT),替换了U-Net中的卷积神经网络(CNN),用于diffusion模型的图像生成任务。
该项研究现已开源,欢迎大家关注:
GitHub链接:
论文链接:
[2209.12152] All are Worth Words: A ViT Backbone for Diffusion Models (arxiv.org)
模型链接:
imagenet256_uvit_huge · 模型库 (modelscope.cn)
同时,我们也意识到,Sora将基于Transformers的diffusion model scale up成功,不仅需要对底层算法有专家级理解,还要对整个深度学习工程体系有很好的把握,这项工作相比在学术数据集做出一个可行架构更加困难。
论文和代码技术解读
1、模型结构概览
基于卷积神经网络(CNN)的U-Net一直是之前的Diffusion Model中的主流backbone。基于CNN的U-Net具有一组下采样块和一组上采样块以及两组之间的Long skip connection的特征。然而视觉Transformers(ViT)在各类视觉任务重已经显示出很好的前景,其中ViT与基于CNN的方法效果相当甚至优于CNN。因此,论文的开篇就提出了一个问题:在扩散模型中是否有必要依赖基于CNN的U-Net?
在这篇研究中,遵循Transformers的设计方法,U-ViT将包括时间、条件和噪声图像patches在内的所有输入都视作为token。至关重要的是,U-ViT受U-Net启发,采用了浅层和深层之间的Long skip Connection。实际上,低级特征对于扩散模型中的像素级预测目标很重要,Long skip Connection可以简化对应预测网络的训练。此外,U-ViT可选的在输出之前增加了一个额外的3X3的卷积块,以获得更好的视觉质量。详见如下的模型架构图。
如上图所示,扩散模型的U-ViT架构,其特点是将时间,条件和噪声图像块在内的所有输入作为token,并在浅层和深层之间使用(blocks-1)/2 Long skip Connection。
结合上图,我们对UViT模型的结构做了一个大致的梳理,大家可以先有个初步的了解,下面我们将对提及的每个模块进行详细的介绍。
2、模型的输入
首先我们来看下模型的输入分别是哪些组成,是什么样的。
输入部分切合了论文的标题:All as words (token),将包括时间、条件和噪声图像快在内的所有输入表示为token,再通过Embedding层。
Embedding层的作用是将某种格式的输入数据,转变为模型可以处理的向量表示,来描述原始数据所包含的信息。
timestep_embedding
timestep_embedding,核心是为时间步长生成正弦嵌入,用于时序数据中引入时间信息,如下:
def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding
Patchify函数
patchify函数,将图像划分为 patches ,示例图片如下:
代码:
def patchify(imgs, patch_size): x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) return x
PatchEmbed:
通过卷积操作将图像转换为 patch token,即将图像分割成多个patches并投影到指定维度的向量空间。
代码:
class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, patch_size, in_chans=3, embed_dim=768): super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): B, C, H, W = x.shape assert H % self.patch_size == 0 and W % self.patch_size == 0 x = self.proj(x).flatten(2).transpose(1, 2) return x
3、编解码结构设计
参考代码:
U-ViT/libs/autoencoder.py at main · baofff/U-ViT · GitHub
上采样代码,根据需求选择性地结合上采样和卷积操作,实现对输入特征图的上采样并可选地进行特征提取。
class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x
下采样代码,对输入的二维特征图(例如图像)进行下采样操作,即降低特征图的空间维度。
class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x
Encoder设计采用了深度残差网络结构,并结合了多尺度注意力机制来增强特征学习能力,旨在高效地捕获输入图像中的关键信息并转化为适合下游任务(如图像生成)的潜在表征。
class Encoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h
Decoder是一个递归结构,包含卷积层、残差块(ResnetBlock)、注意力机制(make_attn)以及上采样层(Upsample),可以逐步将低维潜在空间的向量(z)转换为高分辨率的图像。
class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x
4、Transformers设计
注意力机制
UViT实现多头注意力机制,根据ATTENTION_MODE选择不同的计算方式,包括 Flash Attention、XFormers Attention 和Math attention。
以下是Attention模块的实现代码:
class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, L, C = x.shape qkv = self.qkv(x) if ATTENTION_MODE == 'flash': qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() q, k, v = qkv[0], qkv[1], qkv[2] # B H L D x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = einops.rearrange(x, 'B H L D -> B L (H D)') elif ATTENTION_MODE == 'xformers': qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B L H D x = xformers.ops.memory_efficient_attention(q, k, v) x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) elif ATTENTION_MODE == 'math': qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B H L D attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, L, C) else: raise NotImplemented x = self.proj(x) x = self.proj_drop(x) return x
Transformers Block设计
Block类:一个基于Transformer编码器/解码器块,包含multi-head attention层和 MLP 层,并可选地使用skip connection以及 checkpoint进行内存优化。
class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) self.skip_linear = nn.Linear(2 * dim, dim) if skip else None self.use_checkpoint = use_checkpoint def forward(self, x, skip=None): if self.use_checkpoint: return torch.utils.checkpoint.checkpoint(self._forward, x, skip) else: return self._forward(x, skip) def _forward(self, x, skip=None): if self.skip_linear is not None: x = self.skip_linear(torch.cat([x, skip], dim=-1)) x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x
5、UViT主干:
初始化包含了patch embedding模块、时间embedding模块(可以选择MLP形式的时间嵌入)、Class embedding、Position embedding等。
定义了in_blocks和out_blocks部分,中间还有一个mid_block。
最后,模型通过decoder_pred线性层预测出patch级别的特征,然后重构回原始图像尺寸。
在前向过程中,模型首先对输入图像进行patch化处理,接着添加time embedding(如果有提供时间步信息的话)以及position embedding。经过一系列的编码器和解码器block后,输出被映射回patch级别特征,再重构为图像尺寸,最后通过一个可选的3X3卷积层或者恒等映射输出最终结果。
class UViT(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, use_checkpoint=False, conv=True, skip=True): super().__init__() self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_classes = num_classes self.in_chans = in_chans self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = (img_size // patch_size) ** 2 self.time_embed = nn.Sequential( nn.Linear(embed_dim, 4 * embed_dim), nn.SiLU(), nn.Linear(4 * embed_dim, embed_dim), ) if mlp_time_embed else nn.Identity() if self.num_classes > 0: self.label_emb = nn.Embedding(self.num_classes, embed_dim) self.extras = 2 else: self.extras = 1 self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) self.in_blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer, use_checkpoint=use_checkpoint) for _ in range(depth // 2)]) self.mid_block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer, use_checkpoint=use_checkpoint) self.out_blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) for _ in range(depth // 2)]) self.norm = norm_layer(embed_dim) self.patch_dim = patch_size ** 2 * in_chans self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() trunc_normal_(self.pos_embed, std=.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed'} def forward(self, x, timesteps, y=None): x = self.patch_embed(x) B, L, D = x.shape time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) time_token = time_token.unsqueeze(dim=1) x = torch.cat((time_token, x), dim=1) if y is not None: label_emb = self.label_emb(y) label_emb = label_emb.unsqueeze(dim=1) x = torch.cat((label_emb, x), dim=1) x = x + self.pos_embed skips = [] for blk in self.in_blocks: x = blk(x) skips.append(x) x = self.mid_block(x) for blk in self.out_blocks: x = blk(x, skips.pop()) x = self.norm(x) x = self.decoder_pred(x) assert x.size(1) == self.extras + L x = x[:, self.extras:, :] x = unpatchify(x, self.in_chans) x = self.final_layer(x) return x
实战案例
实战案例是通过扩散模型,在给定类别标签条件下,使用UViT作为主干,生成连续图像样本,并展示了生成的结果。
根据UViT官方的实践代码(U-ViT/UViT_ImageNet_demo.ipynb at main · baofff/U-ViT · GitHub)修改,可直接在魔搭社区的免费算力上实践和运行。
魔搭案例链接:
modelscope/examples/pytorch/UViT_ImageNet_demo.ipynb at master · modelscope/modelscope · GitHub
1、环境依赖安装
!git clone https://github.com/baofff/U-ViT !pip install einops import os os.chdir('/mnt/workspace/U-ViT') os.environ['PYTHONPATH'] = '/env/python:/mnt/workspace/U-ViT' import torch from dpm_solver_pp import NoiseScheduleVP, DPM_Solver import libs.autoencoder from libs.uvit import UViT import einops from torchvision.utils import save_image from PIL import Image from modelscope.hub.file_download import model_file_download
2、加载UViT模型
设置图像尺寸为256,下载对应的UViT模型,计算Latent Space的低维潜在表示的尺寸大小,初始化UViT模型结构,加载模型。
image_size = "256" #@param [256, 512] image_size = int(image_size) if image_size == 256: #download uvit model model_file_download(model_id='thu-ml/imagenet256_uvit_huge',file_path='imagenet256_uvit_huge.pth', cache_dir='/mnt/workspace') !mv /mnt/workspace/thu-ml/imagenet256_uvit_huge/imagenet256_uvit_huge.pth /mnt/workspace/U-ViT else: model_file_download(model_id='thu-ml/imagenet512_uvit_huge',file_path='imagenet512_uvit_huge.pth', cache_dir='/mnt/workspace') !mv /mnt/workspace/thu-ml/imagenet512_uvit_huge/imagenet512_uvit_huge.pth /mnt/workspace/U-ViT z_size = image_size // 8 patch_size = 2 if image_size == 256 else 4 device = 'cuda' if torch.cuda.is_available() else 'cpu' nnet = UViT(img_size=z_size, patch_size=patch_size, in_chans=4, embed_dim=1152, depth=28, num_heads=16, num_classes=1001, conv=False) nnet.to(device) nnet.load_state_dict(torch.load(f'imagenet{image_size}_uvit_huge.pth', map_location='cpu')) nnet.eval()
3、下载自动编码器模型并加载
model_file_download(model_id='AI-ModelScope/autoencoder_kl_ema',file_path='autoencoder_kl_ema.pth', cache_dir='/mnt/workspace') !mv /mnt/workspace/AI-ModelScope/autoencoder_kl_ema/autoencoder_kl_ema.pth /mnt/workspace/U-ViT autoencoder = libs.autoencoder.get_model('autoencoder_kl_ema.pth') autoencoder.to(device)
4、UViT结合diffusion模型实现图像生成
通过扩散模型,在给定类别标签条件下,使用UViT作为主干,生成连续图像样本,并展示了生成的结果。Sample方式为dpm_solver。
seed = 42 #@param {type:"number"} steps = 25 #@param {type:"slider", min:0, max:1000, step:1} cfg_scale = 3 #@param {type:"slider", min:0, max:10, step:0.1} class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"} samples_per_row = 4 #@param {type:"number"} torch.manual_seed(seed) def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): _betas = ( torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 ) return _betas.numpy() _betas = stable_diffusion_beta_schedule() # set the noise schedule noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) y = torch.tensor(class_labels, device=device) y = einops.repeat(y, 'B -> (B N)', N=samples_per_row) def model_fn(x, t_continuous): t = t_continuous * len(_betas) _cond = nnet(x, t, y=y) _uncond = nnet(x, t, y=torch.tensor([1000] * x.size(0), device=device)) return _cond + cfg_scale * (_cond - _uncond) # classifier free guidance z_init = torch.randn(len(y), 4, z_size, z_size, device=device) dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) with torch.no_grad(): with torch.cuda.amp.autocast(): # inference with mixed precision z = dpm_solver.sample(z_init, steps=steps, eps=1. / len(_betas), T=1.) samples = autoencoder.decode(z) samples = 0.5 * (samples + 1.) samples.clamp_(0., 1.) save_image(samples, "sample.png", nrow=samples_per_row * 2, padding=0) samples = Image.open("sample.png") display(samples)
总结
正如论文中所说,具有UViT的latent diffusion模型在Im-ageNet 256X256上类条件图像生成中获得了创纪录的2.29分FID,在MS-COCO上的文本到图像生成中获得了5.48FID的高分,感谢UViT开源相应的工作,我们相信,基于UViT的研究基础上,开发者们可以更好的展开类似Sora这样的前沿技术研究工作。
GITHUB链接: