使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

简介: 此项目使用的是著名的celebA(CelebFaces Attribute)数据集。其包含10,177个名人身份的202,599张人脸图片,每张图片都做好了特征标记,包含人脸bbox标注框、5个人脸特征点坐标以及40个属性标记,数据由香港中文大学开放提供(不包含商业用途的使用)。

1 数据集描述


此项目使用的是著名的celebA(CelebFaces Attribute)数据集。其包含10,177个名人身份的202,599张人脸图片,每张图片都做好了特征标记,包含人脸bbox标注框、5个人脸特征点坐标以及40个属性标记,数据由香港中文大学开放提供(不包含商业用途的使用)。

ac4268ff3ede44a5ae9aaf486e88fcaf.png



在实际训练前,已经将数据处理成了HDF5的数据集格式。使用h5py处理HDF5数据集可以提供很多方便,使得数据处理更加高效、灵活、可扩展,显著提升训练过程的文件读取速度。可以使用h5py包自行对数据进行处理,也可直接下载我已经处理好的HDF5数据格式。

如需了解更多h5py相关知识,可以查看HDF5补充知识。


2 GPU设置


前面几篇博客的内容,都是对手写数字这个数据集的处理,CPU还能吃得消。这次数据输入明显增加,需要使用GPU处理数据。如电脑无NAVIDIA独显,建议使用Google Colab执行代码,Colab提供了免费的GPU算力。


if torch.cuda.is_available():
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  print("using cuda:", torch.cuda.get_device_name(0))
  pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


这段代码的作用是,如果当前设备有可用的CUDA,则将默认的张量类型设置为CUDA浮点张量并输出使用的CUDA设备的名称。然后,它将设备设置为CUDA设备(如果有)或CPU。


3 设置Dataset类


基于面向对象编程的基本原则,我们建立一个Dataset类,使类具有数据读取、获取指定索引的数据与绘制指定索引的图像,具体代码如下:


class CelebADataset(Dataset):
    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')
        self.dataset = self.file_object['img_align_celeba']
        pass
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        if index >= len(self.dataset):
            raise IndexError()
        img = numpy.array(self.dataset[str(index) + '.jpg'])
        return torch.cuda.FloatTensor(img) / 255.0
    def plot_image(self, index):
        plt.imshow(numpy.array(self.dataset[str(index) + '.jpg']), interpolation='nearest')
        plt.show()


在获取指定索引对应的数据时,如果指定数大于索引的最大值,我们命令程序返回一个IndexError()错误,以便于快速查找问题所在。

为了理解这一数据类,我们对类进行使用:

celeba_dataset = CelebADataset('文件地址.h5py')


这里创建了一个名为celeba_dataset 的CelebADataset类,并传入了文件的所在路径file。在__init__中,使用h5py.File方法读取路经所在的文件。


celeba_dataset.plot_image(66)

绘制数据集中66.jpg图形。如果前面代码正确,此处将绘制出数据集中的人脸头像。如果为绘制出图形并产生报错,考虑路径是否有误以及数据格式是否正确。4471d2bdd6464396b51cfcd0c11bd769.png

4 设置辨别器类


本项目的核心类为鉴别器类与生成器类,下面开始编写鉴别器类。首先建立神经网络框架:

class Discriminator(nn.Module):
    def __init__(self):
        # 父类继承
        super().__init__()
        # 神经网络定义
        self.model = nn.Sequential(
            View(218 * 178 * 3),
            nn.Linear(3 * 218 * 178, 100),
            nn.LeakyReLU(),
            nn.LayerNorm(100),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )
        # 创建损失函数
        self.loss_function = nn.BCELoss()
        # 创建优化器
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
        # 初始化计数器
        self.counter = 0
        self.progress = []



这段代码定义了一个名为Discriminator的类,继承了PyTorch中nn.Module类。在__init__函数中,通过nn.Sequential定义了一个神经网络模型,包括三个线性层,两个激活函数,一个归一化层。一开始的View(218*178*3)是新代码。它的作用是将大小为(218, 178, 3) 的三维图像张量重塑成一个长度为218×178×3的一维张量。基于自上而下的编程习惯,我们会在后面对View进行定义。

在此基础上,定义了损失函数nn.BCELoss()和优化器Adam,并定义了一个计数器和一个存储进度的列表。


class Discriminator(nn.Module):
    def forward(self, inputs):
        # simply run model
        return self.model(inputs)
    def train(self, inputs, targets):
        # calculate the output of the network
        outputs = self.forward(inputs)
        # calculate loss
        loss = self.loss_function(outputs, targets)
        # increase counter and accumulate error every 10
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
        # 梯度归零,向后传递,优化执行
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))


接下来定义forward功能,train功能,plot_progress功能。在forward()函数中,它只是让模型对输入数据进行前向传播并返回网络的输出。在train()函数中,它使用输入数据和目标数据来计算网络的损失,并使用优化器来更新网络的参数。最后,plot_progress()函数可以用来绘制训练进度。以上类方法与手写字体识别博文中的定义完全相同,如有需要可找到对应博文查看。


5 辅助函数与辅助类


class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape, # 逗号不是多打的,代表这是元组
    def forward(self, x):
        return x.view(*self.shape)


在前面定义鉴别器类时,我们已经使用了View,此处对View进行补充定义。在 forward 方法中,它对输入的 x 应用了 view 方法,并将 shape 属性作为参数传入。这个模型的作用是将输入的张量的形状调整为 shape 属性所指定的形状。


def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data
def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data


以上两个随机张量生成器,其作用与手写数字识别中的作用完全相同,在此不再赘述。后续在使用时也会再进行介绍。

截至目前,我们已经建立好了模型所必需的鉴别器类与Dataset类。下一篇会讲解最重要的鉴别器类以及对模型的训练与使用。


目录
打赏
0
0
0
0
21
分享
相关文章
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
50 1
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
221 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
RT-DETR改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR LSKNet (附网络详解和完整配置步骤)
RT-DETR改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR LSKNet (附网络详解和完整配置步骤)
68 13
RT-DETR改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR LSKNet (附网络详解和完整配置步骤)
RT-DETR改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR ConvNeXt V2 (附网络详解和完整配置步骤)
RT-DETR改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR ConvNeXt V2 (附网络详解和完整配置步骤)
74 11
RT-DETR改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR ConvNeXt V2 (附网络详解和完整配置步骤)
YOLOv11改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
YOLOv11改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
90 13
YOLOv11改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
YOLOv11改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR LSKNet (附网络详解和完整配置步骤)
YOLOv11改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR LSKNet (附网络详解和完整配置步骤)
153 0
YOLOv11改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR LSKNet (附网络详解和完整配置步骤)
YOLOv11改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR ConvNeXt V2 (附网络详解和完整配置步骤)
YOLOv11改进策略【Backbone/主干网络】| 替换骨干网络为2023-CVPR ConvNeXt V2 (附网络详解和完整配置步骤)
113 0
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
网络诊断指南:网络故障排查步骤与技巧
网络诊断指南:网络故障排查步骤与技巧
911 7
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
92 5

热门文章

最新文章