DCGAN代码解析
今天我们将对GAN领域中经典的论文DCGAN做一个简单的解析。
1 初始化
import argparse import os import numpy as np import torch
parser = argparse.ArgumentParser() parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling") opt = parser.parse_args(args=[]) print(opt)
Namespace(b1=0.5, b2=0.999, batch_size=64, channels=1, img_size=32, latent_dim=100, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=400)
2 数据加载
加载后的数据为 32 * 32 的灰度图
import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets # Configure data loader os.makedirs("../../data/mnist", exist_ok=True) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist", train=True, download=True, transform=transforms.Compose( [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ), ), batch_size=opt.batch_size, shuffle=True, )
from torch.autograd import Variable import matplotlib.pyplot as plt def show_img(img, trans=True): if trans: img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0)) # 把channel维度放到最后 plt.imshow(img[:, :, 0], cmap="gray") else: plt.imshow(img, cmap="gray") plt.show() mnist = datasets.MNIST("../../data/mnist") for i in range(3): sample = mnist[i][0] show_img(np.array(sample), trans=False)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eozTGysh-1664249499185)(test_files/test_6_0.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kI2wdLrr-1664249499187)(test_files/test_6_1.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eC4sWl8G-1664249499188)(test_files/test_6_2.png)]
trans_resize = transforms.Resize(opt.img_size) trans_to_tensor = transforms.ToTensor() trans_normalize = transforms.Normalize([0.5], [0.5]) # x_n = (x - 0.5) / 0.5 print("shape =", np.array(sample).shape, '\n') print("data =", np.array(sample), '\n') sample_resize = trans_resize(sample) print("(trans_resize) shape =", np.array(sample_resize).shape, '\n') sample_tensor = trans_to_tensor(sample_resize) print("(trans_to_tensor) data =", sample_tensor, '\n') sample_normalize = trans_normalize(sample_tensor) print("(trans_normalize) data =", sample_normalize, '\n')
shape = (28, 28) data = [[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 67 232 39 0 0 0 0 0] [ 0 0 0 0 62 81 0 0 0 0 0 0 0 0 0 0 0 0 0 0 120 180 39 0 0 0 0 0] [ 0 0 0 0 126 163 0 0 0 0 0 0 0 0 0 0 0 0 0 2 153 210 40 0 0 0 0 0] [ 0 0 0 0 220 163 0 0 0 0 0 0 0 0 0 0 0 0 0 27 254 162 0 0 0 0 0 0] [ 0 0 0 0 222 163 0 0 0 0 0 0 0 0 0 0 0 0 0 183 254 125 0 0 0 0 0 0] [ 0 0 0 46 245 163 0 0 0 0 0 0 0 0 0 0 0 0 0 198 254 56 0 0 0 0 0 0] [ 0 0 0 120 254 163 0 0 0 0 0 0 0 0 0 0 0 0 23 231 254 29 0 0 0 0 0 0] [ 0 0 0 159 254 120 0 0 0 0 0 0 0 0 0 0 0 0 163 254 216 16 0 0 0 0 0 0] [ 0 0 0 159 254 67 0 0 0 0 0 0 0 0 0 14 86 178 248 254 91 0 0 0 0 0 0 0] [ 0 0 0 159 254 85 0 0 0 47 49 116 144 150 241 243 234 179 241 252 40 0 0 0 0 0 0 0] [ 0 0 0 150 253 237 207 207 207 253 254 250 240 198 143 91 28 5 233 250 0 0 0 0 0 0 0 0] [ 0 0 0 0 119 177 177 177 177 177 98 56 0 0 0 0 0 102 254 220 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169 254 137 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169 254 57 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169 254 57 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169 255 94 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169 254 96 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169 254 153 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169 255 153 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 96 254 153 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]] (trans_resize) shape = (32, 32) (trans_to_tensor) data = tensor([[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]]) (trans_normalize) data = tensor([[[-1., -1., -1., ..., -1., -1., -1.], [-1., -1., -1., ..., -1., -1., -1.], [-1., -1., -1., ..., -1., -1., -1.], ..., [-1., -1., -1., ..., -1., -1., -1.], [-1., -1., -1., ..., -1., -1., -1.], [-1., -1., -1., ..., -1., -1., -1.]]])
3 模型
3.1生成器
包含1个全连接层和3个卷积层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm和Upsample
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y8xGwXe6-1664249499189)(figures/BN.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FeR2L0uH-1664249499190)(figures/resize.png)]
import torch.nn as nn class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.init_size = opt.img_size // 4 self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), nn.Tanh(), ) def forward(self, z): out = self.l1(z) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img generator = Generator() print(generator)
Generator( (l1): Sequential( (0): Linear(in_features=100, out_features=8192, bias=True) ) (conv_blocks): Sequential( (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): Upsample(scale_factor=2, mode=nearest) (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) (4): LeakyReLU(negative_slope=0.2, inplace) (5): Upsample(scale_factor=2, mode=nearest) (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) (8): LeakyReLU(negative_slope=0.2, inplace) (9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (10): Tanh() ) )
3.2判别器
包含4个卷积层和1个全连接层,使用LeakyReLU和Sigmoid激活函数,使用了Dropout和BatchNorm,使用Strided Conv进行下采样
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OlQtOQTB-1664249499191)(figures/strided.png)]
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True): block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] if bn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block self.model = nn.Sequential( *discriminator_block(opt.channels, 16, bn=False), *discriminator_block(16, 32), *discriminator_block(32, 64), *discriminator_block(64, 128), ) # The height and width of downsampled image ds_size = opt.img_size // 2 ** 4 self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) def forward(self, img): out = self.model(img) out = out.view(out.shape[0], -1) validity = self.adv_layer(out) return validity discriminator = Discriminator() print(discriminator)
Discriminator( (model): Sequential( (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): LeakyReLU(negative_slope=0.2, inplace) (2): Dropout2d(p=0.25) (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (4): LeakyReLU(negative_slope=0.2, inplace) (5): Dropout2d(p=0.25) (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (8): LeakyReLU(negative_slope=0.2, inplace) (9): Dropout2d(p=0.25) (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (12): LeakyReLU(negative_slope=0.2, inplace) (13): Dropout2d(p=0.25) (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) ) (adv_layer): Sequential( (0): Linear(in_features=512, out_features=1, bias=True) (1): Sigmoid() ) )
3.3初始化
对卷积层和BatchNorm层进行参数初始化
def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) # Initialize weights generator.apply(weights_init_normal) discriminator.apply(weights_init_normal)
Discriminator( (model): Sequential( (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): LeakyReLU(negative_slope=0.2, inplace) (2): Dropout2d(p=0.25) (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (4): LeakyReLU(negative_slope=0.2, inplace) (5): Dropout2d(p=0.25) (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (8): LeakyReLU(negative_slope=0.2, inplace) (9): Dropout2d(p=0.25) (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (12): LeakyReLU(negative_slope=0.2, inplace) (13): Dropout2d(p=0.25) (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True) ) (adv_layer): Sequential( (0): Linear(in_features=512, out_features=1, bias=True) (1): Sigmoid() ) )