手动实现一个扩散模型DDPM(上):https://developer.aliyun.com/article/1480704
▐ 位置嵌入
如何让网络知道目前处于K的哪一步?可以增加一个Time Embedding(类似于Positional embeddings)进行处理,通过将timestep编码进网络中,从而只需要训练一个共享的U-Net模型,就可以让网络知道现在处于哪一步了。
Time Embedding正是输入到ResNetBlock模块中,为U-Net引入了时间信息(时间步长T,T的大小代表了噪声扰动的强度),模拟一个随时间变化不断增加不同强度噪声扰动的过程,让SD模型能够更好地理解时间相关性。
同时,在SD模型调用U-Net重复迭代去噪的过程中,我们希望在迭代的早期,能够先生成整幅图片的轮廓与边缘特征,随着迭代的深入,再补充生成图片的高频和细节特征信息。由于在每个ResNetBlock模块中都有Time Embedding,就能告诉U-Net现在是整个迭代过程的哪一步,并及时控制U-Net够根据不同的输入特征和迭代阶段而预测不同的噪声残差。
从AI绘画应用视角解释一下Time Embedding的作用。Time Embedding能够让SD模型在生成图片时考虑时间的影响,使得生成的图片更具有故事性、情感和沉浸感等艺术效果。并且Time Embedding可以帮助SD模型在不同的时间点将生成的图片添加完善不同情感和主题的内容,从而增加了AI绘画的多样性和表现力。
class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings
▐ U-net
基于上述定义的DM神经网络基础的层和模块,现在是时候把他组装拼接起来了:
- 神经网络接受一批如下shape的噪声图像输入(batch_size, num_channels, height, width) 同时接受这批噪声水平,shape=(batch_size, 1)。返回一个张量,shape = (batch_size, num_channels, height, width)
按照如下步骤构建这个网络:
- 首先,对噪声图像进行卷积处理,对噪声水平进行进行位置编码(embedding)
- 然后,进入一个序列的下采样阶段,每个下采样阶段由两个ResNet/ConvNeXT模块+分组归一化+注意力模块+残差链接+下采样完成。
- 在网络的中间层,再一次用ResNet/ConvNeXT模块,中间穿插着注意力模块(Attention)。
- 下一个阶段,则是序列构成的上采样阶段,每个上采样阶段由两个ResNet/ConvNeXT模块+分组归一化+注意力模块+残差链接+上采样完成。
- 最后,一个ResNet/ConvNeXT模块后面跟着一个卷积层。
class Unet(nn.Module): # 初始化函数,定义U-Net网络的结构和参数 def __init__( self, dim, # 基本隐藏层维度 init_dim=None, # 初始层维度,如果未提供则会根据dim计算得出 out_dim=None, # 输出维度,如果未提供则默认为输入图像的通道数 dim_mults=(1, 2, 4, 8), # 控制每个阶段隐藏层维度倍增的倍数 channels=3, # 输入图像的通道数,默认为3 with_time_emb=True, # 是否使用时间嵌入,这对于某些生成模型可能是必要的 resnet_block_groups=8, # ResNet块中的组数 use_convnext=True, # 是否使用ConvNeXt块而不是ResNet块 convnext_mult=2, # ConvNeXt块的维度倍增因子 ): super().__init__() # 调用父类构造函数 # 确定各层维度 self.channels = channels init_dim = default(init_dim, dim // 3 * 2) # 设置或计算初始层维度 self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) # 初始卷积层,使用7x7卷积核和padding dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # 计算每个阶段的维度 in_out = list(zip(dims[:-1], dims[1:])) # 创建输入输出维度对 # 根据use_convnext选择块类 if use_convnext: block_klass = partial(ConvNextBlock, mult=convnext_mult) else: block_klass = partial(ResnetBlock, groups=resnet_block_groups) # 时间嵌入层 if with_time_emb: time_dim = dim * 4 # 时间嵌入的维度 self.time_mlp = nn.Sequential( # 时间嵌入的多层感知机 SinusoidalPositionEmbeddings(dim), # 正弦位置嵌入 nn.Linear(dim, time_dim), # 线性变换 nn.GELU(), # GELU激活函数 nn.Linear(time_dim, time_dim), # 再一次线性变换 ) else: time_dim = None self.time_mlp = None # 下采样层 self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) # 解析的层数 # 构建下采样模块 for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) # 是否为最后一层 self.downs.append( # 添加下采样块 nn.ModuleList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), # 卷积块 block_klass(dim_out, dim_out, time_emb_dim=time_dim), # 卷积块 Residual(PreNorm(dim_out, LinearAttention(dim_out))), # 残差连接和注意力模块 Downsample(dim_out) if not is_last else nn.Identity(), # 下采样或恒等映射 ] ) ) # 中间层(瓶颈层) mid_dim = dims[-1] # 中间层(瓶颈层) # 第一个中间卷积块 self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) # 中间层的注意力模块 self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) # 第二个中间卷积块 self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) # 构建上采样模块 for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) # 是否是最后一次上采样,减2是因为我们需要留出一个输出层 self.ups.append( nn.ModuleList( [ # 卷积块,这里输入维度翻倍是因为上采样过程中会与编码器阶段的相应层进行拼接 block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), # 卷积块 block_klass(dim_in, dim_in, time_emb_dim=time_dim), # 残差和注意力模块 Residual(PreNorm(dim_in, LinearAttention(dim_in))), # 上采样或恒等映射 Upsample(dim_in) if not is_last else nn.Identity(), ] ) ) # 设置或计算输出维度,如果未提供则默认为输入图像的通道数 out_dim = default(out_dim, channels) # 最后的卷积层,将输出维度变换到期望的输出维度 self.final_conv = nn.Sequential( block_klass(dim, dim), # 卷积块 nn.Conv2d(dim, out_dim, 1) # 1x1卷积,用于输出维度变换 ) # 前向传播函数 def forward(self, x, time): # 初始卷积层 x = self.init_conv(x) # 如果存在时间嵌入层,则将时间编码 t = self.time_mlp(time) if exists(self.time_mlp) else None # 用于存储各个阶段的特征图 h = [] # 下采样过程 for block1, block2, attn, downsample in self.downs: x = block1(x, t) # 应用卷积块 x = block2(x, t) # 应用卷积块 x = attn(x) # 应用注意力模块 h.append(x) # 存储特征图以便后续的拼接 x = downsample(x) # 应用下采样或恒等映射 # 中间层或瓶颈层 x = self.mid_block1(x, t) # 第一个中间卷积块 x = self.mid_attn(x) # 中间层的注意力模块 x = self.mid_block2(x, t) # 第二个中间卷积块 # 上采样过程 for block1, block2, attn, upsample in self.ups: # 拼接特征图和对应的编码器阶段的特征图 x = torch.cat((x, h.pop()), dim=1) x = block1(x, t) # 应用卷积块 x = block2(x, t) # 应用卷积块 x = attn(x) # 应用注意力模块 x = upsample(x) # 应用上采样或恒等映射 # 最后的输出层,输出最终的特征图或图像 return self.final_conv(x)
▐ 损失函数
下面这段代码是为扩散模型中的去噪模型定义的损失函数。它计算由去噪模型预测的噪声和实际加入的噪声之间的差异。该函数支持不同类型的损失,包括L1损失、均方误差损失(L2损失)和Huber损失。选择适当的损失函数可以帮助模型更好地学习如何预测和去除生成数据中的噪声。
import torch import torch.nn.functional as F # 定义损失函数,它评估去噪模型的性能 def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"): if noise is None: noise = torch.randn_like(x_start) # 如果未提供噪声,则生成一个与x_start形状相同的随机噪声张量 # 使用q_sample函数生成带有噪声的数据x_noisy,这模拟了扩散模型的前向过程 x_noisy = q_sample(x_start=x_start, t=t, noise=noise) # 使用去噪模型对噪声数据x_noisy进行预测,试图恢复加入的噪声 predicted_noise = denoise_model(x_noisy, t) # 根据指定的损失类型计算损失 if loss_type == 'l1': # 如果损失类型为L1损失 loss = F.l1_loss(noise, predicted_noise) # 使用L1损失函数计算真实噪声和预测噪声之间的差异 elif loss_type == 'l2': # 如果损失类型为L2损失(均方误差损失) loss = F.mse_loss(noise, predicted_noise) # 使用均方误差损失函数计算真实噪声和预测噪声之间的差异 elif loss_type == "huber": # 如果损失类型为Huber损失 loss = F.smooth_l1_loss(noise, predicted_noise) # 使用Huber损失函数,这是L1和L2损失的结合,对异常值不那么敏感 else: raise NotImplementedError() # 如果指定了未实现的损失类型,则抛出异常 return loss # 返回计算得到的损失值
▐ 开始训练
if __name__=="__main__": for epoch in range(epochs): for step, batch in tqdm(enumerate(dataloader), desc='Training'): optimizer.zero_grad() batch = batch[0] batch_size = batch.shape[0] batch = batch.to(device) # 国内版启用这段,注释上面两行 # batch_size = batch[0].shape[0] # batch = batch[0].to(device) # Algorithm 1 line 3: sample t uniformally for every example in the batch t = torch.randint(0, timesteps, (batch_size,), device=device).long() loss = p_losses(model, batch, t, loss_type="huber") if step % 50 == 0: print("Loss:", loss.item()) loss.backward() optimizer.step() # save generated images if step != 0 and step % save_and_sample_every == 0: milestone = step // save_and_sample_every batches = num_to_groups(4, batch_size) all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches)) all_images = torch.cat(all_images_list, dim=0) all_images = (all_images + 1) * 0.5 # save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6) currentDateAndTime = datetime.now() torch.save(model,f"train.pt")
▐ 推理结果
参考文献
- 深入学习:Diffusion Model 原理解析(地址:http://www.egbenz.com/#/my_article/12)
- 【一个本子】Diffusion Model 原理详解(地址:https://zhuanlan.zhihu.com/p/582072317)
- 深入浅出扩散模型(Diffusion Model)系列:基石DDPM(模型架构篇),最详细的DDPM架构图解(地址:https://zhuanlan.zhihu.com/p/637815071)
- 一文读懂Transformer模型的位置编码(地址:https://zhuanlan.zhihu.com/p/637815071
- https://zhuanlan.zhihu.com/p/632809634
团队介绍
我们是淘天集团业务技术线的手猫营销&导购团队,专注于在手机天猫平台上探索创新商业化,我们依托淘天集团强大的互联网背景,致力于为手机天猫平台提供效率高、创新性强的技术支持。
我们的队员们来自各种营销和导购领域,拥有丰富的经验。通过不断地技术探索和商业创新,我们改善了用户的体验,并提升了平台的运营效率。
我们的团队持续不懈地探索和提升技术能力,坚持“技术领先、用户至上”,为手机天猫的导购场景和商业发展做出了显著贡献。