4 损失函数
使用 二项交叉熵(Binary Cross Entropy, BCE)Loss
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-o3tCo2KR-1664249499192)(figures/BCE-loss.png)]
# Loss function adversarial_loss = torch.nn.BCELoss()
5 Cuda加速
cuda = True if torch.cuda.is_available() else False print("cuda_is_available =", cuda) if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
cuda_is_available = True
6 优化器
使用Adam优化器
# Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) print("learning_rate =", opt.lr)
learning_rate = 0.0002
7 创建输入
分别从数据集和随机向量中获取输入
for i, (imgs, labels) in list(enumerate(dataloader))[:1]: # Configure input real_imgs = Variable(imgs.type(Tensor)) # Sample noise as generator input z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) print("i =", i, '\n') print("shape of z =", z.shape, '\n') print("shape of real_imgs =", real_imgs.shape, '\n') print("z =", z, '\n') print("real_imgs =") for img in real_imgs[:3]: show_img(img)
i = 0 shape of z = torch.Size([64, 100]) shape of real_imgs = torch.Size([64, 1, 32, 32]) z = tensor([[ 3.1224e-01, -1.1344e-01, -1.0401e+00, ..., 1.8232e-01, -1.2940e+00, 1.3365e+00], [ 7.3029e-01, 4.0669e-01, -1.3267e-01, ..., -4.9197e-01, -7.5093e-01, -1.1240e+00], [ 1.2938e+00, 7.8608e-01, 1.8455e-01, ..., -5.0269e-01, 7.9739e-01, -5.3891e-02], ..., [-7.9207e-01, -4.8256e-02, 4.5883e-01, ..., 1.2142e+00, 6.2461e-01, -1.5289e+00], [-1.4916e-03, 4.8395e-01, -3.0754e-01, ..., -1.8773e-01, -5.0988e-01, -1.2065e+00], [ 1.2712e+00, -5.0849e-01, 6.2769e-01, ..., 1.0904e+00, 2.1514e-01, -4.0929e-01]], device='cuda:0') real_imgs =
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xKXVVKIB-1664249499192)(test_files/test_21_1.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SxHmPOLV-1664249499193)(test_files/test_21_2.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3X41G5yf-1664249499194)(test_files/test_21_3.png)]
8 计算loss,反向传播
分别对生成器和判别器计算loss,使用反向传播更新模型参数
# Adversarial ground truths batch_size = imgs.shape[0] valid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真 fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假 # --------------------- # Train Generator # --------------------- optimizer_G.zero_grad() # Sample noise as generator input z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator g_loss = adversarial_loss(discriminator(gen_imgs), valid) g_loss.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss = (real_loss + fake_loss) / 2 print("real_loss =", real_loss, '\n') print("fake_loss =", fake_loss, '\n') print("d_loss =", d_loss, '\n') d_loss.backward() optimizer_D.step()
real_loss = tensor(0.7088, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) fake_loss = tensor(0.6778, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) d_loss = tensor(0.6933, device='cuda:0', grad_fn=<DivBackward0>)
9 保存生成图像和模型文件
from torchvision.utils import save_image def sample_image(n_row, batches_done): """Saves a grid of generated digits ranging from 0 to n_classes""" # Sample noise z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim)))) # Get labels ranging from 0 to n_classes for n rows gen_imgs = generator(z) save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True) epoch = 0 # temporary batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: os.makedirs("images", exist_ok=True) sample_image(n_row=10, batches_done=batches_done) os.makedirs("model", exist_ok=True) # 保存模型 torch.save(generator, 'model/generator.pkl') torch.save(discriminator, 'model/discriminator.pkl') print("gen images saved!\n") print("model saved!")
gen images saved! model saved!
rue)
epoch = 0 # temporary batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: os.makedirs("images", exist_ok=True) sample_image(n_row=10, batches_done=batches_done) os.makedirs("model", exist_ok=True) # 保存模型 torch.save(generator, 'model/generator.pkl') torch.save(discriminator, 'model/discriminator.pkl') print("gen images saved!\n") print("model saved!")
gen images saved! model saved!