使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

简介: 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类。

使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

接下来,我们测试一下鉴别器是否可以正常工作,并建立生成器。

1 测试鉴别器


# 数据类建立
celeba_dataset = CelebADataset(r'F:\学习\AI\对抗网络\face-data\celeba_aligned_small.h5py')
celeba_dataset.plot_image(66)
# 鉴别器类建立
D = Discriminator()
D.to(device)
for image_data_tensor in celeba_dataset:
    # real data
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))
    # fake data
    D.train(generate_random_image((218,178,3)), torch.cuda.FloatTensor([0.0]))


此处我们调用了两个类,一个是celeba_dataset(Dataset)类,一个是D(Discriminator)类。两个类在博文的上篇中完成了定义。此处分别使用real数据与fake数据对模型进行训练。fake数据使用的是随机生成的不规则像素点,real数据使用的是真是人脸数据。

在使用GPU的情况,此处预计会消耗5分钟左右。

训练完成后,可以绘制损失值的变化以查看训练效果。


D.plot_progress()
plt.show()



6266aaa0f4874be58e92a5b79bf87de7.png



2 建立生成器


生成器与鉴别器高度类似,仅网络的结构和训练部分略有不同。

网格结构选取的是输入层为100个节点,中间层为单层结构,包含3*10*10个节点,输出层为3 * 218 * 178。输出层是完全根据照片的像素格式来确定的,输入层与中间层可以根据经验进行修改与优化。各层之间均采用全连接的连接方式。相关部分的代码如下:


class Generator(nn.Module):
    def __init__(self):
        # 父类继承
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
            nn.Linear(100, 3 * 10 * 10),
            nn.LeakyReLU(),
            nn.LayerNorm(3 * 10 * 10),
            nn.Linear(3 * 10 * 10, 3 * 218 * 178),
            nn.Sigmoid(),
            View((218, 178, 3))
        )


在进行损失计算时,我们将鉴别器的返回值作为实际输出,将torch.cuda.FloatTensor([1.0]作为目标输出,来计算损失。相关比分的代码如下:


class Generator(nn.Module):
    def train(self, D, inputs, targets):
        # 计算输出
        g_output = self.forward(inputs)
        # 将输出传至鉴别器
        d_output = D.forward(g_output)
        # 计算损失
        loss = D.loss_function(d_output, targets)


对于生成器的完整代码,也将在文末进行提供。


3 测试生成器


未经训练的生成器,应该具备生成类似雪花马赛克的随机图像能力。下面建立了一个生成器类,并用未经训练的生成器直接输出图像。


G = Generator()
G.to(device)
output = G.forward(generate_random_seed(100))
img = output.detach().cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()


如果代码运行正常,应得到类似下面的图象。



269efa8a3a9041689a420b4c798d846e.png




4 训练生成器


训练时,对数据集进行遍历,并且依次执行下面三步:


使用真实照片数据,对鉴别器进行训练,期望的鉴别器输出值为1;

使用生成器输出的fake数据,对鉴别器进行训练,期望的鉴别器输出值为0;

使用鉴别器的返回值,训练生成器,生成器所希望的鉴别器输出为1

具体代码如下:

for image_data_tensor in celeba_dataset:
    # train discriminator on true
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))
    # train discriminator on false
    # use detach() so gradients in G are not calculated
    D.train(G.forward(generate_random_seed(100)).detach(), torch.cuda.FloatTensor([0.0]))
    # train generator
    G.train(D, generate_random_seed(100), torch.cuda.FloatTensor([1.0]))


在训练后,可以分别查看鉴别器与生成器的损失变化曲线。

D.plot_progress()
G.plot_progress()


下图为鉴别器损失值变化曲线


fbeac81be0a045c49f55e66a72e01467.png


下图为生成器损失值变化曲22898c3939124c08a9982448021701f2.png


5 使用生成器




c253c56ab2aa45039715fc563bd1f02c.png


6 内存查看


最后可以查看一下本次训练的内存使用情况

(1)分配给张量的当前内存(输出单位是GB)


torch.cuda.memory_allocated(device) / (1024*1024*1024)


我的输出结果为:0.6999950408935547

(2)分配给张量的总内存(输出单位是GB)


torch.cuda.max_memory_allocated(device) / (1024*1024*1024)


我的输出结果为:0.962151050567627

(3)内存消耗汇总


print(torch.cuda.memory_summary(device, abbreviated=True))
1

输出结果如下:


|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| Active memory         |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    1086 MB |    1086 MB |    1086 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |    9426 KB |   12685 KB |  353393 MB |  353383 MB |
|---------------------------------------------------------------------------|
| Allocations           |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| Active allocs         |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |      15    |      15    |      15    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      11    |      14    |    1410 K  |    1410 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|


相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
145 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 中的动态计算图:实现灵活的神经网络架构
【8月更文第27天】PyTorch 是一款流行的深度学习框架,它以其灵活性和易用性而闻名。与 TensorFlow 等其他框架相比,PyTorch 最大的特点之一是支持动态计算图。这意味着开发者可以在运行时定义网络结构,这为构建复杂的模型提供了极大的便利。本文将深入探讨 PyTorch 中动态计算图的工作原理,并通过一些示例代码展示如何利用这一特性来构建灵活的神经网络架构。
338 1
|
3天前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本文详细介绍了如何在昇腾平台上使用PyTorch实现GraphSage算法,在CiteSeer数据集上进行图神经网络的分类训练。内容涵盖GraphSage的创新点、算法原理、网络架构及实战代码分析,通过采样和聚合方法高效处理大规模图数据。实验结果显示,模型在CiteSeer数据集上的分类准确率达到66.5%。
|
21天前
|
域名解析 运维 网络协议
网络诊断指南:网络故障排查步骤与技巧
网络诊断指南:网络故障排查步骤与技巧
81 7
|
1月前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:从漏洞到加密,保护数据的关键步骤
【10月更文挑战第24天】在数字化时代,网络安全和信息安全是维护个人隐私和企业资产的前线防线。本文将探讨网络安全中的常见漏洞、加密技术的重要性以及如何通过提高安全意识来防范潜在的网络威胁。我们将深入理解网络安全的基本概念,学习如何识别和应对安全威胁,并掌握保护信息不被非法访问的策略。无论你是IT专业人士还是日常互联网用户,这篇文章都将为你提供宝贵的知识和技能,帮助你在网络世界中更安全地航行。
|
1月前
|
数据采集 Java API
java怎么设置代理ip:简单步骤,实现高效网络请求
本文介绍了在Java中设置代理IP的方法,包括使用系统属性设置HTTP和HTTPS代理、在URL连接中设置代理、设置身份验证代理,以及使用第三方库如Apache HttpClient进行更复杂的代理配置。这些方法有助于提高网络请求的安全性和灵活性。
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
203 59
|
3月前
|
机器学习/深度学习
小土堆-pytorch-神经网络-损失函数与反向传播_笔记
在使用损失函数时,关键在于匹配输入和输出形状。例如,在L1Loss中,输入形状中的N代表批量大小。以下是具体示例:对于相同形状的输入和目标张量,L1Loss默认计算差值并求平均;此外,均方误差(MSE)也是常用损失函数。实战中,损失函数用于计算模型输出与真实标签间的差距,并通过反向传播更新模型参数。
|
4月前
|
SQL 安全 网络安全
网络安全漏洞与加密技术:提升安全意识的关键步骤
【8月更文挑战第31天】在数字时代的浪潮中,网络安全已成为保护个人信息和企业资产的前沿防线。本文将深入探讨网络安全的重要性,分析常见的安全漏洞,介绍加密技术如何加固这道防线,并强调提升个人和组织的安全意识的必要性。通过实例和代码演示,我们将一窥网络防御的艺术,旨在启发读者构建更安全的网络环境。
|
4月前
|
监控 安全 iOS开发

热门文章

最新文章

下一篇
DataWorks