Pytorch和DCGAN生成肖像画(下)

简介: Pytorch和DCGAN生成肖像画

生成器

我们的生成器模型将完全遵循DCGAN的完全相同的体系结构。

640.png

与生成器类似,但是相反,我们不是将图像卷积为一组特征,而是采用随机噪声输入“ z”并运行一系列反卷积以达到所需的图像形状(3,64,64 )。

classGenerator(nn.Module):
def__init__(self, channels_noise, channels_img, features_g):
super(Generator, self).__init__()
self.net=nn.Sequential(
#Input: Nxchannels_noisex1x1self._block(channels_noise, features_g*16, 4, 1, 0), #img: 4x4self._block(features_g*16, features_g*8, 4, 2, 1), #img: 8x8self._block(features_g*8, features_g*4, 4, 2, 1), #img: 16x16self._block(features_g*4, features_g*2, 4, 2, 1), #img: 32x32nn.ConvTranspose2d(
features_g*2, channels_img, kernel_size=4, stride=2, padding=1            ),
#Output: Nxchannels_imgx64x64nn.Tanh(),
        )
def_block(self, in_channels, out_channels, kernel_size, stride, padding):
returnnn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
        )
defforward(self, x):
returnself.net(x)

最后,我们只需要在训练之前将initialize_weights()函数添加到两个模型即可:

definitialize_weights(model):
#InitializesweightsaccordingtotheDCGANpaperforminmodel.modules():
ifisinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)

训练

现在我们有了数据和模型,我们可以从train.py文件开始,在这里我们将加载数据,初始化模型并运行主训练循环。

超参数和导入

importnumpyasnpimporttorchimporttorch.nnasnnimporttorch.optimasoptimimporttorchvisionimporttorchvision.datasetsasdatasetsimporttorchvision.transformsastransformsfromtorch.utils.dataimportDatasetfromtorch.utils.dataimportDataLoaderfromtorch.utils.tensorboardimportSummaryWriterfrommodelimportDiscriminator, Generator, initialize_weightsimportrandomimportosimportnatsortfromPILimportImage, ImageOps, ImageEnhancedevice=torch.device("cuda"iftorch.cuda.is_available() else"cpu")
D_LEARNING_RATE=2e-4G_LEARNING_RATE=1e-4BATCH_SIZE=64IMAGE_SIZE=64CHANNELS_IMG=3NOISE_DIM=128NUM_EPOCHS=100FEATURES_DISC=64FEATURES_GEN=64

两个优化器可以使用相同的学习率,但是我发现对鉴别器使用稍高的学习率被证明是更有效的。在更复杂的数据集上,我发现较小的批次大小(例如16或8)可以帮助避免过度拟合。

随机增强

改善GAN训练并从数据集中获得最大收益的技术之一是应用随机图像增强。在原始论文中,它们还提供了一种在生成器端还原增强图像的机制,因为我们不希望生成器生成增强图像。但是,在这种情况下,我认为应用这些简单的增幅就足够了,而这些增幅并不会真正影响画质。如果我们试图获得照片般逼真的结果,那么使用它来进行全面实施可能是一个更好的主意。

defrandom_augmentation(img):
#randommirroring/flippingimagerand_mirror=random.randint(0,1)
#randomsaturationadjustmentrand_sat=random.uniform(0.5,1.5)
#randomsharpnessadjustmentrand_sharp=random.uniform(0.5,1.5)
converter=ImageEnhance.Color(img)
img=converter.enhance(rand_sat)
converter=ImageEnhance.Sharpness(img)
img=converter.enhance(rand_sharp)
ifrand_mirror==0:
img=ImageOps.mirror(img)
returnimg

我们对图像执行3个操作-镜像,饱和度调整和锐度调整。镜像图像对我们的图像质量没有影响,因为我们只是在翻转图像。对于饱和度和清晰度,我使用了一个较小的系数范围(0.5、1.5),以免对原始图像造成很大的影响。

数据加载器

为了应用我们之前构建的随机增强方法并加载数据,我编写了一个使用其下定义的转换的自定义数据集。

classCustomDataSet(Dataset):
def__init__(self, main_dir, transform):
self.main_dir=main_dirself.transform=transformall_imgs=os.listdir(main_dir)
self.total_imgs=natsort.natsorted(all_imgs)
def__len__(self):
returnlen(self.total_imgs)
def__getitem__(self, idx):
img_loc=os.path.join(self.main_dir, self.total_imgs[idx])
image=Image.open(img_loc).convert('RGB')
image=random_augmentation(image)
tensor_image=self.transform(image)
returntensor_imagetransforms=transforms.Compose(
    [
transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
            [0.5for_inrange(CHANNELS_IMG)], [0.5for_inrange(CHANNELS_IMG)]
        ),
    ]
)
dataset=CustomDataSet("./data/128_portraits/", transform=transforms)
dataloader=DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

训练

最后,我们可以初始化我们的网络并开始对其进行训练。对于鉴别器训练,我使用均方误差作为损失函数。我也尝试使用二进制交叉熵,但MSELoss最有效。在训练循环之前,我们还初始化张量板编写器以在tensorboard上实时查看我们的图像。

gen=Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc=Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)
###uncommenttoworkfromsavedmodels####gen.load_state_dict(torch.load('saved_models/generator_model.pt'))
#disc.load_state_dict(torch.load('saved_models/discriminator_model.pt'))
opt_gen=optim.Adam(gen.parameters(), lr=G_LEARNING_RATE, betas=(0.5, 0.99))
opt_disc=optim.Adam(disc.parameters(), lr=D_LEARNING_RATE, betas=(0.5, 0.99))
criterion=nn.MSELoss()
fixed_noise=torch.randn(16, NOISE_DIM, 1, 1).to(device)
writer_real=SummaryWriter(f"logs/real")
writer_fake=SummaryWriter(f"logs/fake")
step=0gen.train()
disc.train()
forepochinrange(NUM_EPOCHS):
forbatch_idx, realinenumerate(dataloader):
real=real.to(device)
noise=torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
fake=gen(noise)
###TrainDiscriminatordisc_real=disc(real).reshape(-1)
loss_disc_real=criterion(disc_real, torch.ones_like(disc_real))
disc_fake=disc(fake.detach()).reshape(-1)
loss_disc_fake=criterion(disc_fake, torch.zeros_like(disc_fake))
loss_disc= (loss_disc_real+loss_disc_fake) /2disc.zero_grad()
loss_disc.backward()
opt_disc.step()
###TrainGeneratorusingfeaturematchingoutput=disc(fake).reshape(-1)
loss_gen=criterion(output, torch.ones_like(output))
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
#Printlossesoccasionallyandprinttotensorboardifbatch_idx%10==0:                
torch.save(gen.state_dict(), 'generator_model.pt')
torch.save(disc.state_dict(), 'discriminator_model.pt')
print(
f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"            )
withtorch.no_grad():
fake=gen(fixed_noise)
img_grid_real=torchvision.utils.make_grid(
real[:16], normalize=True                )
img_grid_fake=torchvision.utils.make_grid(
fake[:16], normalize=True                )
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step+=1

在训练的第一部分中,我们使用MSELoss在真实图像和伪图像上训练鉴别器。之后,我们使用特征匹配来训练我们的生成器。之前,我们在鉴别器的前向传递中添加了变量“ feature_matching”,以从图像中提取感知特征。在传统的DCGAN中,您只需训练生成器以伪造的图像来欺骗鉴别器,而在这里,我们试图训练生成器以生成与真实图像的特征紧密匹配的图像。此技术通常可以提高训练的稳定性。

经过100个批次后,我获得了以下结果。我尝试对模型进行更多的迭代训练,但是图像质量没有太大改善。

640.png

结论与最终想法

本文的目的是记录我从事该项目的过程。尽管在线上有很多资源和论文探讨了这个令人兴奋的概念的不同方面,但我发现有些东西是只能通过经验学习的……与其他任何东西一样。但我希望您能在本文中找到一些可以在自己的GAN项目中应用或试验的东西。由于我们获得的结果并不完美,因此我打算应用本文中提出的EvolGAN来优化我的生成器。

目录
相关文章
|
8月前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch实现DCGAN(生成对抗网络)生成新的假名人照片实战(附源码和数据集)
PyTorch实现DCGAN(生成对抗网络)生成新的假名人照片实战(附源码和数据集)
138 1
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch和DCGAN生成肖像画(上)
Pytorch和DCGAN生成肖像画
143 0
Pytorch和DCGAN生成肖像画(上)
|
机器学习/深度学习 存储 并行计算
一个快速构造GAN的教程:如何用pytorch构造DCGAN(下)
一个快速构造GAN的教程:如何用pytorch构造DCGAN
175 0
一个快速构造GAN的教程:如何用pytorch构造DCGAN(下)
|
机器学习/深度学习 存储 PyTorch
一个快速构造GAN的教程:如何用pytorch构造DCGAN(上)
一个快速构造GAN的教程:如何用pytorch构造DCGAN
175 0
一个快速构造GAN的教程:如何用pytorch构造DCGAN(上)
|
3月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
528 2
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
72 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
3月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
108 7
利用 PyTorch Lightning 搭建一个文本分类模型
|
3月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
231 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
4月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
329 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
4月前
|
机器学习/深度学习 监控 PyTorch
PyTorch 模型调试与故障排除指南
在深度学习领域,PyTorch 成为开发和训练神经网络的主要框架之一。本文为 PyTorch 开发者提供全面的调试指南,涵盖从基础概念到高级技术的内容。目标读者包括初学者、中级开发者和高级工程师。本文探讨常见问题及解决方案,帮助读者理解 PyTorch 的核心概念、掌握调试策略、识别性能瓶颈,并通过实际案例获得实践经验。无论是在构建简单神经网络还是复杂模型,本文都将提供宝贵的洞察和实用技巧,帮助开发者更高效地开发和优化 PyTorch 模型。
63 3
PyTorch 模型调试与故障排除指南

热门文章

最新文章