使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

简介: 此项目使用的是著名的celebA(CelebFaces Attribute)数据集。其包含10,177个名人身份的202,599张人脸图片,每张图片都做好了特征标记,包含人脸bbox标注框、5个人脸特征点坐标以及40个属性标记,数据由香港中文大学开放提供(不包含商业用途的使用)。

1 数据集描述


此项目使用的是著名的celebA(CelebFaces Attribute)数据集。其包含10,177个名人身份的202,599张人脸图片,每张图片都做好了特征标记,包含人脸bbox标注框、5个人脸特征点坐标以及40个属性标记,数据由香港中文大学开放提供(不包含商业用途的使用)。

ac4268ff3ede44a5ae9aaf486e88fcaf.png



在实际训练前,已经将数据处理成了HDF5的数据集格式。使用h5py处理HDF5数据集可以提供很多方便,使得数据处理更加高效、灵活、可扩展,显著提升训练过程的文件读取速度。可以使用h5py包自行对数据进行处理,也可直接下载我已经处理好的HDF5数据格式。

如需了解更多h5py相关知识,可以查看HDF5补充知识。


2 GPU设置


前面几篇博客的内容,都是对手写数字这个数据集的处理,CPU还能吃得消。这次数据输入明显增加,需要使用GPU处理数据。如电脑无NAVIDIA独显,建议使用Google Colab执行代码,Colab提供了免费的GPU算力。


if torch.cuda.is_available():
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  print("using cuda:", torch.cuda.get_device_name(0))
  pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


这段代码的作用是,如果当前设备有可用的CUDA,则将默认的张量类型设置为CUDA浮点张量并输出使用的CUDA设备的名称。然后,它将设备设置为CUDA设备(如果有)或CPU。


3 设置Dataset类


基于面向对象编程的基本原则,我们建立一个Dataset类,使类具有数据读取、获取指定索引的数据与绘制指定索引的图像,具体代码如下:


class CelebADataset(Dataset):
    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')
        self.dataset = self.file_object['img_align_celeba']
        pass
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        if index >= len(self.dataset):
            raise IndexError()
        img = numpy.array(self.dataset[str(index) + '.jpg'])
        return torch.cuda.FloatTensor(img) / 255.0
    def plot_image(self, index):
        plt.imshow(numpy.array(self.dataset[str(index) + '.jpg']), interpolation='nearest')
        plt.show()


在获取指定索引对应的数据时,如果指定数大于索引的最大值,我们命令程序返回一个IndexError()错误,以便于快速查找问题所在。

为了理解这一数据类,我们对类进行使用:

celeba_dataset = CelebADataset('文件地址.h5py')


这里创建了一个名为celeba_dataset 的CelebADataset类,并传入了文件的所在路径file。在__init__中,使用h5py.File方法读取路经所在的文件。


celeba_dataset.plot_image(66)

绘制数据集中66.jpg图形。如果前面代码正确,此处将绘制出数据集中的人脸头像。如果为绘制出图形并产生报错,考虑路径是否有误以及数据格式是否正确。4471d2bdd6464396b51cfcd0c11bd769.png

4 设置辨别器类


本项目的核心类为鉴别器类与生成器类,下面开始编写鉴别器类。首先建立神经网络框架:

class Discriminator(nn.Module):
    def __init__(self):
        # 父类继承
        super().__init__()
        # 神经网络定义
        self.model = nn.Sequential(
            View(218 * 178 * 3),
            nn.Linear(3 * 218 * 178, 100),
            nn.LeakyReLU(),
            nn.LayerNorm(100),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )
        # 创建损失函数
        self.loss_function = nn.BCELoss()
        # 创建优化器
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
        # 初始化计数器
        self.counter = 0
        self.progress = []



这段代码定义了一个名为Discriminator的类,继承了PyTorch中nn.Module类。在__init__函数中,通过nn.Sequential定义了一个神经网络模型,包括三个线性层,两个激活函数,一个归一化层。一开始的View(218*178*3)是新代码。它的作用是将大小为(218, 178, 3) 的三维图像张量重塑成一个长度为218×178×3的一维张量。基于自上而下的编程习惯,我们会在后面对View进行定义。

在此基础上,定义了损失函数nn.BCELoss()和优化器Adam,并定义了一个计数器和一个存储进度的列表。


class Discriminator(nn.Module):
    def forward(self, inputs):
        # simply run model
        return self.model(inputs)
    def train(self, inputs, targets):
        # calculate the output of the network
        outputs = self.forward(inputs)
        # calculate loss
        loss = self.loss_function(outputs, targets)
        # increase counter and accumulate error every 10
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
        # 梯度归零,向后传递,优化执行
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))


接下来定义forward功能,train功能,plot_progress功能。在forward()函数中,它只是让模型对输入数据进行前向传播并返回网络的输出。在train()函数中,它使用输入数据和目标数据来计算网络的损失,并使用优化器来更新网络的参数。最后,plot_progress()函数可以用来绘制训练进度。以上类方法与手写字体识别博文中的定义完全相同,如有需要可找到对应博文查看。


5 辅助函数与辅助类


class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape, # 逗号不是多打的,代表这是元组
    def forward(self, x):
        return x.view(*self.shape)


在前面定义鉴别器类时,我们已经使用了View,此处对View进行补充定义。在 forward 方法中,它对输入的 x 应用了 view 方法,并将 shape 属性作为参数传入。这个模型的作用是将输入的张量的形状调整为 shape 属性所指定的形状。


def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data
def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data


以上两个随机张量生成器,其作用与手写数字识别中的作用完全相同,在此不再赘述。后续在使用时也会再进行介绍。

截至目前,我们已经建立好了模型所必需的鉴别器类与Dataset类。下一篇会讲解最重要的鉴别器类以及对模型的训练与使用。


相关文章
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
145 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
17天前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
21天前
|
域名解析 运维 网络协议
网络诊断指南:网络故障排查步骤与技巧
网络诊断指南:网络故障排查步骤与技巧
82 7
|
1月前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:从漏洞到加密,保护数据的关键步骤
【10月更文挑战第24天】在数字化时代,网络安全和信息安全是维护个人隐私和企业资产的前线防线。本文将探讨网络安全中的常见漏洞、加密技术的重要性以及如何通过提高安全意识来防范潜在的网络威胁。我们将深入理解网络安全的基本概念,学习如何识别和应对安全威胁,并掌握保护信息不被非法访问的策略。无论你是IT专业人士还是日常互联网用户,这篇文章都将为你提供宝贵的知识和技能,帮助你在网络世界中更安全地航行。
|
1月前
|
数据采集 Java API
java怎么设置代理ip:简单步骤,实现高效网络请求
本文介绍了在Java中设置代理IP的方法,包括使用系统属性设置HTTP和HTTPS代理、在URL连接中设置代理、设置身份验证代理,以及使用第三方库如Apache HttpClient进行更复杂的代理配置。这些方法有助于提高网络请求的安全性和灵活性。
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
2月前
|
并行计算 开发工具 异构计算
在Windows平台使用源码编译和安装PyTorch3D指定版本
【10月更文挑战第6天】在 Windows 平台上,编译和安装指定版本的 PyTorch3D 需要先安装 Python、Visual Studio Build Tools 和 CUDA(如有需要),然后通过 Git 获取源码。建议创建虚拟环境以隔离依赖,并使用 `pip` 安装所需库。最后,在源码目录下运行 `python setup.py install` 进行编译和安装。完成后即可在 Python 中导入 PyTorch3D 使用。
280 0
|
4月前
|
SQL 安全 网络安全
网络安全漏洞与加密技术:提升安全意识的关键步骤
【8月更文挑战第31天】在数字时代的浪潮中,网络安全已成为保护个人信息和企业资产的前沿防线。本文将深入探讨网络安全的重要性,分析常见的安全漏洞,介绍加密技术如何加固这道防线,并强调提升个人和组织的安全意识的必要性。通过实例和代码演示,我们将一窥网络防御的艺术,旨在启发读者构建更安全的网络环境。
|
4月前
|
监控 安全 iOS开发
|
4月前
|
网络协议 网络架构
OSPF邻居关系建立失败?揭秘网络工程师面试中最常见的难题,这些关键步骤你掌握了吗?网络配置的陷阱就在这里!
【8月更文挑战第19天】OSPF是网络工程中确保数据高效传输的关键协议。但常遇难题:路由器间无法建立OSPF邻居关系,影响网络稳定并成为面试热点。解决此问题需检查网络连通性(如使用`ping`),确认OSPF区域配置一致(通过`show running-config`),校准Hello与Dead计时器(配置`hello`和`dead`命令),及核查IP地址和子网掩码正确无误(使用`ip address`)。系统排查上述因素可确保OSPF稳定运行。
85 2
下一篇
DataWorks