使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成下

简介: 生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读。

相关资料链接为:使用PyTorch构建GAN生成对抗

本次训练代码使用了本地GPU计算。

文章的上篇讲解了数据集class和鉴别器class,下面将会继续建立生成器class,并完成鉴别器与生成器的对抗。


1 转置卷积


生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

与普通的卷积不同,转置卷积中的卷积核是从输入图像的底部向上卷的,而不是从顶部向下卷。这是通过交换卷积核的行和列,并在每次卷积时对卷积核进行逆时针旋转180度来实现的。

转置卷积的步骤如下:


对卷积核进行行列交换,将其变为与输入图像矩阵相同的大小。

对卷积核进行逆时针旋转180度。

将卷积核应用于输入图像。

因此,转置卷积不同于普通卷积,它不需要对卷积核进行任何变换,因此不需要学习任何参数。此外,转置卷积的输出图像的大小一般大于输入图像的大小,因此可以用于图像放大等应用。




d9738bc27aaa4e329797723c74f24e39.png



图片来源:《Python生成对抗网络编程》


转置卷积用到的函数是nn.ConvTranspose2d。其中的参数有in_channels:输入通道数;out_channels:输出通道数;kernel_size:卷积核的大小;stride:卷积步长;padding:卷积边缘填充;output_padding:输出大小增加的填充;groups:卷积核分组数;bias:是否使用偏置。


2 生成器修改


此次代码中,生成器的总体结构为:输入 — 全连接层 — 转置卷积层1 — 转置卷积层2 — 转置卷积层3 — 输出

卷积层后没有接新的全连接层而直接输出,是为了保障将能够将局部特征直接生成最终图像。


class Generator(nn.Module):
    def __init__(self):
        # 修改后的网络结构
        self.model = nn.Sequential(
            # 输入是一个一维数组
            nn.Linear(100, 3 * 11 * 11),
            nn.LeakyReLU(0.2),
            # 转换成四维
            View((1, 3, 11, 11)),
            nn.ConvTranspose2d(3, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 3, kernel_size=8, stride=2, padding=1),
            nn.BatchNorm2d(3),
            nn.Sigmoid()
        )



结构讲解:


生成器的输入仍然保留为100个,使用全连接的方式,连接至(3*11*11)维度的张量,激活函数使用漏斗形非线性激活函数LeakReLu。

第一个转置卷积层接受了3个11*11的输入,卷积核大小为8,步长为2,不补全,则输出为28*28。激活函数同样使用LeakReLu。

转置卷积输出的计算方法:

输出 = 输入 + (输入 − 1 ) × (步长 − 1 ) + (卷积核 − 1 ) − 补全 × 2 输出=输入+(输入-1)×(步长-1)+(卷积核-1)-补全×2

输出=输入+(输入−1)×(步长−1)+(卷积核−1)−补全×2

上面的4部分代表了原始输入的大小、因步长而增加的大小、因卷积核而增加的大小(卷积核在边缘时为其增加的空白像素)、因补全而降低的大小(是否补全与补全在函数中定义,本处未定义默认不补全,此处补全的含义与词语含义相反)。

第二个转置卷积层接受了256个28*28的输入,卷积核大小为8,步长为2,不补全,则输出为62*62。激活函数再次使用LeakReLu。

第三个转置卷积层接受了256个28*28的输入,卷积核大小为8,这里卷积核个数恢复为3(如输出通道的RGB颜色相匹配),步长为2,不补全,则输出为128*128。使用常见的sigmoid函数进行输出。

3 生成器测试


在不进行训练的情况下,直接生成图片,已检查生成器class是否存在明显BUG。


G = Generator()
G.to(device)
output = G.forward(generate_random_seed(100))
img = output.detach().permute(0,2,3,1).view(128,128,3).cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()


输出如下:




20041c73edb84cd1bfa129b9ea6c7cec.png


因为增加了转置卷积,所以与单纯的随机生成图片不同,边缘颜色的色彩鲜明都要明显弱于中心。


4 训练模型


训练部分与普通的GAN模型相比不需任何修改,此处不再贴代码,有需要可以直接下载查看。

鉴别器和生成器在训练后期,均呈现稳态震荡。




7fae3a01e675448db1b837251b57941d.png2be0de4a07b643979a77c17bc61aaee8.png1e9080e9ff65466d808899dc9279d836.png



与非卷积相比,卷积GAN生成的图片具备明显的五官特征。但五官匹配并不完美,增加训练集与训练次数后在一定程度上可以解决此问题。


下面查看使用卷积后,内存消耗是否有优化


print(torch.cuda.memory_allocated(device) / (1024*1024*1024))
print(torch.cuda.max_memory_allocated(device) / (1024*1024*1024))


结果分别为:

2da9470e49f04e169bcb778c9104c61e.png



可以看出与之前超过1个G的内存占用相比,优化明显。

本案例的完整代码链接可点此下载或文末留言申请。

————————————————


相关文章
|
6月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
446 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
7月前
|
机器学习/深度学习 算法 安全
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
667 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
|
10月前
|
并行计算 监控 搜索推荐
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
195 5
|
10月前
|
机器学习/深度学习 人工智能 PyTorch
使用Pytorch构建视觉语言模型(VLM)
视觉语言模型(Vision Language Model,VLM)正在改变计算机对视觉和文本信息的理解与交互方式。本文将介绍 VLM 的核心组件和实现细节,可以让你全面掌握这项前沿技术。我们的目标是理解并实现能够通过指令微调来执行有用任务的视觉语言模型。
260 2
|
11月前
|
机器学习/深度学习 数据采集 自然语言处理
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器
【NLP自然语言处理】基于PyTorch深度学习框架构建RNN经典案例:构建人名分类器
|
8天前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
39 1
|
4月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
639 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
8天前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
43 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
1月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
115 9
|
3月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
130 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统

热门文章

最新文章

推荐镜像

更多