使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成下

简介: 生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读。

相关资料链接为:使用PyTorch构建GAN生成对抗

本次训练代码使用了本地GPU计算。

文章的上篇讲解了数据集class和鉴别器class,下面将会继续建立生成器class,并完成鉴别器与生成器的对抗。


1 转置卷积


生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

与普通的卷积不同,转置卷积中的卷积核是从输入图像的底部向上卷的,而不是从顶部向下卷。这是通过交换卷积核的行和列,并在每次卷积时对卷积核进行逆时针旋转180度来实现的。

转置卷积的步骤如下:


对卷积核进行行列交换,将其变为与输入图像矩阵相同的大小。

对卷积核进行逆时针旋转180度。

将卷积核应用于输入图像。

因此,转置卷积不同于普通卷积,它不需要对卷积核进行任何变换,因此不需要学习任何参数。此外,转置卷积的输出图像的大小一般大于输入图像的大小,因此可以用于图像放大等应用。




d9738bc27aaa4e329797723c74f24e39.png



图片来源:《Python生成对抗网络编程》


转置卷积用到的函数是nn.ConvTranspose2d。其中的参数有in_channels:输入通道数;out_channels:输出通道数;kernel_size:卷积核的大小;stride:卷积步长;padding:卷积边缘填充;output_padding:输出大小增加的填充;groups:卷积核分组数;bias:是否使用偏置。


2 生成器修改


此次代码中,生成器的总体结构为:输入 — 全连接层 — 转置卷积层1 — 转置卷积层2 — 转置卷积层3 — 输出

卷积层后没有接新的全连接层而直接输出,是为了保障将能够将局部特征直接生成最终图像。


class Generator(nn.Module):
    def __init__(self):
        # 修改后的网络结构
        self.model = nn.Sequential(
            # 输入是一个一维数组
            nn.Linear(100, 3 * 11 * 11),
            nn.LeakyReLU(0.2),
            # 转换成四维
            View((1, 3, 11, 11)),
            nn.ConvTranspose2d(3, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 3, kernel_size=8, stride=2, padding=1),
            nn.BatchNorm2d(3),
            nn.Sigmoid()
        )



结构讲解:


生成器的输入仍然保留为100个,使用全连接的方式,连接至(3*11*11)维度的张量,激活函数使用漏斗形非线性激活函数LeakReLu。

第一个转置卷积层接受了3个11*11的输入,卷积核大小为8,步长为2,不补全,则输出为28*28。激活函数同样使用LeakReLu。

转置卷积输出的计算方法:

输出 = 输入 + (输入 − 1 ) × (步长 − 1 ) + (卷积核 − 1 ) − 补全 × 2 输出=输入+(输入-1)×(步长-1)+(卷积核-1)-补全×2

输出=输入+(输入−1)×(步长−1)+(卷积核−1)−补全×2

上面的4部分代表了原始输入的大小、因步长而增加的大小、因卷积核而增加的大小(卷积核在边缘时为其增加的空白像素)、因补全而降低的大小(是否补全与补全在函数中定义,本处未定义默认不补全,此处补全的含义与词语含义相反)。

第二个转置卷积层接受了256个28*28的输入,卷积核大小为8,步长为2,不补全,则输出为62*62。激活函数再次使用LeakReLu。

第三个转置卷积层接受了256个28*28的输入,卷积核大小为8,这里卷积核个数恢复为3(如输出通道的RGB颜色相匹配),步长为2,不补全,则输出为128*128。使用常见的sigmoid函数进行输出。

3 生成器测试


在不进行训练的情况下,直接生成图片,已检查生成器class是否存在明显BUG。


G = Generator()
G.to(device)
output = G.forward(generate_random_seed(100))
img = output.detach().permute(0,2,3,1).view(128,128,3).cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()


输出如下:




20041c73edb84cd1bfa129b9ea6c7cec.png


因为增加了转置卷积,所以与单纯的随机生成图片不同,边缘颜色的色彩鲜明都要明显弱于中心。


4 训练模型


训练部分与普通的GAN模型相比不需任何修改,此处不再贴代码,有需要可以直接下载查看。

鉴别器和生成器在训练后期,均呈现稳态震荡。




7fae3a01e675448db1b837251b57941d.png2be0de4a07b643979a77c17bc61aaee8.png1e9080e9ff65466d808899dc9279d836.png



与非卷积相比,卷积GAN生成的图片具备明显的五官特征。但五官匹配并不完美,增加训练集与训练次数后在一定程度上可以解决此问题。


下面查看使用卷积后,内存消耗是否有优化


print(torch.cuda.memory_allocated(device) / (1024*1024*1024))
print(torch.cuda.max_memory_allocated(device) / (1024*1024*1024))


结果分别为:

2da9470e49f04e169bcb778c9104c61e.png



可以看出与之前超过1个G的内存占用相比,优化明显。

本案例的完整代码链接可点此下载或文末留言申请。

————————————————


目录
打赏
0
0
0
0
21
分享
相关文章
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
37 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
293 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
99 5
使用Pytorch构建视觉语言模型(VLM)
视觉语言模型(Vision Language Model,VLM)正在改变计算机对视觉和文本信息的理解与交互方式。本文将介绍 VLM 的核心组件和实现细节,可以让你全面掌握这项前沿技术。我们的目标是理解并实现能够通过指令微调来执行有用任务的视觉语言模型。
113 2
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
270 66
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
833 2
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
126 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
387 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等