3 具体实现
使用原生GAN实现
加载MNIST数据
import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets # Configure data loader os.makedirs("../../data/mnist", exist_ok=True) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist", train=True, download=True, transform=transforms.Compose( [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ), ), batch_size=opt.batch_size, shuffle=True, )
这里随机取几张图片观察。
def show_img(img, trans=True): if trans: img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0)) # 把channel维度放到最后 plt.imshow(img[:, :, 0], cmap="gray") else: plt.imshow(img, cmap="gray") plt.show() mnist = datasets.MNIST("../../data/mnist")
构建生成器
仿照下图的原生GAN的结构来搭建。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-un1cmauM-1634715042902)(https://i.loli.net/2021/10/19/HYN87qkdefZhmyl.png)]
我们的生成器包含5个全连接层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm。
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) def forward(self, z): img = self.model(z) img = img.view(img.size(0), *img_shape) return img
结构如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-b7qAO5Zw-1634715042904)(https://i.loli.net/2021/10/18/ZsElTonhgqWweQv.png)]
构建判别器
仿照原生GAN,使用全连接网络,把Maxout激活函数换为ReLU与Sigmoid。
包含3个全连接层,使用LeakyReLU和Sigmoid激活函数。
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, img): img_flat = img.view(img.size(0), -1) validity = self.model(img_flat) return validity discriminator = Discriminator() print(discriminator)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TpTQFtR8-1634715042908)(https://i.loli.net/2021/10/18/4ALVpd8lhnOPWzi.png)]
损失函数与优化
判别器使用 Binary Cross Entropy Loss。
优化都使用Adam,lr = 0.0002。
optimizer_G = torch**.**optim**.**Adam(generator**.**parameters(), lr=opt**.**lr, betas=(opt**.**b1, opt**.**b2)) optimizer_D = torch**.**optim**.**Adam(discriminator**.**parameters(), lr=opt**.**lr, betas=(opt**.**b1, opt**.**b2))
随机采样
从100维的正态分布中采样作为z。
一个batch有64组输入。
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
交替训练
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) real_imgs = Variable(imgs.type(Tensor)) #更新生成器 optimizer_G.zero_grad() #采样z z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) gen_imgs = generator(z) #生成器权值更新 g_loss = adversarial_loss(discriminator(gen_imgs), valid) g_loss.backward() optimizer_G.step() #更新判别器 optimizer_D.zero_grad() real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step()
生成结果
每400次迭代观察一次当前生成图像。
最开始,生成全是杂讯。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NHUs123v-1634715042910)(https://i.loli.net/2021/10/19/gdMUDkeC42OJ6sR.png)]
开始设置的epoch数很少,结果很差,下图是第6000次迭代的结果:
20000次:
100000次:
200个epoch以后,也就是十八万多次迭代以后的最终结果:
感觉没有很好的结果,还需要继续train下去,但没有继续尝试了。
使用CNN+GAN实现
更改生成网络结构
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), nn.Tanh(), ) def forward(self, z): out = self.l1(z) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img
网络结构为:
更改判别网络结构
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True): block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] if bn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block self.model = nn.Sequential( *discriminator_block(opt.channels, 16, bn=False), *discriminator_block(16, 32), *discriminator_block(32, 64), *discriminator_block(64, 128), ) ds_size = opt.img_size // 2 ** 4 self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
网络结构为:
训练过程
生成结果
比用原生GAN的结果好很多。
比如:
第6000次迭代:
第20000次迭代:
]
第100个epoch:
第120个epoch:
观察linearly interpolating结果
随机选两个点,在两点中取10个点观察变化过程:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor g = torch.load('model/generator.pkl') z = Variable(Tensor(np.random.normal(0, 1, (2, 100)))) a = torch.FloatTensor(100, 20) for i in range(100): a[i] = torch.linspace(z[0][i], z[1][i], 10) b = Variable(a.t()) b = b.to('cuda') gen_imgs = g(b) save_image(gen_imgs.data[:], "images_trans.png", normalize=True)
再次尝试观察更细致的变化:
]
使用CGAN实现
为了可以控制输出我们可以使用CGAN
在原生GAN结构基础上,更改网络结构如下:
更改生成网络结构
更改判别网络结构
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes) self.model = nn.Sequential( nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 1), ) def forward(self, img, labels): d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1) validity = self.model(d_in) return validity
结构如下:
交替训练
把标签引入训练。
batch_size = imgs.shape[0] valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真 fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假 optimizer_G.zero_grad() gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size))) #更新生成器 gen_imgs = generator(z, gen_labels) print("gen_imgs =") for img in gen_imgs[:3]: show_img(img) validity = discriminator(gen_imgs, gen_labels) g_loss = adversarial_loss(validity, valid) print("g_loss =", g_loss, '\n') g_loss.backward() optimizer_G.step() #更新判别器 optimizer_D.zero_grad() validity_real = discriminator(real_imgs, labels) d_real_loss = adversarial_loss(validity_real, valid) validity_fake = discriminator(gen_imgs.detach(), gen_labels) d_fake_loss = adversarial_loss(validity_fake, fake) d_loss = (d_real_loss + d_fake_loss) / 2 print("real_loss =", d_real_loss, '\n') print("fake_loss =", d_fake_loss, '\n') print("d_loss =", d_loss, '\n') d_loss.backward() optimizer_D.step()
生成结果:
100个epoch后的结果