生成对抗网络——GAN

简介: 生成对抗网络——GAN

1.生成模型训练


import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch import optim
import os
# 设置超参数
batch_size = 64     
learning_rate = 0.0002
epochsize = 100
sample_dir = "images"
# 创建生成图像的目录
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# 生成器结构
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh())
    def forward(self, x):
        output = self.model(x)      # torch.Size([batch_size, 784])
        output = output.view(x.size(0), 1, 28, 28)   # torch.Size([batch_size, 1, 28, 28])
        return output
# 鉴别器结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid())   # 使用BCELoss注意添加Sigmoid层,BCEWithLogitsLoss则不用
    def forward(self, x):
        x = x.view(x.size(0), -1)  # torch.Size([batch_size, 784])
        output = self.model(x)     # torch.Size([batch_size, 1])
        return output
# 训练集下载
mnist_traindata = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],
                         std=[0.5])
]), download=False)
mnist_train = DataLoader(mnist_traindata, batch_size=batch_size, shuffle=True)
# GPU加速
device = torch.device('cuda')
G = Generator().to(device)
D = Discriminator().to(device)
# 导入之前的训练模型
# G.load_state_dict(torch.load('G.ckpt'))
# D.load_state_dict(torch.load('D.ckpt'))
# 设置优化器与损失函数,二分类的时候使用BCELoss较好,BCEWithLogitsLoss是自带一层Sigmoid
# criteon = nn.BCEWithLogitsLoss()
criteon = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
# 设置对比标签
# realimage_label = torch.ones(batch_size, 1).to(device)  # value:1 torch.Size([128, 1])
# fakeimage_label = torch.zeros(batch_size, 1).to(device)  # value:0 torch.Size([128, 1])
# 开始训练
print("start training")
for epoch in range(epochsize):
    D_loss_total = 0
    G_loss_total = 0
    total_num = 0
    # 这里的RealImageLabel是没有用上的
    for batchidx, (realimage, _) in enumerate(mnist_train):
        realimage = realimage.to(device)
        # realimage = realimage.reshape(realimage.size(0), -1).to(device)
        # 设置标签值
        realimage_label = torch.ones(realimage.size(0), 1).to(device)  # value:1 torch.Size([128, 1])
        fakeimage_label = torch.zeros(realimage.size(0), 1).to(device)  # value:0 torch.Size([128, 1])
        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #
        # 计算出真实图像通过鉴别器之后的loss, 真实图像希望输出值为1,使用realimage_label
        realimage_output = D(realimage)  # torch.Size([128, 1])
        d_realimage_loss = criteon(realimage_output, realimage_label)  # criteon中两者的shape都是相同的
        # 计算出生成图像通过鉴别器之后的loss, 生成图像希望输出值为0,使用fakeimage_label
        z = torch.randn(realimage.size(0), 100).to(device)
        fakeimage = G(z)
        fakeimage_output = D(fakeimage)
        d_fakeimage_loss = criteon(fakeimage_output, fakeimage_label)  # criteon中两者的shape都是相同的
        # 总的损失为两者相加
        D_loss = d_realimage_loss + d_fakeimage_loss
        # 参数训练三个步骤
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()
        # 计算一次epoch的总损失
        D_loss_total += D_loss
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #
        # 计算出生成图像通过鉴别器之后的loss, 生成图像希望输出值为1,使用realimage_label
        z = torch.randn(realimage.size(0), 100).to(device)
        fakeimage = G(z)
        fakeimage_output = D(fakeimage)
        # 生成器只有一个损失
        G_loss = criteon(fakeimage_output, realimage_label)
        # 参数训练三个步骤
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        # 计算一次epoch的总损失
        G_loss_total += G_loss
        # 计算训练图像个数
        total_num += realimage.size(0)
        # 打印相关的loss值
        if batchidx % 200 == 0:
            print("batchidx:{}/{}, D_loss:{}, G_loss:{}, total_num:{}".format(batchidx, len(mnist_train), D_loss, G_loss, total_num))
    # 打印一次训练的loss值
    print('Epoch:{}/{}, D_loss:{}, G_loss:{}, total_num:{}'.format(epoch, epochsize, D_loss_total/len(mnist_train), G_loss_total/len(mnist_train), total_num))
    # 保存生成图像
    save_image(fakeimage, os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)), nrow=8, normalize=True)
# 保存网络结构
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')


2.生成模型结果展示


一开始训练的结果基本只是噪声

image.png

epoch1生成的图像

image.png

epoch20生成的图像

image.png

epoch60生成的图像

image.png

epoch100生成的图像


3.网络训练中出现的问题


  • 问题1:生成出来的图像全是空白噪声且毫无规律

image.png


出现这种问题可能是生成器结构的问题,其可能没有输出一个符合格式的图像,比如手写数字的shape是(1,28,28),所以生成器的输出也应该符合对应格式;

也有可能是模型出现了梯度弥散与梯度爆炸象限,归根到底是模型问题。


  • 问题2:出现类似dim=0不符合dim=1的提示错误

一般出现这种问题是输入的shape与要求的不一致,不如相关的损失函数的输入输出,以BCELoss为例,其输入与输出的shape一定是相同的,而且要为0-1之间,所以这时候在鉴别器的最后一层就要加上Sigmoid函数,或者使用BCEWithLogitsLoss。


一般用到的损失函数介绍可以查看我的这篇blog:pytorch中的损失函数


  • 问题3:batch_size的设置

以mnist数据集为例,其训练集是有60000张图像,设置batch_size为100的时候,一次训练过程只需要600轮,而且这时候,图像的张数与batch_size是整数倍的关系,所以每一轮的epoch是不会剩余未训练的图像的。而如果设置的batch_size为64,128等非整数倍,一轮训练下来肯定会剩余一些低于batch_size数量的图像。


  • 问题4:生成图像格式设置

在以上的代码中有一个小漏洞,就是明明设置了一个batch_size为64,结果生成的图像只有32个,这是因为保存的是训练过程中,最后一个剩余训练样本数量的量,就只有剩余了32个,所以最后生成器与会生成32个图像,所以在保存图像之前重新生成再保存就好了。


# 保存生成图像
z = torch.randn(batch_size, 100).to(device)
fakeimage = G(z)
save_image(fakeimage.data[:64], os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)), nrow=8, normalize=True)

image.png


可以看见生成的是8*8=64张手写数字,并且按照每行每列都是8来排列。


4.自定义损失函数


其实我们可以不适应pytorch提供的损失函数自己写一个,相关代码如下,同样可以训练起来的,但是速度比较慢,所以还是用官方提供的比较好。训练过程一样。


# 设置优化器与损失函数,二分类的时候使用BCELoss较好,BCEWithLogitsLoss是自带一层Sigmoid
# criteon = nn.BCEWithLogitsLoss()
# criteon = nn.BCELoss()
def BCELoss(output, target):
    return -1*torch.mean(target*torch.log(output) + (1-target)*torch.log(1-output))
z = torch.randn(realimage.size(0), 100)
# 训练鉴别器————总的损失为两者相加
d_realimage_loss = BCELoss(D(realimage), realimage_label)
d_fakeimage_loss = BCELoss(D(G(z)), fakeimage_label)
#print("d_realimage_loss:",d_realimage_loss)
#print("d_fakeimage_loss:", d_fakeimage_loss)

image.png


目录
相关文章
|
3月前
|
机器学习/深度学习 API 算法框架/工具
【Tensorflow+keras】Keras API两种训练GAN网络的方式
使用Keras API以两种不同方式训练条件生成对抗网络(CGAN)的示例代码:一种是使用train_on_batch方法,另一种是使用tf.GradientTape进行自定义训练循环。
37 5
|
3月前
|
机器学习/深度学习 数据可视化 算法框架/工具
【深度学习】Generative Adversarial Networks ,GAN生成对抗网络分类
文章概述了生成对抗网络(GANs)的不同变体,并对几种经典GAN模型进行了简介,包括它们的结构特点和应用场景。此外,文章还提供了一个GitHub项目链接,该项目汇总了使用Keras实现的各种GAN模型的代码。
68 0
|
5月前
|
机器学习/深度学习 自然语言处理 算法
生成对抗网络(GAN):创造与竞争的艺术
【6月更文挑战第14天】**生成对抗网络(GANs)**是深度学习中的亮点,由生成器和判别器两部分构成,通过博弈式训练实现数据生成。GAN已应用于图像生成、修复、自然语言处理和音频生成等领域,但还面临训练不稳定性、可解释性差和计算资源需求高等挑战。未来,随着技术发展,GAN有望克服这些问题并在更多领域发挥潜力。
|
4月前
|
机器学习/深度学习 PyTorch API
生成对抗网络(GAN)由两部分组成:生成器(Generator)和判别器(Discriminator)。
生成对抗网络(GAN)由两部分组成:生成器(Generator)和判别器(Discriminator)。
|
6月前
|
机器学习/深度学习 数据可视化 PyTorch
使用Python实现深度学习模型:生成对抗网络(GAN)
使用Python实现深度学习模型:生成对抗网络(GAN)
136 3
|
6月前
|
机器学习/深度学习 JavaScript 算法
深度学习500问——Chapter07:生成对抗网络(GAN)(1)
深度学习500问——Chapter07:生成对抗网络(GAN)(1)
114 3
|
6月前
|
机器学习/深度学习 人工智能 编解码
【AI 生成式】生成对抗网络 (GAN) 的概念
【5月更文挑战第4天】【AI 生成式】生成对抗网络 (GAN) 的概念
【AI 生成式】生成对抗网络 (GAN) 的概念
|
6月前
|
机器学习/深度学习
GAN网络的代码实现(学习ing)
GAN网络的代码实现(学习ing)
|
6月前
|
机器学习/深度学习 编解码 自然语言处理
深度学习500问——Chapter07:生成对抗网络(GAN)(3)
深度学习500问——Chapter07:生成对抗网络(GAN)(3)
95 0
|
6月前
|
机器学习/深度学习 JavaScript Linux
深度学习500问——Chapter07:生成对抗网络(GAN)(2)
深度学习500问——Chapter07:生成对抗网络(GAN)(2)
102 0