GAN网络的代码实现(学习ing)

简介: GAN网络的代码实现(学习ing)

感悟:在学习某个网络的时候,一味的看概念越看越不理解,此时我们可以找一份源代码,试着写一遍,也许就会明白了。

#----------引入需要的库----------
import argparse #配置超参数的库
import os#用来创建文件夹
import numpy as np
#对数据进行一些操作
import torchvision.transforms as transforms
#保存图片
from torchvision.utils import save_image
#数据加载器
from torch.utils.data import DataLoader
#加载数据
from torchvision import datasets
#在旧版本的PyTorch中,Variable 类被用于包装张量,
#并自动跟踪对该张量的所有操作,从而支持自动计算梯度。
#它是PyTorch自动微分系统的核心组件。
from torch.autograd import Variable
import torch.nn as nn
import torch
#----------创建文件夹和配置一些参数----------
# 创建文件夹
os.makedirs("./images/gan/",exist_ok=True)
os.makedirs("./save/gan/",exist_ok=True)
os.makedirs("./datasets/mnist/",exist_ok=True)

#超参数配置-->超参数就是我们可以设置的
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs",type = int,default = 50,help = "number of epochs of training" )
parser.add_argument("--batch_size",type = int,default = 2,help="size of the batches")
parser.add_argument("--lr",type = float,default = 0.0002,help="adam: learning rate")
parser.add_argument("--b1",type = float,default = 0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2",type = float,default = 0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu",type = int, default = 2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim",type = int,default = 100, help="dimensionality of the latent space")
parser.add_argument("--img_size",type = int,default = 28, help="size of each image dimension")
parser.add_argument("--channels", type = int,default = 1, help="number of image channels")
parser.add_argument("--sample_interval",type = int,default = 500, help="interval betwen image samples")
opt = parser.parse_known_args()[0]

# print(opt)输出结果如下
#Namespace(b1=0.5, b2=0.999, batch_size=2, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=2, n_epochs=50, sample_interval=500)
#----------下载数据并将数据份数分好----------
#(1,28,28)
img_shape = (opt.channels,opt.img_size,opt.img_size)
print(img_shape)
#计算数组所有元素的乘积
img_area = np.prod(img_shape)
print(img_area)
#cuda 查看cuda是否可用
cuda = True if torch.cuda.is_available() else False

#mnist 数据集下载并对数据做一些处理
mnist = datasets.MNIST(root = "./datasets",train = True, download = True, transform = transforms.Compose(
                        [transforms.Resize(opt.img_size),
                         transforms.ToTensor(),
                         transforms.Normalize([0.5],[0.5])]))
#加载器 分批次加载数据集,
#将数据分成len(dataloader)/batchsize【不是很严谨】份送入网络
dataloader = DataLoader(
    mnist,
    batch_size = opt.batch_size,
    shuffle = True)
#----------判别器----------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area,512),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(256,1),
            nn.Sigmoid()
            
        )
    def forward(self,img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        print(validity)
        return validity
#----------生成器----------
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__( )
        def block(in_feat,out_feat,normalize = True):
            layers = [nn.Linear(in_feat,out_feat)]
            if  normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        self.model = nn.Sequential(
          # *操作符被称为解包操作符
          # 这里可以理解成将block中的层一个一个的写到了Sequential中
            *block(opt.latent_dim,128,normalize = False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,img_area),
            nn.Tanh()
        )
    def forward(self,z):
        imgs = self.model(z)
        #imgs.view的意思,
        #(2, batch_size * channels * height * width)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        imgs = imgs.view(imgs.size(0),*img_shape)
        #imgs.size(0)就是2
        #将imgs从一维向量重新reshape为合适的图像形状img_shape
        return imgs 
        
# 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
#loss
criterion = torch.nn.BCELoss()
#youhua
optimizer_G = torch.optim.Adam(generator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))

#有cuda就在cuda上运行
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()
# training
for epoch in range(opt.n_epochs):
    # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    for i,(imgs,_) in enumerate(dataloader):
        imgs = imgs.view(imgs.size(0),-1)
        #使用cuda,就在后边加个.cuda()
        #real_img = Variable(imgs).cuda()
        real_img = Variable(imgs)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        #2个三维数组,1个二维数组,28个一维数组
        #real_label = Variable(torch.ones(imgs.size(0),1)).cuda()#全1
        #fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()#全0
        #-----------------重点-----------------
        #-为什么real_label是全1呢?fake_label全为0呢?-
        #------------这就是gan的原理啦------------
        #在生成对抗网络(GAN)中,判别器的目标是将真实样本判别为1,
        #将生成的假样本判别为0。
        real_label = Variable(torch.ones(imgs.size(0),1))
        fake_label = Variable(torch.zeros(imgs.size(0),1))
        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)
        #print(real_img.view(real_img.size(0),-1))
        loss_real_D = criterion(real_out,real_label)
        real_scores = real_out
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        #  detach还没很理解
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out,fake_label)
        fake_scores = fake_out
        #损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        # ---------------------
        # Train Generator
        # ---------------------
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z)
        output = discriminator(fake_img)
        #损失函数和优化
        loss_G = criterion(output,real_label)
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if(i+1) % 100 == 0:
            print(
                "[Epoch %d/%d]  [Batch %d/%d]  [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                %(epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),real_scores.data.mean(),fake_scores.data.mean())
            )
        # 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(fake_img.data[:25],"./images/gan/%d.png"%batches_done,nrow = 5,normalize = True)
# 保存模型
#将生成器和判别器的网络权重保存起来
torch.save(generator.state_dict(),"./save/gan/generator.pth")
torch.save(discriminator.state_dict(),"./save/gan/discriminator.pth")


相关文章
|
8天前
|
Kubernetes 应用服务中间件 Docker
Kubernetes学习-集群搭建篇(二) 部署Node服务,启动JNI网络插件
Kubernetes学习-集群搭建篇(二) 部署Node服务,启动JNI网络插件
|
8天前
|
存储 算法 Windows
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例(下)
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例
|
1天前
|
机器学习/深度学习 编解码 算法
YOLOv5改进 | 主干网络 | 用EfficientNet卷积替换backbone【教程+代码 】
在YOLOv5的GFLOPs计算量中,卷积占了其中大多数的比列,为了减少计算量,研究人员提出了用EfficientNet代替backbone。本文给大家带来的教程是**将原来的主干网络替换为EfficientNet。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。
|
4天前
|
机器学习/深度学习 编解码 算法
YOLOv5改进 | 主干网络 | 将backbone替换为MobileNetV3【小白必备教程+附完整代码】
本文介绍了将YOLOv5的backbone替换为MobileNetV3以提升目标检测性能的教程。MobileNetV3采用倒残差结构、Squeeze-and-Excitation模块和Hard-Swish激活函数,实现更高性能和更低计算成本。文中提供了详细的代码实现,包括MobileNetV3的关键组件和YOLOv5的配置修改,便于读者实践。此外,还分享了完整代码链接和进一步的进阶策略,适合深度学习初学者和进阶者学习YOLO系列。
|
7天前
|
机器学习/深度学习 数据可视化 PyTorch
使用Python实现深度学习模型:生成对抗网络(GAN)
使用Python实现深度学习模型:生成对抗网络(GAN)
21 3
|
8天前
|
机器学习/深度学习 算法 网络架构
什么是神经网络学习中的反向传播算法?
什么是神经网络学习中的反向传播算法?
12 2
|
8天前
|
机器学习/深度学习 编解码 自然语言处理
深度学习500问——Chapter07:生成对抗网络(GAN)(3)
深度学习500问——Chapter07:生成对抗网络(GAN)(3)
21 0
|
8天前
|
机器学习/深度学习 JavaScript Linux
深度学习500问——Chapter07:生成对抗网络(GAN)(2)
深度学习500问——Chapter07:生成对抗网络(GAN)(2)
21 0
|
8天前
|
机器学习/深度学习 JavaScript 算法
深度学习500问——Chapter07:生成对抗网络(GAN)(1)
深度学习500问——Chapter07:生成对抗网络(GAN)(1)
31 3
|
8天前
|
网络安全 数据安全/隐私保护 计算机视觉
2024蓝桥杯网络安全-图片隐写-缺失的数据(0基础也能学会-含代码解释)
2024蓝桥杯网络安全-图片隐写-缺失的数据(0基础也能学会-含代码解释)

热门文章

最新文章