GAN网络的代码实现(学习ing)

简介: GAN网络的代码实现(学习ing)

感悟:在学习某个网络的时候,一味的看概念越看越不理解,此时我们可以找一份源代码,试着写一遍,也许就会明白了。

#----------引入需要的库----------
import argparse #配置超参数的库
import os#用来创建文件夹
import numpy as np
#对数据进行一些操作
import torchvision.transforms as transforms
#保存图片
from torchvision.utils import save_image
#数据加载器
from torch.utils.data import DataLoader
#加载数据
from torchvision import datasets
#在旧版本的PyTorch中,Variable 类被用于包装张量,
#并自动跟踪对该张量的所有操作,从而支持自动计算梯度。
#它是PyTorch自动微分系统的核心组件。
from torch.autograd import Variable
import torch.nn as nn
import torch
#----------创建文件夹和配置一些参数----------
# 创建文件夹
os.makedirs("./images/gan/",exist_ok=True)
os.makedirs("./save/gan/",exist_ok=True)
os.makedirs("./datasets/mnist/",exist_ok=True)

#超参数配置-->超参数就是我们可以设置的
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs",type = int,default = 50,help = "number of epochs of training" )
parser.add_argument("--batch_size",type = int,default = 2,help="size of the batches")
parser.add_argument("--lr",type = float,default = 0.0002,help="adam: learning rate")
parser.add_argument("--b1",type = float,default = 0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2",type = float,default = 0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu",type = int, default = 2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim",type = int,default = 100, help="dimensionality of the latent space")
parser.add_argument("--img_size",type = int,default = 28, help="size of each image dimension")
parser.add_argument("--channels", type = int,default = 1, help="number of image channels")
parser.add_argument("--sample_interval",type = int,default = 500, help="interval betwen image samples")
opt = parser.parse_known_args()[0]

# print(opt)输出结果如下
#Namespace(b1=0.5, b2=0.999, batch_size=2, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=2, n_epochs=50, sample_interval=500)
#----------下载数据并将数据份数分好----------
#(1,28,28)
img_shape = (opt.channels,opt.img_size,opt.img_size)
print(img_shape)
#计算数组所有元素的乘积
img_area = np.prod(img_shape)
print(img_area)
#cuda 查看cuda是否可用
cuda = True if torch.cuda.is_available() else False

#mnist 数据集下载并对数据做一些处理
mnist = datasets.MNIST(root = "./datasets",train = True, download = True, transform = transforms.Compose(
                        [transforms.Resize(opt.img_size),
                         transforms.ToTensor(),
                         transforms.Normalize([0.5],[0.5])]))
#加载器 分批次加载数据集,
#将数据分成len(dataloader)/batchsize【不是很严谨】份送入网络
dataloader = DataLoader(
    mnist,
    batch_size = opt.batch_size,
    shuffle = True)
#----------判别器----------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area,512),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(256,1),
            nn.Sigmoid()
            
        )
    def forward(self,img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        print(validity)
        return validity
#----------生成器----------
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__( )
        def block(in_feat,out_feat,normalize = True):
            layers = [nn.Linear(in_feat,out_feat)]
            if  normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        self.model = nn.Sequential(
          # *操作符被称为解包操作符
          # 这里可以理解成将block中的层一个一个的写到了Sequential中
            *block(opt.latent_dim,128,normalize = False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,img_area),
            nn.Tanh()
        )
    def forward(self,z):
        imgs = self.model(z)
        #imgs.view的意思,
        #(2, batch_size * channels * height * width)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        imgs = imgs.view(imgs.size(0),*img_shape)
        #imgs.size(0)就是2
        #将imgs从一维向量重新reshape为合适的图像形状img_shape
        return imgs 
        
# 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
#loss
criterion = torch.nn.BCELoss()
#youhua
optimizer_G = torch.optim.Adam(generator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))

#有cuda就在cuda上运行
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()
# training
for epoch in range(opt.n_epochs):
    # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    for i,(imgs,_) in enumerate(dataloader):
        imgs = imgs.view(imgs.size(0),-1)
        #使用cuda,就在后边加个.cuda()
        #real_img = Variable(imgs).cuda()
        real_img = Variable(imgs)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        #2个三维数组,1个二维数组,28个一维数组
        #real_label = Variable(torch.ones(imgs.size(0),1)).cuda()#全1
        #fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()#全0
        #-----------------重点-----------------
        #-为什么real_label是全1呢?fake_label全为0呢?-
        #------------这就是gan的原理啦------------
        #在生成对抗网络(GAN)中,判别器的目标是将真实样本判别为1,
        #将生成的假样本判别为0。
        real_label = Variable(torch.ones(imgs.size(0),1))
        fake_label = Variable(torch.zeros(imgs.size(0),1))
        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)
        #print(real_img.view(real_img.size(0),-1))
        loss_real_D = criterion(real_out,real_label)
        real_scores = real_out
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        #  detach还没很理解
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out,fake_label)
        fake_scores = fake_out
        #损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        # ---------------------
        # Train Generator
        # ---------------------
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z)
        output = discriminator(fake_img)
        #损失函数和优化
        loss_G = criterion(output,real_label)
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if(i+1) % 100 == 0:
            print(
                "[Epoch %d/%d]  [Batch %d/%d]  [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                %(epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),real_scores.data.mean(),fake_scores.data.mean())
            )
        # 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(fake_img.data[:25],"./images/gan/%d.png"%batches_done,nrow = 5,normalize = True)
# 保存模型
#将生成器和判别器的网络权重保存起来
torch.save(generator.state_dict(),"./save/gan/generator.pth")
torch.save(discriminator.state_dict(),"./save/gan/discriminator.pth")


相关文章
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
|
2月前
|
机器学习/深度学习 存储 算法
回声状态网络(Echo State Networks,ESN)详细原理讲解及Python代码实现
本文详细介绍了回声状态网络(Echo State Networks, ESN)的基本概念、优点、缺点、储层计算范式,并提供了ESN的Python代码实现,包括不考虑和考虑超参数的两种ESN实现方式,以及使用ESN进行时间序列预测的示例。
80 4
回声状态网络(Echo State Networks,ESN)详细原理讲解及Python代码实现
|
1月前
|
监控 网络协议 Linux
网络学习
网络学习
132 68
|
5天前
|
网络协议 网络架构
网络协议介绍与学习
网络协议介绍与学习
18 4
|
5天前
|
网络协议 网络安全 数据安全/隐私保护
网络基础知识学习
如果你打算深入学习网络技术,建议从上述基础知识入手,并逐渐扩展到更高级的主题,如网络编程、网络安全、网络管理等。同时,实践是学习网络技术的关键,可以通过搭建自己的小型网络环境来进行实验和探索。
10 2
|
6天前
|
安全 C#
某网络硬盘网站被植入传播Trojan.DL.Inject.xz等的代码
某网络硬盘网站被植入传播Trojan.DL.Inject.xz等的代码
完成切换网络+修改网络连接图标提示的代码框架
完成切换网络+修改网络连接图标提示的代码框架
|
2月前
|
安全 网络安全 开发者
探索Python中的装饰器:简化代码,增强功能网络安全与信息安全:从漏洞到防护
【8月更文挑战第30天】本文通过深入浅出的方式介绍了Python中装饰器的概念、用法和高级应用。我们将从基础的装饰器定义开始,逐步深入到如何利用装饰器来改进代码结构,最后探讨其在Web框架中的应用。适合有一定Python基础的开发者阅读,旨在帮助读者更好地理解并运用装饰器来优化他们的代码。
|
1月前
|
网络协议 安全 网络安全
网络基础知识学习
【9月更文挑战第1天】
47 0
|
2月前
|
前端开发 算法 网络协议
如何学习计算机基础知识,打好前端和网络安全的基础
如何学习计算机基础知识,打好前端和网络安全的基础
38 4
下一篇
无影云桌面