深度学习第5天:GAN生成对抗网络

简介: 深度学习第5天:GAN生成对抗网络

一、GAN

1.基本思想

想象一下,市面上有许多仿制的画作,人们为了辨别这些伪造的画,就会提高自己的鉴别技能,然后仿制者为了躲过鉴别又会提高自己的伪造技能,这样反反复复,两个群体的技能不断得到提高,这就是GAN的基本思想

2.用途

我们知道GAN的全名是生成对抗网络,那么它就是以生成为主要任务,所以可以用在这些方面

  • 生成虚拟数据集,当数据集数量不够时,我们可以用这种方法生成数据
  • 图像清晰化,可以将模糊图片清晰化
  • 文本到图像的生成,可以训练文生图模型

GAN的用途还有很多,可以在学习过程中慢慢发现

3.模型架构

GAN的主要结构包含一个生成器和一个判别器,我们先输入一堆杂乱数据(被称为噪声)给生成器,接着让判别器将生成器生成的数据与真实的数据作对比,看是否能判别出来,以此往复训练

二、具体任务与代码

1.任务介绍

相信很多人都对手写数字数据集不陌生了,那我们就训练一个生成手写数字的GAN,注意:本示例代码需要的运行时间较长,请在高配置设备上运行或者减少训练回合数

2.导入库函数

先导入必要的库函数,包括torch用来处理神经网络方面的任务,numpy用来处理数据

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import numpy as np

3.生成器与判别器

使用torch定义生成器与判别器的基本结构,这里由于任务比较简单,只用定义线性层就行,再给线性层添加相应的激活函数就行了

# 定义生成器(Generator)和判别器(Discriminator)的简单网络结构
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh()
        )
    def forward(self, noise):
        return self.model(noise)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, image):
        return self.model(image)

4.预处理

这一部分定义了模型参数,加载了数据集,定义了损失函数与优化器,这些是神经网络训练时的一些基本参数

# 定义一些参数
batch_size = 100
learning_rate = 0.0002
epochs = 500
# 加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=batch_size, shuffle=True)
# 初始化生成器、判别器和优化器
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
# 损失函数
criterion = nn.BCELoss()

5.模型训练

这一部分开始训练模型,通过反向传播逐步调整模型的参数,注意模型训练的过程,观察生成器和判别器分别是怎么在训练中互相作用不断提高的

# 训练 GAN
for epoch in range(epochs):
    for data, _ in data_loader:
        data = data.view(data.size(0), -1)
        real_data = Variable(data)
        target_real = Variable(torch.ones(data.size(0), 1))
        target_fake = Variable(torch.zeros(data.size(0), 1))
        # 训练判别器
        optimizer_D.zero_grad()
        output_real = discriminator(real_data)
        loss_real = criterion(output_real, target_real)
        loss_real.backward()
        noise = Variable(torch.randn(data.size(0), 100))
        fake_data = generator(noise)
        output_fake = discriminator(fake_data.detach())
        loss_fake = criterion(output_fake, target_fake)
        loss_fake.backward()
        optimizer_D.step()
        # 训练生成器
        optimizer_G.zero_grad()
        output = discriminator(fake_data)
        loss_G = criterion(output, target_real)
        loss_G.backward()
        optimizer_G.step()
    print(f'Epoch [{epoch+1}/{epochs}], Loss D: {loss_real.item()+loss_fake.item()}, Loss G: {loss_G.item()}')

6.图片生成

这一部分再一次随机生成了一些噪声,并把他们传入生成器生成图片,其中包含一些格式转化过程,再通过matplotlib绘图库显示结果

# 生成一些图片
num_samples = 16
noise = Variable(torch.randn(num_samples, 100))
generated_samples = generator(noise)
generated_samples = generated_samples.view(num_samples, 1, 28, 28).detach()
import matplotlib.pyplot as plt
import torchvision.utils as vutils
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(generated_samples, nrow=4, padding=2, normalize=True).cpu(), (1, 2, 0)
    )
)
plt.show()

7.不同训练轮次的结果对比

感谢阅读,觉得有用的话就订阅下《深度学习》专栏吧,有错误也欢迎指出

相关文章
|
1月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
眼疾识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了4种常见的眼疾图像数据集(白内障、糖尿病性视网膜病变、青光眼和正常眼睛) 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,实现用户上传一张眼疾图片识别其名称。
135 5
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
|
2月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
356 55
|
15天前
|
机器学习/深度学习 数据可视化 算法
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
神经常微分方程(Neural ODEs)是深度学习领域的创新模型,将神经网络的离散变换扩展为连续时间动力系统。本文基于Torchdyn库介绍Neural ODE的实现与训练方法,涵盖数据集构建、模型构建、基于PyTorch Lightning的训练及实验结果可视化等内容。Torchdyn支持多种数值求解算法和高级特性,适用于生成模型、时间序列分析等领域。
162 77
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
|
7天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
51 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
1月前
|
机器学习/深度学习 监控 算法
基于yolov4深度学习网络的排队人数统计系统matlab仿真,带GUI界面
本项目基于YOLOv4深度学习网络,利用MATLAB 2022a实现排队人数统计的算法仿真。通过先进的计算机视觉技术,系统能自动、准确地检测和统计监控画面中的人数,适用于银行、车站等场景,优化资源分配和服务管理。核心程序包含多个回调函数,用于处理用户输入及界面交互,确保系统的高效运行。仿真结果无水印,操作步骤详见配套视频。
54 18
|
2月前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
220 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于yolov4深度学习网络的公共场所人流密度检测系统matlab仿真,带GUI界面
本项目使用 MATLAB 2022a 进行 YOLOv4 算法仿真,实现公共场所人流密度检测。通过卷积神经网络提取图像特征,将图像划分为多个网格进行目标检测和识别,最终计算人流密度。核心程序包括图像和视频读取、处理和显示功能。仿真结果展示了算法的有效性和准确性。
88 31
|
2月前
|
机器学习/深度学习 算法 信息无障碍
基于GoogleNet深度学习网络的手语识别算法matlab仿真
本项目展示了基于GoogleNet的深度学习手语识别算法,使用Matlab2022a实现。通过卷积神经网络(CNN)识别手语手势,如"How are you"、"I am fine"、"I love you"等。核心在于Inception模块,通过多尺度处理和1x1卷积减少计算量,提高效率。项目附带完整代码及操作视频。
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于深度学习网络的宝石类型识别算法matlab仿真
本项目利用GoogLeNet深度学习网络进行宝石类型识别,实验包括收集多类宝石图像数据集并按7:1:2比例划分。使用Matlab2022a实现算法,提供含中文注释的完整代码及操作视频。GoogLeNet通过其独特的Inception模块,结合数据增强、学习率调整和正则化等优化手段,有效提升了宝石识别的准确性和效率。
|
2月前
|
机器学习/深度学习 人工智能 自然语言处理
深入理解深度学习中的卷积神经网络(CNN)##
在当今的人工智能领域,深度学习已成为推动技术革新的核心力量之一。其中,卷积神经网络(CNN)作为深度学习的一个重要分支,因其在图像和视频处理方面的卓越性能而备受关注。本文旨在深入探讨CNN的基本原理、结构及其在实际应用中的表现,为读者提供一个全面了解CNN的窗口。 ##

热门文章

最新文章