手动实现一个扩散模型DDPM(下)

简介: 手动实现一个扩散模型DDPM(下)

手动实现一个扩散模型DDPM(上):https://developer.aliyun.com/article/1480704

 位置嵌入



如何让网络知道目前处于K的哪一步?可以增加一个Time Embedding(类似于Positional embeddings)进行处理,通过将timestep编码进网络中,从而只需要训练一个共享的U-Net模型,就可以让网络知道现在处于哪一步了。
Time Embedding正是输入到ResNetBlock模块中,为U-Net引入了时间信息(时间步长T,T的大小代表了噪声扰动的强度),模拟一个随时间变化不断增加不同强度噪声扰动的过程,让SD模型能够更好地理解时间相关性
同时,在SD模型调用U-Net重复迭代去噪的过程中,我们希望在迭代的早期,能够先生成整幅图片的轮廓与边缘特征,随着迭代的深入,再补充生成图片的高频和细节特征信息。由于在每个ResNetBlock模块中都有Time Embedding,就能告诉U-Net现在是整个迭代过程的哪一步,并及时控制U-Net够根据不同的输入特征和迭代阶段而预测不同的噪声残差
从AI绘画应用视角解释一下Time Embedding的作用。Time Embedding能够让SD模型在生成图片时考虑时间的影响,使得生成的图片更具有故事性、情感和沉浸感等艺术效果。并且Time Embedding可以帮助SD模型在不同的时间点将生成的图片添加完善不同情感和主题的内容,从而增加了AI绘画的多样性和表现力。


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings



 U-net


基于上述定义的DM神经网络基础的层和模块,现在是时候把他组装拼接起来了:

  • 神经网络接受一批如下shape的噪声图像输入(batch_size, num_channels, height, width) 同时接受这批噪声水平,shape=(batch_size, 1)。返回一个张量,shape = (batch_size, num_channels, height, width)


按照如下步骤构建这个网络:

  • 首先,对噪声图像进行卷积处理,对噪声水平进行进行位置编码(embedding)
  • 然后,进入一个序列的下采样阶段,每个下采样阶段由两个ResNet/ConvNeXT模块+分组归一化+注意力模块+残差链接+下采样完成。
  • 在网络的中间层,再一次用ResNet/ConvNeXT模块,中间穿插着注意力模块(Attention)。
  • 下一个阶段,则是序列构成的上采样阶段,每个上采样阶段由两个ResNet/ConvNeXT模块+分组归一化+注意力模块+残差链接+上采样完成。
  • 最后,一个ResNet/ConvNeXT模块后面跟着一个卷积层。



class Unet(nn.Module):
    # 初始化函数,定义U-Net网络的结构和参数
    def __init__(
            self,
            dim,  # 基本隐藏层维度
            init_dim=None,  # 初始层维度,如果未提供则会根据dim计算得出
            out_dim=None,  # 输出维度,如果未提供则默认为输入图像的通道数
            dim_mults=(1, 2, 4, 8),  # 控制每个阶段隐藏层维度倍增的倍数
            channels=3,  # 输入图像的通道数,默认为3
            with_time_emb=True,  # 是否使用时间嵌入,这对于某些生成模型可能是必要的
            resnet_block_groups=8,  # ResNet块中的组数
            use_convnext=True,  # 是否使用ConvNeXt块而不是ResNet块
            convnext_mult=2,  # ConvNeXt块的维度倍增因子
    ):
        super().__init__()  # 调用父类构造函数
        # 确定各层维度
        self.channels = channels
        init_dim = default(init_dim, dim // 3 * 2)  # 设置或计算初始层维度
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)  # 初始卷积层,使用7x7卷积核和padding
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]  # 计算每个阶段的维度
        in_out = list(zip(dims[:-1], dims[1:]))  # 创建输入输出维度对
        # 根据use_convnext选择块类
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)
        # 时间嵌入层
        if with_time_emb:
            time_dim = dim * 4  # 时间嵌入的维度
            self.time_mlp = nn.Sequential(  # 时间嵌入的多层感知机
                SinusoidalPositionEmbeddings(dim),  # 正弦位置嵌入
                nn.Linear(dim, time_dim),  # 线性变换
                nn.GELU(),  # GELU激活函数
                nn.Linear(time_dim, time_dim),  # 再一次线性变换
            )
        else:
            time_dim = None
            self.time_mlp = None
        # 下采样层
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)  # 解析的层数
        # 构建下采样模块
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)  # 是否为最后一层
            self.downs.append(  # 添加下采样块
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),  # 卷积块
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),  # 卷积块
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),  # 残差连接和注意力模块
                        Downsample(dim_out) if not is_last else nn.Identity(),  # 下采样或恒等映射
                    ]
                )
            )
        # 中间层(瓶颈层)
        mid_dim = dims[-1]
        # 中间层(瓶颈层)
        # 第一个中间卷积块
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        # 中间层的注意力模块
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        # 第二个中间卷积块
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        # 构建上采样模块
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)  # 是否是最后一次上采样,减2是因为我们需要留出一个输出层
            self.ups.append(
                nn.ModuleList(
                    [
                        # 卷积块,这里输入维度翻倍是因为上采样过程中会与编码器阶段的相应层进行拼接
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        # 卷积块
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        # 残差和注意力模块
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        # 上采样或恒等映射
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        # 设置或计算输出维度,如果未提供则默认为输入图像的通道数
        out_dim = default(out_dim, channels)
        # 最后的卷积层,将输出维度变换到期望的输出维度
        self.final_conv = nn.Sequential(
            block_klass(dim, dim),  # 卷积块
            nn.Conv2d(dim, out_dim, 1)  # 1x1卷积,用于输出维度变换
        )

    # 前向传播函数
    def forward(self, x, time):
        # 初始卷积层
        x = self.init_conv(x)
        # 如果存在时间嵌入层,则将时间编码
        t = self.time_mlp(time) if exists(self.time_mlp) else None
        # 用于存储各个阶段的特征图
        h = []

        # 下采样过程
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)  # 应用卷积块
            x = block2(x, t)  # 应用卷积块
            x = attn(x)  # 应用注意力模块
            h.append(x)  # 存储特征图以便后续的拼接
            x = downsample(x)  # 应用下采样或恒等映射

        # 中间层或瓶颈层
        x = self.mid_block1(x, t)  # 第一个中间卷积块
        x = self.mid_attn(x)  # 中间层的注意力模块
        x = self.mid_block2(x, t)  # 第二个中间卷积块

        # 上采样过程
        for block1, block2, attn, upsample in self.ups:
            # 拼接特征图和对应的编码器阶段的特征图
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)  # 应用卷积块
            x = block2(x, t)  # 应用卷积块
            x = attn(x)  # 应用注意力模块
            x = upsample(x)  # 应用上采样或恒等映射

        # 最后的输出层,输出最终的特征图或图像
        return self.final_conv(x)


 损失函数



下面这段代码是为扩散模型中的去噪模型定义的损失函数。它计算由去噪模型预测的噪声和实际加入的噪声之间的差异。该函数支持不同类型的损失,包括L1损失、均方误差损失(L2损失)和Huber损失。选择适当的损失函数可以帮助模型更好地学习如何预测和去除生成数据中的噪声。

import torch
import torch.nn.functional as F

# 定义损失函数,它评估去噪模型的性能
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)  # 如果未提供噪声,则生成一个与x_start形状相同的随机噪声张量

    # 使用q_sample函数生成带有噪声的数据x_noisy,这模拟了扩散模型的前向过程
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    # 使用去噪模型对噪声数据x_noisy进行预测,试图恢复加入的噪声
    predicted_noise = denoise_model(x_noisy, t)

    # 根据指定的损失类型计算损失
    if loss_type == 'l1':  # 如果损失类型为L1损失
        loss = F.l1_loss(noise, predicted_noise)  # 使用L1损失函数计算真实噪声和预测噪声之间的差异
    elif loss_type == 'l2':  # 如果损失类型为L2损失(均方误差损失)
        loss = F.mse_loss(noise, predicted_noise)  # 使用均方误差损失函数计算真实噪声和预测噪声之间的差异
    elif loss_type == "huber":  # 如果损失类型为Huber损失
        loss = F.smooth_l1_loss(noise, predicted_noise)  # 使用Huber损失函数,这是L1和L2损失的结合,对异常值不那么敏感
    else:
        raise NotImplementedError()  # 如果指定了未实现的损失类型,则抛出异常

    return loss  # 返回计算得到的损失值


 开始训练

if __name__=="__main__":
    for epoch in range(epochs):
        for step, batch in tqdm(enumerate(dataloader), desc='Training'):
          optimizer.zero_grad()
          batch = batch[0]

          batch_size = batch.shape[0]
          batch = batch.to(device)
          # 国内版启用这段,注释上面两行
          # batch_size = batch[0].shape[0]
          # batch = batch[0].to(device)

          # Algorithm 1 line 3: sample t uniformally for every example in the batch
          t = torch.randint(0, timesteps, (batch_size,), device=device).long()

          loss = p_losses(model, batch, t, loss_type="huber")

          if step % 50 == 0:
            print("Loss:", loss.item())

          loss.backward()
          optimizer.step()

          # save generated images
          if step != 0 and step % save_and_sample_every == 0:
            milestone = step // save_and_sample_every
            batches = num_to_groups(4, batch_size)
            all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
            all_images = torch.cat(all_images_list, dim=0)
            all_images = (all_images + 1) * 0.5
            # save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
            currentDateAndTime = datetime.now()
            torch.save(model,f"train.pt")


 推理结果


image.png

参考文献


  1. 深入学习:Diffusion Model 原理解析(地址:http://www.egbenz.com/#/my_article/12
  2. 【一个本子】Diffusion Model 原理详解(地址:https://zhuanlan.zhihu.com/p/582072317
  3. 深入浅出扩散模型(Diffusion Model)系列:基石DDPM(模型架构篇),最详细的DDPM架构图解(地址:https://zhuanlan.zhihu.com/p/637815071
  4. 一文读懂Transformer模型的位置编码(地址:https://zhuanlan.zhihu.com/p/637815071
  5. https://zhuanlan.zhihu.com/p/632809634


image.png

团队介绍


我们是淘天集团业务技术线的手猫营销&导购团队,专注于在手机天猫平台上探索创新商业化,我们依托淘天集团强大的互联网背景,致力于为手机天猫平台提供效率高、创新性强的技术支持。
我们的队员们来自各种营销和导购领域,拥有丰富的经验。通过不断地技术探索和商业创新,我们改善了用户的体验,并提升了平台的运营效率。
我们的团队持续不懈地探索和提升技术能力,坚持“技术领先、用户至上”,为手机天猫的导购场景和商业发展做出了显著贡献。


目录
相关文章
|
机器学习/深度学习 调度
详解 Diffusion (扩散) 模型
详解 Diffusion (扩散) 模型
199 0
|
12天前
|
机器学习/深度学习 调度 知识图谱
TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
近年来,深度神经网络成为时间序列预测的主流方法。自监督学习通过从未标记数据中学习,能够捕获时间序列的长期依赖和局部特征。TimeDART结合扩散模型和自回归建模,创新性地解决了时间序列预测中的关键挑战,在多个数据集上取得了最优性能,展示了强大的泛化能力。
53 0
TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
|
2月前
|
机器学习/深度学习 自然语言处理 并行计算
扩散模型
本文详细介绍了扩散模型(Diffusion Models, DM),一种在计算机视觉和自然语言处理等领域取得显著进展的生成模型。文章分为四部分:基本原理、处理过程、应用和代码实战。首先,阐述了扩散模型的两个核心过程:前向扩散(加噪)和逆向扩散(去噪)。接着,介绍了训练和生成的具体步骤。最后,展示了模型在图像生成、视频生成和自然语言处理等领域的广泛应用,并提供了一个基于Python和PyTorch的代码示例,帮助读者快速入门。
|
4月前
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
8695 3
|
5月前
|
人工智能 vr&ar 计算机视觉
CVPR 2024:让图像扩散模型生成高质量360度场景,只需要一个语言模型
【6月更文挑战第20天】CVPR 2024研究表明,结合语言模型的图像扩散模型能高效生成360度全景图像,减少对标注数据的依赖。该框架利用语言模型的语义信息引导细节丰富的图像生成,解决了传统方法的标注难题。然而,方法的准确性和计算资源需求是挑战。这一进展推动了VR/AR图像生成技术的发展。[论文链接](https://arxiv.org/pdf/2406.01843)**
62 6
|
5月前
|
机器学习/深度学习 Python
扩散模型的基本原理
扩散模型的基本原理
98 2
|
6月前
|
机器学习/深度学习 人工智能 测试技术
世界模型也扩散!训练出的智能体竟然不错
【5月更文挑战第30天】研究人员提出了一种名为DIAMOND的新方法,将扩散模型应用于世界模型以增强强化学习智能体的训练。DIAMOND在Atari 100k基准测试中实现了1.46的人类标准化得分,刷新了完全在世界模型中训练的智能体的记录。通过生成视觉细节,智能体在多个游戏中超越人类玩家,特别是在需要精细细节识别的游戏上。不过,DIAMOND在连续控制环境和长期记忆方面的应用仍需改进。这项工作开源了代码和模型,促进了未来相关研究。论文链接:[https://arxiv.org/abs/2405.12399](https://arxiv.org/abs/2405.12399)
116 2
|
6月前
|
人工智能 计算机视觉
论文介绍:MDTv2——提升图像合成能力的掩码扩散变换器
【5月更文挑战第18天】MDTv2是掩码扩散变换器的升级版,旨在增强图像合成模型DPMs处理语义关系的能力。通过掩码操作和不对称扩散变换,MDTv2能学习图像的完整语义信息,提升学习效率和图像质量。MDTv2采用优化的网络结构和训练策略,如长快捷方式、密集输入和时间步适应损失权重,实现SOTA性能,FID分数达到1.58,训练速度比DiT快10倍。尽管计算成本高和泛化能力待验证,MDTv2为图像合成领域开辟了新方向。[链接: https://arxiv.org/abs/2303.14389]
146 1
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
论文介绍:探索离散状态空间中的去噪扩散模型
【4月更文挑战第8天】新研究提出离散去噪扩散概率模型(D3PMs),扩展了在图像和音频生成上成功的DDPMs,专注于离散数据如文本和图像分割。D3PMs通过结构化的离散腐败过程改进生成质量,无需将数据转化为连续空间,允许嵌入领域知识。实验显示,D3PMs在字符级文本生成和CIFAR-10图像数据集上表现出色。尽管有局限性,如在某些任务上不及自回归模型,D3PMs的灵活性使其适用于多样化场景。
73 2
论文介绍:探索离散状态空间中的去噪扩散模型
|
6月前
|
机器学习/深度学习 存储 人工智能
手动实现一个扩散模型DDPM(上)
手动实现一个扩散模型DDPM(上)
268 5