【Pytorch】(九)生成对抗网络(GAN)

简介:

@[toc]

生成对抗网络(GAN)

在这里插入图片描述
图片来源:

生成序列[2]

实现:输入[0.5],输出[1, 0, 1, 1, 0, 0, 1, 0]

from torch import nn
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

torch.manual_seed(1)
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(8, 4),
            nn.Sigmoid(),
            nn.Linear(4, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.model(inputs)


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(1, 4),
            nn.Sigmoid(),
            nn.Linear(4, 8),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.model(inputs)


def train():
    D = Discriminator()
    G = Generator()
    loss_func = nn.BCELoss() # - [p * log(q) + (1-p) * log(1-q)]
    optimiser_D = torch.optim.Adam(D.parameters(), lr=1e-3)
    optimiser_G = torch.optim.Adam(G.parameters(), lr=1e-3)
    output_list = []
    for i in range(1000):
        # 1 训练判别器
        # 1.1 真实数据real_data输入D,得到d_real
        real_data = torch.tensor([1, 0, 1, 1, 0, 0, 1, 0], dtype=torch.float)
        d_real = D(real_data)
        
        # 1.2 生成数据的输出fake_data输入D,得到d_fake
        fake_data = G(torch.tensor([0.5])).detach()  # detach:只更新判别器的参数
        d_fake = D(fake_data)
        
        # 1.3 计算损失值 ,判别器学习使得d_real->1、d_fake->0
        loss_d_real = loss_func(d_real, torch.tensor([1.0]))
        loss_d_fack = loss_func(d_fake, torch.tensor([0.]))
        loss_d = loss_d_real + loss_d_fack
        
        # 1.4 反向传播更新参数
        optimiser_D.zero_grad()
        loss_d.backward()
        optimiser_D.step()

        # 2 训练生成器
        # 2.1 G的输出fake_data输入D,得到d_g_fake
        fake_data = G(torch.tensor([0.5]))  
        d_g_fake = D(fake_data)
        # 2.2 计算损失值,生成器学习使得d_g_fake->1
        loss_g = loss_func(d_g_fake, torch.tensor([1.0]))
        
        # 2.3 反向传播,优化
        optimiser_G.zero_grad()
        loss_g.backward()
        optimiser_G.step()
        
        if i % 100 == 0:
            output_list.append(fake_data.detach().numpy())
            print(fake_data.detach().numpy())
            
    plt.imshow(np.array(output_list), cmap='Blues')
    plt.colorbar()
    plt.xticks([])
    plt.ylabel('迭代次数X100')
    plt.show()


if __name__ == '__main__':
    train()


在这里插入图片描述

[1, 0, 1, 1, 0, 0, 1, 0]

GAN生成服从正态分布的数据3

实现:输入大小为[1,16]的随机数,生成大小为[1,512]且服从正态分布的数据。

import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)
np.random.seed(1)
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(16, 128),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(128, 256),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(256, 512)
        )

    def forward(self, inputs):
        return self.model(inputs)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(512, 256),
            nn.Tanh(),
            nn.Linear(256, 128),
            nn.Tanh(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.model(inputs)


def normal_pdf(x, mu, sigma):
    ''' 正态分布概率密度函数'''
    pdf = np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))
    return pdf


def draw(G, epoch, g_input_size):
    '''画目标分布和生成分布'''
    plt.clf()
    # 画出目标分布
    x = np.arange(-3, 9, 0.2)
    y = normal_pdf(x, 3, 1)
    plt.plot(x, y, 'r', linewidth=2)

    # 画出生成的分布
    test_data = torch.rand(1, g_input_size)
    data = G(test_data).detach().numpy()
    mean = data.mean()
    std = data.std()
    x = np.arange(np.floor(data.min()) - 5, np.ceil(data.max()) + 5, 0.2)
    y = normal_pdf(x, mean, std)
    plt.plot(x, y, 'orange', linewidth=2)
    plt.hist(data.flatten(), bins=20, color='y', alpha=0.5, rwidth=0.9, density=True)


    plt.legend(['目标分布', '生成分布'])
    plt.title('GAN:epoch' + str(epoch))
    plt.show()
    plt.pause(0.1)


def train():
    
    G_mean = []
    G_std = [] # 用于记录生成器生成的数据的均值和方差
    data_mean = 3
    data_std = 1  # 目标分布的均值和方差
    batch_size = 64
    g_input_size = 16
    g_output_size = 512
    
    epochs = 1001
    d_epoch = 1  # 每个epoch判别器的训练轮数

    # 初始化网络
    D = Discriminator()
    G = Generator()

    # 初始化优化器和损失函数
    d_learning_rate = 0.01
    g_learning_rate = 0.001
    loss_func = nn.BCELoss()  # - [p * log(q) + (1-p) * log(1-q)]
    optimiser_D = optim.Adam(D.parameters(), lr=d_learning_rate)
    optimiser_G = optim.Adam(G.parameters(), lr=g_learning_rate)

    plt.ion()
    for epoch in range(epochs):
        G.train()
        # 1 训练判别器d_steps次
        for _ in range(d_epoch):
            # 1.1 真实数据real_data输入D,得到d_real
            real_data = torch.tensor(np.random.normal(data_mean, data_std, (batch_size, g_output_size)),
                                     dtype=torch.float)
            d_real = D(real_data)
            # 1.2 生成数据的输出fake_data输入D,得到d_fake
            g_input = torch.rand(batch_size, g_input_size)
            fake_data = G(g_input).detach()  # detach:只更新判别器的参数
            d_fake = D(fake_data)

            # 1.3 计算损失值 ,判别器学习使得d_real->1、d_fake->0
            loss_d_real = loss_func(d_real, torch.ones([batch_size, 1]))
            loss_d_fake = loss_func(d_fake, torch.zeros([batch_size, 1]))
            d_loss = loss_d_real + loss_d_fake

            # 1.4 反向传播,优化
            optimiser_D.zero_grad()
            d_loss.backward()
            optimiser_D.step()

        # 2 训练生成器
        # 2.1 G输入g_input,输出fake_data。fake_data输入D,得到d_g_fake
        g_input = torch.rand(batch_size, g_input_size)
        fake_data = G(g_input)
        d_g_fake = D(fake_data)

        # 2.2 计算损失值,生成器学习使得d_g_fake->1
        loss_G = loss_func(d_g_fake, torch.ones([batch_size, 1]))

        # 2.3 反向传播,优化
        optimiser_G.zero_grad()
        loss_G.backward()
        optimiser_G.step()
        # 2.4 记录生成器输出的均值和方差
        G_mean.append(fake_data.mean().item())
        G_std.append(fake_data.std().item())

        if epoch % 10 == 0:
            print("Epoch: {}, 生成数据的均值: {}, 生成数据的标准差: {}".format(epoch, G_mean[-1], G_std[-1]))
            print('-' * 10)
            G.eval()
            draw(G, epoch, g_input_size)

    plt.ioff()
    plt.show()
    plt.plot(G_mean)
    plt.title('均值')
    plt.savefig('gan_mean.jpg')
    plt.show()

    plt.plot(G_std)
    plt.title('标准差')
    plt.savefig('gan_std.jpg')
    plt.show()

if __name__ == '__main__':
    train()

https://www.bilibili.com/video/av383651076/?zw

参考:

[1] Generative adversarial nets, https://arxiv.org/abs/1406.2661

[2]PyTorch生成对抗网络编程。作者:Tariq Rashid;译者:韩江雷;出版社:人民邮电出版社

[3]https://mathpretty.com/10808.html

[4]https://www.jianshu.com/p/7d3a17f00312

[5] 视频BGM(人工智能AIVA作曲):Octet No. 1 in D Major, Op. 3: A Little Chamber Music

重点在【Pytorch】(十),未完待续。

相关文章
|
6月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
397 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
|
4月前
|
机器学习/深度学习 算法 PyTorch
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
深度学习近年来在多个领域取得了显著进展,但其核心组件——人工神经元和反向传播算法自提出以来鲜有根本性突破。穿孔反向传播(Perforated Backpropagation)技术通过引入“树突”机制,模仿生物神经元的计算能力,实现了对传统神经元的增强。该技术利用基于协方差的损失函数训练树突节点,使其能够识别神经元分类中的异常模式,从而提升整体网络性能。实验表明,该方法不仅可提高模型精度(如BERT模型准确率提升3%-17%),还能实现高效模型压缩(参数减少44%而无性能损失)。这一革新为深度学习的基础构建模块带来了新的可能性,尤其适用于边缘设备和大规模模型优化场景。
124 16
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
|
4月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
|
4月前
|
机器学习/深度学习 自然语言处理 PyTorch
基于Pytorch Gemotric在昇腾上实现GAT图神经网络
本实验基于昇腾平台,使用PyTorch实现图神经网络GAT(Graph Attention Networks)在Pubmed数据集上的分类任务。内容涵盖GAT网络的创新点分析、图注意力机制原理、多头注意力机制详解以及模型代码实战。实验通过两层GAT网络对Pubmed数据集进行训练,验证模型性能,并展示NPU上的内存使用情况。最终,模型在测试集上达到约36.60%的准确率。
|
7月前
|
机器学习/深度学习 数据可视化 算法
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
神经常微分方程(Neural ODEs)是深度学习领域的创新模型,将神经网络的离散变换扩展为连续时间动力系统。本文基于Torchdyn库介绍Neural ODE的实现与训练方法,涵盖数据集构建、模型构建、基于PyTorch Lightning的训练及实验结果可视化等内容。Torchdyn支持多种数值求解算法和高级特性,适用于生成模型、时间序列分析等领域。
346 77
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
|
4月前
|
算法 PyTorch 算法框架/工具
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
|
4月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本实验基于PyTorch Geometric,在昇腾平台上实现GraphSAGE图神经网络,使用CiteSeer数据集进行分类训练。内容涵盖GraphSAGE的创新点、算法原理、网络架构及实战分析。GraphSAGE通过采样和聚合节点邻居特征,支持归纳式学习,适用于未见节点的表征生成。实验包括模型搭建、训练与验证,并在NPU上运行,最终测试准确率达0.665。
|
8月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
415 66
|
6月前
|
机器学习/深度学习 数据采集 编解码
基于DeepSeek的生成对抗网络(GAN)在图像生成中的应用
生成对抗网络(GAN)通过生成器和判别器的对抗训练,生成高质量的合成数据,在图像生成等领域展现巨大潜力。DeepSeek作为高效深度学习框架,提供便捷API支持GAN快速实现和优化。本文详细介绍基于DeepSeek的GAN技术,涵盖基本原理、实现步骤及代码示例,展示其在图像生成中的应用,并探讨优化与改进方法,如WGAN、CGAN等,解决模式崩溃、训练不稳定等问题。最后,总结GAN在艺术创作、数据增强、图像修复等场景的应用前景。
664 16

热门文章

最新文章

推荐镜像

更多