GAN网络的代码实现(学习ing)

简介: GAN网络的代码实现(学习ing)

感悟:在学习某个网络的时候,一味的看概念越看越不理解,此时我们可以找一份源代码,试着写一遍,也许就会明白了。

#----------引入需要的库----------
import argparse #配置超参数的库
import os#用来创建文件夹
import numpy as np
#对数据进行一些操作
import torchvision.transforms as transforms
#保存图片
from torchvision.utils import save_image
#数据加载器
from torch.utils.data import DataLoader
#加载数据
from torchvision import datasets
#在旧版本的PyTorch中,Variable 类被用于包装张量,
#并自动跟踪对该张量的所有操作,从而支持自动计算梯度。
#它是PyTorch自动微分系统的核心组件。
from torch.autograd import Variable
import torch.nn as nn
import torch
#----------创建文件夹和配置一些参数----------
# 创建文件夹
os.makedirs("./images/gan/",exist_ok=True)
os.makedirs("./save/gan/",exist_ok=True)
os.makedirs("./datasets/mnist/",exist_ok=True)

#超参数配置-->超参数就是我们可以设置的
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs",type = int,default = 50,help = "number of epochs of training" )
parser.add_argument("--batch_size",type = int,default = 2,help="size of the batches")
parser.add_argument("--lr",type = float,default = 0.0002,help="adam: learning rate")
parser.add_argument("--b1",type = float,default = 0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2",type = float,default = 0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu",type = int, default = 2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim",type = int,default = 100, help="dimensionality of the latent space")
parser.add_argument("--img_size",type = int,default = 28, help="size of each image dimension")
parser.add_argument("--channels", type = int,default = 1, help="number of image channels")
parser.add_argument("--sample_interval",type = int,default = 500, help="interval betwen image samples")
opt = parser.parse_known_args()[0]

# print(opt)输出结果如下
#Namespace(b1=0.5, b2=0.999, batch_size=2, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=2, n_epochs=50, sample_interval=500)
#----------下载数据并将数据份数分好----------
#(1,28,28)
img_shape = (opt.channels,opt.img_size,opt.img_size)
print(img_shape)
#计算数组所有元素的乘积
img_area = np.prod(img_shape)
print(img_area)
#cuda 查看cuda是否可用
cuda = True if torch.cuda.is_available() else False

#mnist 数据集下载并对数据做一些处理
mnist = datasets.MNIST(root = "./datasets",train = True, download = True, transform = transforms.Compose(
                        [transforms.Resize(opt.img_size),
                         transforms.ToTensor(),
                         transforms.Normalize([0.5],[0.5])]))
#加载器 分批次加载数据集,
#将数据分成len(dataloader)/batchsize【不是很严谨】份送入网络
dataloader = DataLoader(
    mnist,
    batch_size = opt.batch_size,
    shuffle = True)
#----------判别器----------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area,512),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(256,1),
            nn.Sigmoid()
            
        )
    def forward(self,img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        print(validity)
        return validity
#----------生成器----------
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__( )
        def block(in_feat,out_feat,normalize = True):
            layers = [nn.Linear(in_feat,out_feat)]
            if  normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        self.model = nn.Sequential(
          # *操作符被称为解包操作符
          # 这里可以理解成将block中的层一个一个的写到了Sequential中
            *block(opt.latent_dim,128,normalize = False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,img_area),
            nn.Tanh()
        )
    def forward(self,z):
        imgs = self.model(z)
        #imgs.view的意思,
        #(2, batch_size * channels * height * width)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        imgs = imgs.view(imgs.size(0),*img_shape)
        #imgs.size(0)就是2
        #将imgs从一维向量重新reshape为合适的图像形状img_shape
        return imgs 
        
# 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
#loss
criterion = torch.nn.BCELoss()
#youhua
optimizer_G = torch.optim.Adam(generator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))

#有cuda就在cuda上运行
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()
# training
for epoch in range(opt.n_epochs):
    # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    for i,(imgs,_) in enumerate(dataloader):
        imgs = imgs.view(imgs.size(0),-1)
        #使用cuda,就在后边加个.cuda()
        #real_img = Variable(imgs).cuda()
        real_img = Variable(imgs)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        #2个三维数组,1个二维数组,28个一维数组
        #real_label = Variable(torch.ones(imgs.size(0),1)).cuda()#全1
        #fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()#全0
        #-----------------重点-----------------
        #-为什么real_label是全1呢?fake_label全为0呢?-
        #------------这就是gan的原理啦------------
        #在生成对抗网络(GAN)中,判别器的目标是将真实样本判别为1,
        #将生成的假样本判别为0。
        real_label = Variable(torch.ones(imgs.size(0),1))
        fake_label = Variable(torch.zeros(imgs.size(0),1))
        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)
        #print(real_img.view(real_img.size(0),-1))
        loss_real_D = criterion(real_out,real_label)
        real_scores = real_out
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        #  detach还没很理解
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out,fake_label)
        fake_scores = fake_out
        #损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        # ---------------------
        # Train Generator
        # ---------------------
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z)
        output = discriminator(fake_img)
        #损失函数和优化
        loss_G = criterion(output,real_label)
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if(i+1) % 100 == 0:
            print(
                "[Epoch %d/%d]  [Batch %d/%d]  [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                %(epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),real_scores.data.mean(),fake_scores.data.mean())
            )
        # 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(fake_img.data[:25],"./images/gan/%d.png"%batches_done,nrow = 5,normalize = True)
# 保存模型
#将生成器和判别器的网络权重保存起来
torch.save(generator.state_dict(),"./save/gan/generator.pth")
torch.save(discriminator.state_dict(),"./save/gan/discriminator.pth")


相关文章
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
|
1月前
|
机器学习/深度学习 网络架构 计算机视觉
目标检测笔记(一):不同模型的网络架构介绍和代码
这篇文章介绍了ShuffleNetV2网络架构及其代码实现,包括模型结构、代码细节和不同版本的模型。ShuffleNetV2是一个高效的卷积神经网络,适用于深度学习中的目标检测任务。
70 1
目标检测笔记(一):不同模型的网络架构介绍和代码
|
18天前
|
编解码 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(10-2):保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali——Liinux-Debian:就怕你学成黑客啦!)作者——LJS
保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali以及常见的报错及对应解决方案、常用Kali功能简便化以及详解如何具体实现
|
18天前
|
安全 网络协议 算法
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-1):主动信息收集之ping、Nmap 就怕你学成黑客啦!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-1):主动信息收集之ping、Nmap 就怕你学成黑客啦!
|
18天前
|
网络协议 安全 NoSQL
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!
scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练等具体操作详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法IKUN和I原们你这要是学不会我直接退出江湖;好吧!!!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!
|
18天前
|
网络协议 安全 算法
网络空间安全之一个WH的超前沿全栈技术深入学习之路(9):WireShark 简介和抓包原理及实战过程一条龙全线分析——就怕你学成黑客啦!
实战:WireShark 抓包及快速定位数据包技巧、使用 WireShark 对常用协议抓包并分析原理 、WireShark 抓包解决服务器被黑上不了网等具体操作详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法IKUN和I原们你这要是学不会我直接退出江湖;好吧!!!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(9):WireShark 简介和抓包原理及实战过程一条龙全线分析——就怕你学成黑客啦!
|
2月前
|
监控 网络协议 Linux
网络学习
网络学习
146 68
|
1月前
|
存储 安全 网络安全
浅谈网络安全的认识与学习规划
浅谈网络安全的认识与学习规划
31 6
|
18天前
|
人工智能 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS

热门文章

最新文章