DCGAN代码解析(一)

简介: DCGAN代码解析(一)

DCGAN代码解析


今天我们将对GAN领域中经典的论文DCGAN做一个简单的解析。


论文地址:Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks-ReadPaper论文阅读平台


1 初始化

import argparse
import os
import numpy as np
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)


Namespace(b1=0.5, b2=0.999, batch_size=64, channels=1, img_size=32, latent_dim=100, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=400)


2 数据加载

加载后的数据为 32 * 32 的灰度图

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)
from torch.autograd import Variable
import matplotlib.pyplot as plt
def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
mnist = datasets.MNIST("../../data/mnist")
for i in range(3):
    sample = mnist[i][0]
    show_img(np.array(sample), trans=False)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eozTGysh-1664249499185)(test_files/test_6_0.png)]


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kI2wdLrr-1664249499187)(test_files/test_6_1.png)]


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eC4sWl8G-1664249499188)(test_files/test_6_2.png)]

trans_resize = transforms.Resize(opt.img_size)
trans_to_tensor = transforms.ToTensor()
trans_normalize = transforms.Normalize([0.5], [0.5]) # x_n = (x - 0.5) / 0.5
print("shape =", np.array(sample).shape, '\n')
print("data =", np.array(sample), '\n')
sample_resize = trans_resize(sample) 
print("(trans_resize) shape =", np.array(sample_resize).shape, '\n')
sample_tensor = trans_to_tensor(sample_resize)
print("(trans_to_tensor) data =", sample_tensor, '\n')
sample_normalize = trans_normalize(sample_tensor)
print("(trans_normalize) data =", sample_normalize, '\n')
shape = (28, 28) 
data = [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0  67 232  39   0   0   0   0   0]
 [  0   0   0   0  62  81   0   0   0   0   0   0   0   0   0   0   0   0
    0   0 120 180  39   0   0   0   0   0]
 [  0   0   0   0 126 163   0   0   0   0   0   0   0   0   0   0   0   0
    0   2 153 210  40   0   0   0   0   0]
 [  0   0   0   0 220 163   0   0   0   0   0   0   0   0   0   0   0   0
    0  27 254 162   0   0   0   0   0   0]
 [  0   0   0   0 222 163   0   0   0   0   0   0   0   0   0   0   0   0
    0 183 254 125   0   0   0   0   0   0]
 [  0   0   0  46 245 163   0   0   0   0   0   0   0   0   0   0   0   0
    0 198 254  56   0   0   0   0   0   0]
 [  0   0   0 120 254 163   0   0   0   0   0   0   0   0   0   0   0   0
   23 231 254  29   0   0   0   0   0   0]
 [  0   0   0 159 254 120   0   0   0   0   0   0   0   0   0   0   0   0
  163 254 216  16   0   0   0   0   0   0]
 [  0   0   0 159 254  67   0   0   0   0   0   0   0   0   0  14  86 178
  248 254  91   0   0   0   0   0   0   0]
 [  0   0   0 159 254  85   0   0   0  47  49 116 144 150 241 243 234 179
  241 252  40   0   0   0   0   0   0   0]
 [  0   0   0 150 253 237 207 207 207 253 254 250 240 198 143  91  28   5
  233 250   0   0   0   0   0   0   0   0]
 [  0   0   0   0 119 177 177 177 177 177  98  56   0   0   0   0   0 102
  254 220   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254 137   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  57   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  57   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  255  94   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  96   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  255 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  96
  254 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]] 
(trans_resize) shape = (32, 32) 
(trans_to_tensor) data = tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]]) 
(trans_normalize) data = tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]])


3 模型

3.1生成器

包含1个全连接层和3个卷积层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm和Upsample

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y8xGwXe6-1664249499189)(figures/BN.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FeR2L0uH-1664249499190)(figures/resize.png)]

import torch.nn as nn
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
generator = Generator()
print(generator)
Generator(
  (l1): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
  )
  (conv_blocks): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2, mode=nearest)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Upsample(scale_factor=2, mode=nearest)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Tanh()
  )
)


3.2判别器

包含4个卷积层和1个全连接层,使用LeakyReLU和Sigmoid激活函数,使用了Dropout和BatchNorm,使用Strided Conv进行下采样

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OlQtOQTB-1664249499191)(figures/strided.png)]

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity
discriminator = Discriminator()
print(discriminator)
Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    (1): Sigmoid()
  )
)


3.3初始化

对卷积层和BatchNorm层进行参数初始化

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    (1): Sigmoid()
  )
)


目录
相关文章
|
6月前
|
算法 PyTorch 算法框架/工具
昇腾 msmodelslim w8a8量化代码解析
msmodelslim w8a8量化算法原理和代码解析
456 5
|
8月前
|
搜索推荐 UED Python
实现一个带有昼夜背景切换的动态时钟:从代码到功能解析
本文介绍了一个使用Python和Tkinter库实现的动态时钟程序,具有昼夜背景切换、指针颜色随机变化及整点和半点报时功能。通过设置不同的背景颜色和随机变换指针颜色,增强视觉吸引力;利用多线程技术确保音频播放不影响主程序运行。该程序结合了Tkinter、Pygame、Pytz等库,提供了一个美观且实用的时间显示工具。欢迎点赞、关注、转发、收藏!
379 94
|
6月前
|
传感器 监控 Java
Java代码结构解析:类、方法、主函数(1分钟解剖室)
### Java代码结构简介 掌握Java代码结构如同拥有程序世界的建筑蓝图,类、方法和主函数构成“黄金三角”。类是独立的容器,承载成员变量和方法;方法实现特定功能,参数控制输入环境;主函数是程序入口。常见错误包括类名与文件名不匹配、忘记static修饰符和花括号未闭合。通过实战案例学习电商系统、游戏角色控制和物联网设备监控,理解类的作用、方法类型和主函数任务,避免典型错误,逐步提升编程能力。 **脑图速记法**:类如太空站,方法即舱段;main是发射台,static不能换;文件名对仗,括号要成双;参数是坐标,void不返航。
249 5
|
7月前
|
人工智能 文字识别 自然语言处理
保单AI识别技术及代码示例解析
车险保单包含基础信息、车辆信息、人员信息、保险条款及特别约定等关键内容。AI识别技术通过OCR、文档结构化解析和数据校验,实现对保单信息的精准提取。然而,版式多样性、信息复杂性、图像质量和法律术语解析是主要挑战。Python代码示例展示了如何使用PaddleOCR进行保单信息抽取,并提出了定制化训练、版式分析等优化方向。典型应用场景包括智能录入、快速核保、理赔自动化等。未来将向多模态融合、自适应学习和跨区域兼容性发展。
|
9月前
|
自然语言处理 搜索推荐 数据安全/隐私保护
鸿蒙登录页面好看的样式设计-HarmonyOS应用开发实战与ArkTS代码解析【HarmonyOS 5.0(Next)】
鸿蒙登录页面设计展示了 HarmonyOS 5.0(Next)的未来美学理念,结合科技与艺术,为用户带来视觉盛宴。该页面使用 ArkTS 开发,支持个性化定制和无缝智能设备连接。代码解析涵盖了声明式 UI、状态管理、事件处理及路由导航等关键概念,帮助开发者快速上手 HarmonyOS 应用开发。通过这段代码,开发者可以了解如何构建交互式界面并实现跨设备协同工作,推动智能生态的发展。
531 10
鸿蒙登录页面好看的样式设计-HarmonyOS应用开发实战与ArkTS代码解析【HarmonyOS 5.0(Next)】
|
8月前
|
SQL Java 数据库连接
如何在 Java 代码中使用 JSqlParser 解析复杂的 SQL 语句?
大家好,我是 V 哥。JSqlParser 是一个用于解析 SQL 语句的 Java 库,可将 SQL 解析为 Java 对象树,支持多种 SQL 类型(如 `SELECT`、`INSERT` 等)。它适用于 SQL 分析、修改、生成和验证等场景。通过 Maven 或 Gradle 安装后,可以方便地在 Java 代码中使用。
2652 11
|
9月前
|
PHP 开发者 容器
PHP命名空间深度解析:避免命名冲突与提升代码组织####
本文深入探讨了PHP中命名空间的概念、用途及最佳实践,揭示其在解决全局命名冲突、提高代码可维护性方面的重要性。通过生动实例和详尽分析,本文将帮助开发者有效利用命名空间来优化大型项目结构,确保代码的清晰与高效。 ####
146 20
|
10月前
|
机器学习/深度学习 存储 人工智能
强化学习与深度强化学习:深入解析与代码实现
本书《强化学习与深度强化学习:深入解析与代码实现》系统地介绍了强化学习的基本概念、经典算法及其在深度学习框架下的应用。从强化学习的基础理论出发,逐步深入到Q学习、SARSA等经典算法,再到DQN、Actor-Critic等深度强化学习方法,结合Python代码示例,帮助读者理解并实践这些先进的算法。书中还探讨了强化学习在无人驾驶、游戏AI等领域的应用及面临的挑战,为读者提供了丰富的理论知识和实战经验。
506 5
|
10月前
|
存储 安全 Java
系统安全架构的深度解析与实践:Java代码实现
【11月更文挑战第1天】系统安全架构是保护信息系统免受各种威胁和攻击的关键。作为系统架构师,设计一套完善的系统安全架构不仅需要对各种安全威胁有深入理解,还需要熟练掌握各种安全技术和工具。
455 10
|
10月前
|
前端开发 JavaScript 开发者
揭秘前端高手的秘密武器:深度解析递归组件与动态组件的奥妙,让你代码效率翻倍!
【10月更文挑战第23天】在Web开发中,组件化已成为主流。本文深入探讨了递归组件与动态组件的概念、应用及实现方式。递归组件通过在组件内部调用自身,适用于处理层级结构数据,如菜单和树形控件。动态组件则根据数据变化动态切换组件显示,适用于不同业务逻辑下的组件展示。通过示例,展示了这两种组件的实现方法及其在实际开发中的应用价值。
153 1

推荐镜像

更多
  • DNS