手动实现一个扩散模型DDPM(下)

简介: 手动实现一个扩散模型DDPM(下)

手动实现一个扩散模型DDPM(上):https://developer.aliyun.com/article/1480704

 位置嵌入



如何让网络知道目前处于K的哪一步?可以增加一个Time Embedding(类似于Positional embeddings)进行处理,通过将timestep编码进网络中,从而只需要训练一个共享的U-Net模型,就可以让网络知道现在处于哪一步了。
Time Embedding正是输入到ResNetBlock模块中,为U-Net引入了时间信息(时间步长T,T的大小代表了噪声扰动的强度),模拟一个随时间变化不断增加不同强度噪声扰动的过程,让SD模型能够更好地理解时间相关性
同时,在SD模型调用U-Net重复迭代去噪的过程中,我们希望在迭代的早期,能够先生成整幅图片的轮廓与边缘特征,随着迭代的深入,再补充生成图片的高频和细节特征信息。由于在每个ResNetBlock模块中都有Time Embedding,就能告诉U-Net现在是整个迭代过程的哪一步,并及时控制U-Net够根据不同的输入特征和迭代阶段而预测不同的噪声残差
从AI绘画应用视角解释一下Time Embedding的作用。Time Embedding能够让SD模型在生成图片时考虑时间的影响,使得生成的图片更具有故事性、情感和沉浸感等艺术效果。并且Time Embedding可以帮助SD模型在不同的时间点将生成的图片添加完善不同情感和主题的内容,从而增加了AI绘画的多样性和表现力。


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings



 U-net


基于上述定义的DM神经网络基础的层和模块,现在是时候把他组装拼接起来了:

  • 神经网络接受一批如下shape的噪声图像输入(batch_size, num_channels, height, width) 同时接受这批噪声水平,shape=(batch_size, 1)。返回一个张量,shape = (batch_size, num_channels, height, width)


按照如下步骤构建这个网络:

  • 首先,对噪声图像进行卷积处理,对噪声水平进行进行位置编码(embedding)
  • 然后,进入一个序列的下采样阶段,每个下采样阶段由两个ResNet/ConvNeXT模块+分组归一化+注意力模块+残差链接+下采样完成。
  • 在网络的中间层,再一次用ResNet/ConvNeXT模块,中间穿插着注意力模块(Attention)。
  • 下一个阶段,则是序列构成的上采样阶段,每个上采样阶段由两个ResNet/ConvNeXT模块+分组归一化+注意力模块+残差链接+上采样完成。
  • 最后,一个ResNet/ConvNeXT模块后面跟着一个卷积层。



class Unet(nn.Module):
    # 初始化函数,定义U-Net网络的结构和参数
    def __init__(
            self,
            dim,  # 基本隐藏层维度
            init_dim=None,  # 初始层维度,如果未提供则会根据dim计算得出
            out_dim=None,  # 输出维度,如果未提供则默认为输入图像的通道数
            dim_mults=(1, 2, 4, 8),  # 控制每个阶段隐藏层维度倍增的倍数
            channels=3,  # 输入图像的通道数,默认为3
            with_time_emb=True,  # 是否使用时间嵌入,这对于某些生成模型可能是必要的
            resnet_block_groups=8,  # ResNet块中的组数
            use_convnext=True,  # 是否使用ConvNeXt块而不是ResNet块
            convnext_mult=2,  # ConvNeXt块的维度倍增因子
    ):
        super().__init__()  # 调用父类构造函数
        # 确定各层维度
        self.channels = channels
        init_dim = default(init_dim, dim // 3 * 2)  # 设置或计算初始层维度
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)  # 初始卷积层,使用7x7卷积核和padding
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]  # 计算每个阶段的维度
        in_out = list(zip(dims[:-1], dims[1:]))  # 创建输入输出维度对
        # 根据use_convnext选择块类
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)
        # 时间嵌入层
        if with_time_emb:
            time_dim = dim * 4  # 时间嵌入的维度
            self.time_mlp = nn.Sequential(  # 时间嵌入的多层感知机
                SinusoidalPositionEmbeddings(dim),  # 正弦位置嵌入
                nn.Linear(dim, time_dim),  # 线性变换
                nn.GELU(),  # GELU激活函数
                nn.Linear(time_dim, time_dim),  # 再一次线性变换
            )
        else:
            time_dim = None
            self.time_mlp = None
        # 下采样层
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)  # 解析的层数
        # 构建下采样模块
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)  # 是否为最后一层
            self.downs.append(  # 添加下采样块
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),  # 卷积块
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),  # 卷积块
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),  # 残差连接和注意力模块
                        Downsample(dim_out) if not is_last else nn.Identity(),  # 下采样或恒等映射
                    ]
                )
            )
        # 中间层(瓶颈层)
        mid_dim = dims[-1]
        # 中间层(瓶颈层)
        # 第一个中间卷积块
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        # 中间层的注意力模块
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        # 第二个中间卷积块
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        # 构建上采样模块
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)  # 是否是最后一次上采样,减2是因为我们需要留出一个输出层
            self.ups.append(
                nn.ModuleList(
                    [
                        # 卷积块,这里输入维度翻倍是因为上采样过程中会与编码器阶段的相应层进行拼接
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        # 卷积块
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        # 残差和注意力模块
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        # 上采样或恒等映射
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        # 设置或计算输出维度,如果未提供则默认为输入图像的通道数
        out_dim = default(out_dim, channels)
        # 最后的卷积层,将输出维度变换到期望的输出维度
        self.final_conv = nn.Sequential(
            block_klass(dim, dim),  # 卷积块
            nn.Conv2d(dim, out_dim, 1)  # 1x1卷积,用于输出维度变换
        )

    # 前向传播函数
    def forward(self, x, time):
        # 初始卷积层
        x = self.init_conv(x)
        # 如果存在时间嵌入层,则将时间编码
        t = self.time_mlp(time) if exists(self.time_mlp) else None
        # 用于存储各个阶段的特征图
        h = []

        # 下采样过程
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)  # 应用卷积块
            x = block2(x, t)  # 应用卷积块
            x = attn(x)  # 应用注意力模块
            h.append(x)  # 存储特征图以便后续的拼接
            x = downsample(x)  # 应用下采样或恒等映射

        # 中间层或瓶颈层
        x = self.mid_block1(x, t)  # 第一个中间卷积块
        x = self.mid_attn(x)  # 中间层的注意力模块
        x = self.mid_block2(x, t)  # 第二个中间卷积块

        # 上采样过程
        for block1, block2, attn, upsample in self.ups:
            # 拼接特征图和对应的编码器阶段的特征图
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)  # 应用卷积块
            x = block2(x, t)  # 应用卷积块
            x = attn(x)  # 应用注意力模块
            x = upsample(x)  # 应用上采样或恒等映射

        # 最后的输出层,输出最终的特征图或图像
        return self.final_conv(x)


 损失函数



下面这段代码是为扩散模型中的去噪模型定义的损失函数。它计算由去噪模型预测的噪声和实际加入的噪声之间的差异。该函数支持不同类型的损失,包括L1损失、均方误差损失(L2损失)和Huber损失。选择适当的损失函数可以帮助模型更好地学习如何预测和去除生成数据中的噪声。

import torch
import torch.nn.functional as F

# 定义损失函数,它评估去噪模型的性能
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)  # 如果未提供噪声,则生成一个与x_start形状相同的随机噪声张量

    # 使用q_sample函数生成带有噪声的数据x_noisy,这模拟了扩散模型的前向过程
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    # 使用去噪模型对噪声数据x_noisy进行预测,试图恢复加入的噪声
    predicted_noise = denoise_model(x_noisy, t)

    # 根据指定的损失类型计算损失
    if loss_type == 'l1':  # 如果损失类型为L1损失
        loss = F.l1_loss(noise, predicted_noise)  # 使用L1损失函数计算真实噪声和预测噪声之间的差异
    elif loss_type == 'l2':  # 如果损失类型为L2损失(均方误差损失)
        loss = F.mse_loss(noise, predicted_noise)  # 使用均方误差损失函数计算真实噪声和预测噪声之间的差异
    elif loss_type == "huber":  # 如果损失类型为Huber损失
        loss = F.smooth_l1_loss(noise, predicted_noise)  # 使用Huber损失函数,这是L1和L2损失的结合,对异常值不那么敏感
    else:
        raise NotImplementedError()  # 如果指定了未实现的损失类型,则抛出异常

    return loss  # 返回计算得到的损失值


 开始训练

if __name__=="__main__":
    for epoch in range(epochs):
        for step, batch in tqdm(enumerate(dataloader), desc='Training'):
          optimizer.zero_grad()
          batch = batch[0]

          batch_size = batch.shape[0]
          batch = batch.to(device)
          # 国内版启用这段,注释上面两行
          # batch_size = batch[0].shape[0]
          # batch = batch[0].to(device)

          # Algorithm 1 line 3: sample t uniformally for every example in the batch
          t = torch.randint(0, timesteps, (batch_size,), device=device).long()

          loss = p_losses(model, batch, t, loss_type="huber")

          if step % 50 == 0:
            print("Loss:", loss.item())

          loss.backward()
          optimizer.step()

          # save generated images
          if step != 0 and step % save_and_sample_every == 0:
            milestone = step // save_and_sample_every
            batches = num_to_groups(4, batch_size)
            all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
            all_images = torch.cat(all_images_list, dim=0)
            all_images = (all_images + 1) * 0.5
            # save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
            currentDateAndTime = datetime.now()
            torch.save(model,f"train.pt")


 推理结果


image.png

参考文献


  1. 深入学习:Diffusion Model 原理解析(地址:http://www.egbenz.com/#/my_article/12
  2. 【一个本子】Diffusion Model 原理详解(地址:https://zhuanlan.zhihu.com/p/582072317
  3. 深入浅出扩散模型(Diffusion Model)系列:基石DDPM(模型架构篇),最详细的DDPM架构图解(地址:https://zhuanlan.zhihu.com/p/637815071
  4. 一文读懂Transformer模型的位置编码(地址:https://zhuanlan.zhihu.com/p/637815071
  5. https://zhuanlan.zhihu.com/p/632809634


image.png

团队介绍


我们是淘天集团业务技术线的手猫营销&导购团队,专注于在手机天猫平台上探索创新商业化,我们依托淘天集团强大的互联网背景,致力于为手机天猫平台提供效率高、创新性强的技术支持。
我们的队员们来自各种营销和导购领域,拥有丰富的经验。通过不断地技术探索和商业创新,我们改善了用户的体验,并提升了平台的运营效率。
我们的团队持续不懈地探索和提升技术能力,坚持“技术领先、用户至上”,为手机天猫的导购场景和商业发展做出了显著贡献。


目录
相关文章
|
机器学习/深度学习 调度
详解 Diffusion (扩散) 模型
详解 Diffusion (扩散) 模型
1076 0
|
人工智能 资源调度 算法
AI 绘画Stable Diffusion 研究(八)sd采样方法详解
AI 绘画Stable Diffusion 研究(八)sd采样方法详解
2950 0
|
11月前
|
存储 安全 区块链
区块链在房地产交易中的应用:革新房产市场的未来
区块链在房地产交易中的应用:革新房产市场的未来
850 80
|
7月前
|
存储 机器学习/深度学习 缓存
vLLM 核心技术 PagedAttention 原理详解
本文系统梳理了 vLLM 核心技术 PagedAttention 的设计理念与实现机制。文章从 KV Cache 在推理中的关键作用与内存管理挑战切入,介绍了 vLLM 在请求调度、分布式执行及 GPU kernel 优化等方面的核心改进。PagedAttention 通过分页机制与动态映射,有效提升了显存利用率,使 vLLM 在保持低延迟的同时显著提升了吞吐能力。
3722 19
vLLM 核心技术 PagedAttention 原理详解
|
机器学习/深度学习 存储 人工智能
手动实现一个扩散模型DDPM(上)
手动实现一个扩散模型DDPM(上)
941 5
|
11月前
|
机器学习/深度学习 自然语言处理 搜索推荐
自注意力机制全解析:从原理到计算细节,一文尽览!
自注意力机制(Self-Attention)最早可追溯至20世纪70年代的神经网络研究,但直到2017年Google Brain团队提出Transformer架构后才广泛应用于深度学习。它通过计算序列内部元素间的相关性,捕捉复杂依赖关系,并支持并行化训练,显著提升了处理长文本和序列数据的能力。相比传统的RNN、LSTM和GRU,自注意力机制在自然语言处理(NLP)、计算机视觉、语音识别及推荐系统等领域展现出卓越性能。其核心步骤包括生成查询(Q)、键(K)和值(V)向量,计算缩放点积注意力得分,应用Softmax归一化,以及加权求和生成输出。自注意力机制提高了模型的表达能力,带来了更精准的服务。
12631 46
|
缓存 Shell iOS开发
修改 torch和huggingface 缓存路径
简介:本文介绍了如何修改 PyTorch 和 Huggingface Transformers 的缓存路径。通过设置环境变量 `TORCH_HOME` 和 `HF_HOME` 或 `TRANSFORMERS_CACHE`,可以在 Windows、Linux 和 MacOS 上指定自定义缓存目录。具体步骤包括设置环境变量、编辑 shell 配置文件、移动现有缓存文件以及创建符号链接(可选)。
3806 2
|
机器学习/深度学习 PyTorch 算法框架/工具
图像数据增强库综述:10个强大图像增强工具对比与分析
在深度学习和计算机视觉领域,数据增强是提升模型性能和泛化能力的关键技术。本文全面介绍了10个广泛使用的图像数据增强库,分析其特点和适用场景,帮助研究人员和开发者选择最适合需求的工具。这些库包括高性能的GPU加速解决方案(如Nvidia DALI)、灵活多功能的Albumentations和Imgaug,以及专注于特定框架的Kornia和Torchvision Transforms。通过详细比较各库的功能、特点和适用场景,本文为不同需求的用户提供丰富的选择,助力深度学习项目取得更好的效果。选择合适的数据增强库需考虑性能需求、任务类型、框架兼容性及易用性等因素。
1963 10
|
前端开发 JavaScript 搜索推荐
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
9276 3