使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

简介: 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类。

使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

接下来,我们测试一下鉴别器是否可以正常工作,并建立生成器。

1 测试鉴别器


# 数据类建立
celeba_dataset = CelebADataset(r'F:\学习\AI\对抗网络\face-data\celeba_aligned_small.h5py')
celeba_dataset.plot_image(66)
# 鉴别器类建立
D = Discriminator()
D.to(device)
for image_data_tensor in celeba_dataset:
    # real data
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))
    # fake data
    D.train(generate_random_image((218,178,3)), torch.cuda.FloatTensor([0.0]))


此处我们调用了两个类,一个是celeba_dataset(Dataset)类,一个是D(Discriminator)类。两个类在博文的上篇中完成了定义。此处分别使用real数据与fake数据对模型进行训练。fake数据使用的是随机生成的不规则像素点,real数据使用的是真是人脸数据。

在使用GPU的情况,此处预计会消耗5分钟左右。

训练完成后,可以绘制损失值的变化以查看训练效果。


D.plot_progress()
plt.show()



6266aaa0f4874be58e92a5b79bf87de7.png



2 建立生成器


生成器与鉴别器高度类似,仅网络的结构和训练部分略有不同。

网格结构选取的是输入层为100个节点,中间层为单层结构,包含3*10*10个节点,输出层为3 * 218 * 178。输出层是完全根据照片的像素格式来确定的,输入层与中间层可以根据经验进行修改与优化。各层之间均采用全连接的连接方式。相关部分的代码如下:


class Generator(nn.Module):
    def __init__(self):
        # 父类继承
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
            nn.Linear(100, 3 * 10 * 10),
            nn.LeakyReLU(),
            nn.LayerNorm(3 * 10 * 10),
            nn.Linear(3 * 10 * 10, 3 * 218 * 178),
            nn.Sigmoid(),
            View((218, 178, 3))
        )


在进行损失计算时,我们将鉴别器的返回值作为实际输出,将torch.cuda.FloatTensor([1.0]作为目标输出,来计算损失。相关比分的代码如下:


class Generator(nn.Module):
    def train(self, D, inputs, targets):
        # 计算输出
        g_output = self.forward(inputs)
        # 将输出传至鉴别器
        d_output = D.forward(g_output)
        # 计算损失
        loss = D.loss_function(d_output, targets)


对于生成器的完整代码,也将在文末进行提供。


3 测试生成器


未经训练的生成器,应该具备生成类似雪花马赛克的随机图像能力。下面建立了一个生成器类,并用未经训练的生成器直接输出图像。


G = Generator()
G.to(device)
output = G.forward(generate_random_seed(100))
img = output.detach().cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()


如果代码运行正常,应得到类似下面的图象。



269efa8a3a9041689a420b4c798d846e.png




4 训练生成器


训练时,对数据集进行遍历,并且依次执行下面三步:


使用真实照片数据,对鉴别器进行训练,期望的鉴别器输出值为1;

使用生成器输出的fake数据,对鉴别器进行训练,期望的鉴别器输出值为0;

使用鉴别器的返回值,训练生成器,生成器所希望的鉴别器输出为1

具体代码如下:

for image_data_tensor in celeba_dataset:
    # train discriminator on true
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))
    # train discriminator on false
    # use detach() so gradients in G are not calculated
    D.train(G.forward(generate_random_seed(100)).detach(), torch.cuda.FloatTensor([0.0]))
    # train generator
    G.train(D, generate_random_seed(100), torch.cuda.FloatTensor([1.0]))


在训练后,可以分别查看鉴别器与生成器的损失变化曲线。

D.plot_progress()
G.plot_progress()


下图为鉴别器损失值变化曲线


fbeac81be0a045c49f55e66a72e01467.png


下图为生成器损失值变化曲22898c3939124c08a9982448021701f2.png


5 使用生成器




c253c56ab2aa45039715fc563bd1f02c.png


6 内存查看


最后可以查看一下本次训练的内存使用情况

(1)分配给张量的当前内存(输出单位是GB)


torch.cuda.memory_allocated(device) / (1024*1024*1024)


我的输出结果为:0.6999950408935547

(2)分配给张量的总内存(输出单位是GB)


torch.cuda.max_memory_allocated(device) / (1024*1024*1024)


我的输出结果为:0.962151050567627

(3)内存消耗汇总


print(torch.cuda.memory_summary(device, abbreviated=True))
1

输出结果如下:


|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| Active memory         |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    1086 MB |    1086 MB |    1086 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |    9426 KB |   12685 KB |  353393 MB |  353383 MB |
|---------------------------------------------------------------------------|
| Allocations           |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| Active allocs         |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |      15    |      15    |      15    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      11    |      14    |    1410 K  |    1410 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|


相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
相关文章
|
6月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
446 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
6月前
|
网络协议 物联网
VB6网络通信软件上位机开发,TCP网络通信,读写数据并处理,完整源码下载
本文介绍使用VB6开发网络通信上位机客户端程序,涵盖Winsock控件的引入与使用,包括连接服务端、发送数据(如通过`Winsock1.SendData`方法)及接收数据(利用`Winsock1_DataArrival`事件)。代码实现TCP网络通信,可读写并处理16进制数据,适用于自动化和工业控制领域。提供完整源码下载,适合学习VB6网络程序开发。 下载链接:[完整源码](http://xzios.cn:86/WJGL/DownLoadDetial?Id=20)
231 12
|
6月前
|
前端开发 Java 关系型数据库
基于ssm的网络直播带货管理系统,附源码+数据库+论文
该项目为网络直播带货网站,包含管理员和用户两个角色。管理员可进行主页、个人中心、用户管理、商品分类与信息管理、系统及订单管理;用户可浏览主页、管理个人中心、收藏和订单。系统基于Java开发,采用B/S架构,前端使用Vue、JSP等技术,后端为SSM框架,数据库为MySQL。项目运行环境为Windows,支持JDK8、Tomcat8.5。提供演示视频和详细文档截图。
156 10
|
6月前
|
JavaScript 算法 前端开发
JS数组操作方法全景图,全网最全构建完整知识网络!js数组操作方法全集(实现筛选转换、随机排序洗牌算法、复杂数据处理统计等情景详解,附大量源码和易错点解析)
这些方法提供了对数组的全面操作,包括搜索、遍历、转换和聚合等。通过分为原地操作方法、非原地操作方法和其他方法便于您理解和记忆,并熟悉他们各自的使用方法与使用范围。详细的案例与进阶使用,方便您理解数组操作的底层原理。链式调用的几个案例,让您玩转数组操作。 只有锻炼思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
|
存储 Java Unix
(八)Java网络编程之IO模型篇-内核Select、Poll、Epoll多路复用函数源码深度历险!
select/poll、epoll这些词汇相信诸位都不陌生,因为在Redis/Nginx/Netty等一些高性能技术栈的底层原理中,大家应该都见过它们的身影,接下来重点讲解这块内容。
281 0
|
JavaScript Java 测试技术
基于SpringBoot+Vue+uniapp的网络安全科普系统的详细设计和实现(源码+lw+部署文档+讲解等)
基于SpringBoot+Vue+uniapp的网络安全科普系统的详细设计和实现(源码+lw+部署文档+讲解等)
155 0
|
8天前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
39 1
|
4月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
639 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
8天前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
43 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
1月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
115 9

推荐镜像

更多