DCGAN代码解析(一)

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: DCGAN代码解析(一)

DCGAN代码解析


今天我们将对GAN领域中经典的论文DCGAN做一个简单的解析。


论文地址:Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks-ReadPaper论文阅读平台


1 初始化

import argparse
import os
import numpy as np
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)


Namespace(b1=0.5, b2=0.999, batch_size=64, channels=1, img_size=32, latent_dim=100, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=400)


2 数据加载

加载后的数据为 32 * 32 的灰度图

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)
from torch.autograd import Variable
import matplotlib.pyplot as plt
def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
mnist = datasets.MNIST("../../data/mnist")
for i in range(3):
    sample = mnist[i][0]
    show_img(np.array(sample), trans=False)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eozTGysh-1664249499185)(test_files/test_6_0.png)]


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kI2wdLrr-1664249499187)(test_files/test_6_1.png)]


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eC4sWl8G-1664249499188)(test_files/test_6_2.png)]

trans_resize = transforms.Resize(opt.img_size)
trans_to_tensor = transforms.ToTensor()
trans_normalize = transforms.Normalize([0.5], [0.5]) # x_n = (x - 0.5) / 0.5
print("shape =", np.array(sample).shape, '\n')
print("data =", np.array(sample), '\n')
sample_resize = trans_resize(sample) 
print("(trans_resize) shape =", np.array(sample_resize).shape, '\n')
sample_tensor = trans_to_tensor(sample_resize)
print("(trans_to_tensor) data =", sample_tensor, '\n')
sample_normalize = trans_normalize(sample_tensor)
print("(trans_normalize) data =", sample_normalize, '\n')
shape = (28, 28) 
data = [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0  67 232  39   0   0   0   0   0]
 [  0   0   0   0  62  81   0   0   0   0   0   0   0   0   0   0   0   0
    0   0 120 180  39   0   0   0   0   0]
 [  0   0   0   0 126 163   0   0   0   0   0   0   0   0   0   0   0   0
    0   2 153 210  40   0   0   0   0   0]
 [  0   0   0   0 220 163   0   0   0   0   0   0   0   0   0   0   0   0
    0  27 254 162   0   0   0   0   0   0]
 [  0   0   0   0 222 163   0   0   0   0   0   0   0   0   0   0   0   0
    0 183 254 125   0   0   0   0   0   0]
 [  0   0   0  46 245 163   0   0   0   0   0   0   0   0   0   0   0   0
    0 198 254  56   0   0   0   0   0   0]
 [  0   0   0 120 254 163   0   0   0   0   0   0   0   0   0   0   0   0
   23 231 254  29   0   0   0   0   0   0]
 [  0   0   0 159 254 120   0   0   0   0   0   0   0   0   0   0   0   0
  163 254 216  16   0   0   0   0   0   0]
 [  0   0   0 159 254  67   0   0   0   0   0   0   0   0   0  14  86 178
  248 254  91   0   0   0   0   0   0   0]
 [  0   0   0 159 254  85   0   0   0  47  49 116 144 150 241 243 234 179
  241 252  40   0   0   0   0   0   0   0]
 [  0   0   0 150 253 237 207 207 207 253 254 250 240 198 143  91  28   5
  233 250   0   0   0   0   0   0   0   0]
 [  0   0   0   0 119 177 177 177 177 177  98  56   0   0   0   0   0 102
  254 220   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254 137   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  57   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  57   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  255  94   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  96   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  255 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  96
  254 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]] 
(trans_resize) shape = (32, 32) 
(trans_to_tensor) data = tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]]) 
(trans_normalize) data = tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]])


3 模型

3.1生成器

包含1个全连接层和3个卷积层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm和Upsample

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y8xGwXe6-1664249499189)(figures/BN.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FeR2L0uH-1664249499190)(figures/resize.png)]

import torch.nn as nn
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
generator = Generator()
print(generator)
Generator(
  (l1): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
  )
  (conv_blocks): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2, mode=nearest)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Upsample(scale_factor=2, mode=nearest)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Tanh()
  )
)


3.2判别器

包含4个卷积层和1个全连接层,使用LeakyReLU和Sigmoid激活函数,使用了Dropout和BatchNorm,使用Strided Conv进行下采样

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OlQtOQTB-1664249499191)(figures/strided.png)]

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity
discriminator = Discriminator()
print(discriminator)
Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    (1): Sigmoid()
  )
)


3.3初始化

对卷积层和BatchNorm层进行参数初始化

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    (1): Sigmoid()
  )
)


目录
相关文章
|
14天前
|
PHP 开发者 容器
PHP命名空间深度解析:避免命名冲突与提升代码组织####
本文深入探讨了PHP中命名空间的概念、用途及最佳实践,揭示其在解决全局命名冲突、提高代码可维护性方面的重要性。通过生动实例和详尽分析,本文将帮助开发者有效利用命名空间来优化大型项目结构,确保代码的清晰与高效。 ####
18 1
|
22天前
|
机器学习/深度学习 存储 人工智能
强化学习与深度强化学习:深入解析与代码实现
本书《强化学习与深度强化学习:深入解析与代码实现》系统地介绍了强化学习的基本概念、经典算法及其在深度学习框架下的应用。从强化学习的基础理论出发,逐步深入到Q学习、SARSA等经典算法,再到DQN、Actor-Critic等深度强化学习方法,结合Python代码示例,帮助读者理解并实践这些先进的算法。书中还探讨了强化学习在无人驾驶、游戏AI等领域的应用及面临的挑战,为读者提供了丰富的理论知识和实战经验。
47 5
|
1月前
|
存储 安全 Java
系统安全架构的深度解析与实践:Java代码实现
【11月更文挑战第1天】系统安全架构是保护信息系统免受各种威胁和攻击的关键。作为系统架构师,设计一套完善的系统安全架构不仅需要对各种安全威胁有深入理解,还需要熟练掌握各种安全技术和工具。
120 10
|
1月前
|
前端开发 JavaScript 开发者
揭秘前端高手的秘密武器:深度解析递归组件与动态组件的奥妙,让你代码效率翻倍!
【10月更文挑战第23天】在Web开发中,组件化已成为主流。本文深入探讨了递归组件与动态组件的概念、应用及实现方式。递归组件通过在组件内部调用自身,适用于处理层级结构数据,如菜单和树形控件。动态组件则根据数据变化动态切换组件显示,适用于不同业务逻辑下的组件展示。通过示例,展示了这两种组件的实现方法及其在实际开发中的应用价值。
36 1
|
2月前
|
机器学习/深度学习 人工智能 算法
揭开深度学习与传统机器学习的神秘面纱:从理论差异到实战代码详解两者间的选择与应用策略全面解析
【10月更文挑战第10天】本文探讨了深度学习与传统机器学习的区别,通过图像识别和语音处理等领域的应用案例,展示了深度学习在自动特征学习和处理大规模数据方面的优势。文中还提供了一个Python代码示例,使用TensorFlow构建多层感知器(MLP)并与Scikit-learn中的逻辑回归模型进行对比,进一步说明了两者的不同特点。
91 2
|
2月前
|
存储 搜索推荐 数据库
运用LangChain赋能企业规章制度制定:深入解析Retrieval-Augmented Generation(RAG)技术如何革新内部管理文件起草流程,实现高效合规与个性化定制的完美结合——实战指南与代码示例全面呈现
【10月更文挑战第3天】构建公司规章制度时,需融合业务实际与管理理论,制定合规且促发展的规则体系。尤其在数字化转型背景下,利用LangChain框架中的RAG技术,可提升规章制定效率与质量。通过Chroma向量数据库存储规章制度文本,并使用OpenAI Embeddings处理文本向量化,将现有文档转换后插入数据库。基于此,构建RAG生成器,根据输入问题检索信息并生成规章制度草案,加快更新速度并确保内容准确,灵活应对法律与业务变化,提高管理效率。此方法结合了先进的人工智能技术,展现了未来规章制度制定的新方向。
42 3
|
2月前
|
SQL 监控 关系型数据库
SQL错误代码1303解析与处理方法
在SQL编程和数据库管理中,遇到错误代码是常有的事,其中错误代码1303在不同数据库系统中可能代表不同的含义
|
2月前
|
SQL 安全 关系型数据库
SQL错误代码1303解析与解决方案:深入理解并应对权限问题
在数据库管理和开发过程中,遇到错误代码是常见的事情,每个错误代码都代表着一种特定的问题
|
3月前
|
敏捷开发 安全 测试技术
软件测试的艺术:从代码到用户体验的全方位解析
本文将深入探讨软件测试的重要性和实施策略,通过分析不同类型的测试方法和工具,展示如何有效地提升软件质量和用户满意度。我们将从单元测试、集成测试到性能测试等多个角度出发,详细解释每种测试方法的实施步骤和最佳实践。此外,文章还将讨论如何通过持续集成和自动化测试来优化测试流程,以及如何建立有效的测试团队来应对快速变化的市场需求。通过实际案例的分析,本文旨在为读者提供一套系统而实用的软件测试策略,帮助读者在软件开发过程中做出更明智的决策。
|
3月前
|
SQL 人工智能 机器人
遇到的代码部份解析
/ 模拟后端返回的数据
20 0

推荐镜像

更多