使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成下

简介: 生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读。

相关资料链接为:使用PyTorch构建GAN生成对抗

本次训练代码使用了本地GPU计算。

文章的上篇讲解了数据集class和鉴别器class,下面将会继续建立生成器class,并完成鉴别器与生成器的对抗。


1 转置卷积


生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

与普通的卷积不同,转置卷积中的卷积核是从输入图像的底部向上卷的,而不是从顶部向下卷。这是通过交换卷积核的行和列,并在每次卷积时对卷积核进行逆时针旋转180度来实现的。

转置卷积的步骤如下:


对卷积核进行行列交换,将其变为与输入图像矩阵相同的大小。

对卷积核进行逆时针旋转180度。

将卷积核应用于输入图像。

因此,转置卷积不同于普通卷积,它不需要对卷积核进行任何变换,因此不需要学习任何参数。此外,转置卷积的输出图像的大小一般大于输入图像的大小,因此可以用于图像放大等应用。




d9738bc27aaa4e329797723c74f24e39.png



图片来源:《Python生成对抗网络编程》


转置卷积用到的函数是nn.ConvTranspose2d。其中的参数有in_channels:输入通道数;out_channels:输出通道数;kernel_size:卷积核的大小;stride:卷积步长;padding:卷积边缘填充;output_padding:输出大小增加的填充;groups:卷积核分组数;bias:是否使用偏置。


2 生成器修改


此次代码中,生成器的总体结构为:输入 — 全连接层 — 转置卷积层1 — 转置卷积层2 — 转置卷积层3 — 输出

卷积层后没有接新的全连接层而直接输出,是为了保障将能够将局部特征直接生成最终图像。


class Generator(nn.Module):
    def __init__(self):
        # 修改后的网络结构
        self.model = nn.Sequential(
            # 输入是一个一维数组
            nn.Linear(100, 3 * 11 * 11),
            nn.LeakyReLU(0.2),
            # 转换成四维
            View((1, 3, 11, 11)),
            nn.ConvTranspose2d(3, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 3, kernel_size=8, stride=2, padding=1),
            nn.BatchNorm2d(3),
            nn.Sigmoid()
        )



结构讲解:


生成器的输入仍然保留为100个,使用全连接的方式,连接至(3*11*11)维度的张量,激活函数使用漏斗形非线性激活函数LeakReLu。

第一个转置卷积层接受了3个11*11的输入,卷积核大小为8,步长为2,不补全,则输出为28*28。激活函数同样使用LeakReLu。

转置卷积输出的计算方法:

输出 = 输入 + (输入 − 1 ) × (步长 − 1 ) + (卷积核 − 1 ) − 补全 × 2 输出=输入+(输入-1)×(步长-1)+(卷积核-1)-补全×2

输出=输入+(输入−1)×(步长−1)+(卷积核−1)−补全×2

上面的4部分代表了原始输入的大小、因步长而增加的大小、因卷积核而增加的大小(卷积核在边缘时为其增加的空白像素)、因补全而降低的大小(是否补全与补全在函数中定义,本处未定义默认不补全,此处补全的含义与词语含义相反)。

第二个转置卷积层接受了256个28*28的输入,卷积核大小为8,步长为2,不补全,则输出为62*62。激活函数再次使用LeakReLu。

第三个转置卷积层接受了256个28*28的输入,卷积核大小为8,这里卷积核个数恢复为3(如输出通道的RGB颜色相匹配),步长为2,不补全,则输出为128*128。使用常见的sigmoid函数进行输出。

3 生成器测试


在不进行训练的情况下,直接生成图片,已检查生成器class是否存在明显BUG。


G = Generator()
G.to(device)
output = G.forward(generate_random_seed(100))
img = output.detach().permute(0,2,3,1).view(128,128,3).cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()


输出如下:




20041c73edb84cd1bfa129b9ea6c7cec.png


因为增加了转置卷积,所以与单纯的随机生成图片不同,边缘颜色的色彩鲜明都要明显弱于中心。


4 训练模型


训练部分与普通的GAN模型相比不需任何修改,此处不再贴代码,有需要可以直接下载查看。

鉴别器和生成器在训练后期,均呈现稳态震荡。




7fae3a01e675448db1b837251b57941d.png2be0de4a07b643979a77c17bc61aaee8.png1e9080e9ff65466d808899dc9279d836.png



与非卷积相比,卷积GAN生成的图片具备明显的五官特征。但五官匹配并不完美,增加训练集与训练次数后在一定程度上可以解决此问题。


下面查看使用卷积后,内存消耗是否有优化


print(torch.cuda.memory_allocated(device) / (1024*1024*1024))
print(torch.cuda.max_memory_allocated(device) / (1024*1024*1024))


结果分别为:

2da9470e49f04e169bcb778c9104c61e.png



可以看出与之前超过1个G的内存占用相比,优化明显。

本案例的完整代码链接可点此下载或文末留言申请。

————————————————


目录
打赏
0
0
0
0
22
分享
相关文章
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
546 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
148 5
使用Pytorch构建视觉语言模型(VLM)
视觉语言模型(Vision Language Model,VLM)正在改变计算机对视觉和文本信息的理解与交互方式。本文将介绍 VLM 的核心组件和实现细节,可以让你全面掌握这项前沿技术。我们的目标是理解并实现能够通过指令微调来执行有用任务的视觉语言模型。
193 2
|
10月前
|
使用PyTorch从零构建Llama 3
本文将详细指导如何从零开始构建完整的Llama 3模型架构,并在自定义数据集上执行训练和推理。
154 1
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器
在Windows平台使用源码编译和安装PyTorch3D指定版本
【10月更文挑战第6天】在 Windows 平台上,编译和安装指定版本的 PyTorch3D 需要先安装 Python、Visual Studio Build Tools 和 CUDA(如有需要),然后通过 Git 获取源码。建议创建虚拟环境以隔离依赖,并使用 `pip` 安装所需库。最后,在源码目录下运行 `python setup.py install` 进行编译和安装。完成后即可在 Python 中导入 PyTorch3D 使用。
838 0
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
225 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
261 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
57 4
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现

热门文章

最新文章

推荐镜像

更多
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等