使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上

简介: 使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上

阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读。

相关资料链接为:使用PyTorch构建GAN生成对抗

本次训练代码使用了本地GPU计算。


1 CelebADataset类的修改


原则上这一类不需要修改,但为了提升模型运行速度,可以对图片周边适当裁剪,保留五官等重要内容。


# 设置裁剪功能(辅助函数)
def crop_centre(img, new_width, new_height):
    height, width, _ = img.shape
    startx = width//2 - new_width//2
    starty = height//2 - new_height//2
    return img[  starty:starty + new_height, startx:startx + new_width, :]


上面这个函数可以用来从图像的中心裁剪。该函数接收三个参数:


img:原始图像,需要是 numpy 数组形式

new_width:裁剪后图像的新宽度

new_height:裁剪后图像的新高度

该函数通过计算原始图像的中心位置,以及所需裁剪图像的起始位置,从而在 numpy 数组上实现裁剪。最后,函数返回裁剪后的图像。

有了这个函数后,可以在类中预置对图像的裁剪功能,需要对类的__getitem__方法和plot_image方法进行优化。


class CelebADataset(Dataset):
    def __getitem__(self, index):
        if index >= len(self.dataset):
            raise IndexError()
        img = numpy.array(self.dataset[str(index) + '.jpg'])
        img = crop_centre(img, 128, 128)
        return torch.cuda.FloatTensor(img).permute(2,0,1).view(1,3,128,128) / 255.0
    def plot_image(self, index):
        img = numpy.array(self.dataset[str(index)+'.jpg'])
        img = crop_centre(img, 128, 128)
        plt.imshow(img, interpolation='nearest')


2 鉴别器类的修改


鉴别器的网络结构是卷积GAN需要重点修改的地方。此次的卷积GAN设置了3个卷积层和1个全连接层。


class Discriminator(nn.Module):
    def __init__(self): 
  self.model = nn.Sequential(
      nn.Conv2d(3, 256, kernel_size=8, stride=2),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2),
      nn.Conv2d(256, 256, kernel_size=8, stride=2),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2),
      nn.Conv2d(256, 3, kernel_size=8, stride=2),
      nn.LeakyReLU(0.2),
      View(3*10*10),
      nn.Linear(3*10*10, 1),
      nn.Sigmoid()
  )



经过裁剪的图片的小为128*128;

第一个卷积层使用了256个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出256个特征图,特征图的大小为 128 − 8 2 + 1 \frac{128-8}{2}+1 2128−8 +1 ,即61*61;

第二个卷积层使用了256个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出256个特征图,特征图的大小为 61 − 8 2 + 1 \frac{61-8}{2}+1 261−8 +1 ,即27*27;

第二个卷积层使用了3个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出3个特征图,特征图的大小为 27 − 8 2 + 1 \frac{27-8}{2}+1 227−8 +1 ,即10*10;

经过了3层的卷积后,图片的大小已经降到了(3*10*10)。


3 鉴别器测试


修改完鉴别器之后,可以使用真实图像和随即图像,初步判断鉴别器的能力与测试这部分修改后的代码是否存在BUG。


# 鉴别器类建立
D = Discriminator()
D.to(device)
# 测试鉴别器
for image_data_tensor in celeba_dataset:
    # real data
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))
    # fake data
    D.train(generate_random_image((1,3,128,128)), torch.cuda.FloatTensor([0.0]))
    pass


同样,可以查看损失函数的变化情况并使用测试集进行测试。

for image_data_tensor in celeba_dataset:
    # real data
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))
    # fake data
    D.train(generate_random_image((1,3,128,128)), torch.cuda.FloatTensor([0.0]))
    pass
D.plot_progress()
for i in range(4):
  image_data_tensor = celeba_dataset[random.randint(0,20000)]
  print( D.forward( image_data_tensor ).item() )
  pass
for i in range(4):
  print( D.forward( generate_random_image((1,3,128,128))).item() )
  pass


5db66f404c8b4cb7b3d7739c32b2d2c0.png

可以看出,鉴别器对于数据的判断非常有信息。


之后还需对生成器进行同步修改,并使用代码生成图像,这部分内容将放在下篇。


相关文章
|
21天前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
2月前
|
并行计算 开发工具 异构计算
在Windows平台使用源码编译和安装PyTorch3D指定版本
【10月更文挑战第6天】在 Windows 平台上,编译和安装指定版本的 PyTorch3D 需要先安装 Python、Visual Studio Build Tools 和 CUDA(如有需要),然后通过 Git 获取源码。建议创建虚拟环境以隔离依赖,并使用 `pip` 安装所需库。最后,在源码目录下运行 `python setup.py install` 进行编译和安装。完成后即可在 Python 中导入 PyTorch3D 使用。
299 0
|
6月前
|
机器学习/深度学习 并行计算 PyTorch
安装PyTorch详细步骤
安装PyTorch时,选择CPU或GPU版本。有Nvidia显卡需装CUDA和cuDNN,可从NVIDIA官网下载CUDA 11.8和对应版本cuDNN。无Nvidia显卡则安装CPU版。安装PyTorch通过conda或pip,GPU版指定`cu118`或`rocm5.4.2`镜像源。验证安装成功使用`torch._version_`和`torch.cuda.is_available()`。
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
|
6月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
2月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
388 2
|
23天前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
42 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
2月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
76 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
2月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
131 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力