【Pytorch】(十)生成对抗网络之WGAN,WGAN-GP

简介:

WGAN,WGAN-GP

原理

GAN有多种解释,这里我总结一下:

原始论文解读

https://zhuanlan.zhihu.com/p/25071913


(苏神专场)
互怼的艺术:从零直达WGAN-GP

https://spaces.ac.cn/archives/4439

从Wasserstein距离、对偶理论到WGAN

https://spaces.ac.cn/archives/6280

动力学角度

https://spaces.ac.cn/archives/6583

能量视角下的GAN模型

https://kexue.fm/archives/6316

https://kexue.fm/archives/6331

https://kexue.fm/archives/6612


几何角度

A Geometric View of Optimal Transportation and Generative Model ,https://arxiv.org/abs/1710.05488

我之前尝试着看懂这篇论文,发现需要懂最优传输理论。

然后我就找了一些最优传输的资料(感兴趣的可以在公众号后台回复CHSH获取):


Computational Optimal Transport 这本书华东师范大学的王祥丰老师正在翻译。https://zhuanlan.zhihu.com/p/499401130

又发现没有学过测度论很难读懂。( 限制人学习自由的永远是数学,划线以下是另一个境界:
在这里插入图片描述

不过,我找到了一篇不用测度论解析的论文(工科生狂喜):

https://sci-hub.st/10.1109/msp.2017.2695801

看完能大概知道最优传输是干什么的,以及这个理论的奠基者蒙日(Monge)和康托罗维奇(Kantorovich)做了什么。

Pytorch实现:生成正态分布数据

理论很难,实现倒是不难。

WGAN

在这里插入图片描述
图片来源:[1]

import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)
np.random.seed(1)
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(16, 128),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(128, 256),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(256, 512)
        )

    def forward(self, inputs):
        return self.model(inputs)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(512, 256),
            nn.Tanh(),
            nn.Linear(256, 128),
            nn.Tanh(),
            nn.Linear(128, 1),
            # nn.Sigmoid()  # 去掉
        )

    def forward(self, inputs):
        return self.model(inputs)


def normal_pdf(x, mu, sigma):
    '''# 正态分布,概率密度函数'''
    pdf = np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))
    return pdf


def draw(G, epoch, g_input_size):
    '''画目标分布和生成分布'''
    plt.clf()
    # 画出目标分布
    x = np.arange(-3, 9, 0.2)
    y = normal_pdf(x, 3, 1)
    plt.plot(x, y, 'r', linewidth=2)

    # 画出生成的分布
    test_data = torch.rand(1, g_input_size)
    data = G(test_data).detach().numpy()
    mean = data.mean()
    std = data.std()
    x = np.arange(np.floor(data.min()) - 5, np.ceil(data.max()) + 5, 0.2)
    y = normal_pdf(x, mean, std)
    plt.plot(x, y, 'orange', linewidth=2)
    plt.hist(data.flatten(), bins=20, color='y', alpha=0.5, rwidth=0.9, density=True)

    # 坐标图设置
    plt.legend(['目标分布', '生成分布'])
    plt.show()
    plt.pause(0.1)


def train():
    # 用于记录生成器生成的数据的均值和方差
    G_mean = []
    G_std = []

    # 目标分布的均值和方差
    data_mean = 3
    data_std = 1

    feature_num = 512
    batch_size = 64
    g_input_size = 16
    epochs = 1001
    d_epoch = 1  # 判别器的训练轮数

    # 初始化网络
    D = Discriminator()
    G = Generator()

    # 初始化优化器
    d_learning_rate = 0.01
    g_learning_rate = 0.001
    # loss_func = nn.BCELoss()

    optimiser_D = optim.RMSprop(D.parameters(), lr=d_learning_rate)
    optimiser_G = optim.RMSprop(G.parameters(), lr=g_learning_rate)

    clip_value = 0.01
    plt.ion()
    for epoch in range(epochs):
        G.train()
        # 1 训练判别器d_steps次
        for _ in range(d_epoch):
            # 1.1 真实数据real_data输入D,得到d_real
            real_data = torch.tensor(np.random.normal(data_mean, data_std, (batch_size, feature_num)),
                                     dtype=torch.float)
            d_real = D(real_data)
            # 1.2 生成数据的输出fake_data输入D,得到d_fake
            g_input = torch.rand(batch_size, g_input_size)
            fake_data = G(g_input).detach()  # detach:只更新判别器的参数
            d_fake = D(fake_data)

            # 1.3 计算损失值,最大化EM距离
            d_loss = -(d_real.mean() - d_fake.mean())

            # 1.4 反向传播,优化
            optimiser_D.zero_grad()
            d_loss.backward()
            optimiser_D.step()

            # 1.5 截断
            for p in D.parameters():
                p.data.clamp_(-clip_value, clip_value)

        # 2 训练生成器
        # 2.1 G输入g_input,输出fake_data。fake_data输入D,得到d_g_fake
        g_input = torch.rand(batch_size, g_input_size)
        fake_data = G(g_input)
        d_g_fake = D(fake_data)

        # 2.2 计算损失值,最小化EM距离
        g_loss = -d_g_fake.mean()

        # 2.3 反向传播,优化
        optimiser_G.zero_grad()
        g_loss.backward()
        optimiser_G.step()

        # 2.4 记录生成器输出的均值和方差
        G_mean.append(fake_data.mean().item())
        G_std.append(fake_data.std().item())

        if epoch % 10 == 0:
            print("Epoch: {}, 生成数据的均值: {}, 生成数据的标准差: {}".format(epoch, G_mean[-1], G_std[-1]))
            print('-' * 10)
            G.eval()
            draw(G, epoch, g_input_size)

    plt.ioff()
    plt.show()
    plt.plot(G_mean)
    plt.title('均值')
    plt.savefig('wgan_mean')
    plt.show()

    plt.plot(G_std)
    plt.title('标准差')
    plt.savefig('wgan_std')
    plt.show()


if __name__ == '__main__':
    train()


WGAN-GP

在这里插入图片描述
图片来源[2]

import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)
np.random.seed(1)
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(16, 128),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(128, 256),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(256, 512)
        )

    def forward(self, inputs):
        return self.model(inputs)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(512, 256),
            nn.Tanh(),
            nn.Linear(256, 128),
            nn.Tanh(),
            nn.Linear(128, 1),
        )

    def forward(self, inputs):
        return self.model(inputs)


def cal_gradient_penalty(D, real, fake):
    # 每一个样本对应一个sigma。样本个数为64,特征数为512:[64,512]
    sigma = torch.rand(real.size(0), 1)  # [64,1]
    sigma = sigma.expand(real.size())  # [64, 512]
    # 按公式计算x_hat
    x_hat = sigma * real + (torch.tensor(1.) - sigma) * fake
    x_hat.requires_grad = True
    # 为得到梯度先计算y
    d_x_hat = D(x_hat)

    # 计算梯度,autograd.grad返回的是一个元组(梯度值,)
    gradients = torch.autograd.grad(outputs=d_x_hat, inputs=x_hat,
                                    grad_outputs=torch.ones(d_x_hat.size()),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    # 利用梯度计算出gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


def normal_pdf(x, mu, sigma):
    '''正态分布的概率密度函数'''
    pdf = np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))
    return pdf



def draw(G, epoch, g_input_size):
    '''画目标分布和生成分布'''
    plt.clf()
    # 画出目标分布
    x = np.arange(-3, 9, 0.2)
    y = normal_pdf(x, 3, 1)
    plt.plot(x, y, 'r', linewidth=2)

    # 画出生成的分布
    test_data = torch.rand(1, g_input_size)
    data = G(test_data).detach().numpy()
    mean = data.mean()
    std = data.std()
    x = np.arange(np.floor(data.min()) - 5, np.ceil(data.max()) + 5, 0.2)
    y = normal_pdf(x, mean, std)
    plt.plot(x, y, 'orange', linewidth=2)
    plt.hist(data.flatten(), bins=20, color='y', alpha=0.5, rwidth=0.9, density=True)

    # 坐标图设置
    plt.legend(['目标分布', '生成分布'])

    plt.show()
    plt.pause(0.1)


def train():
    G_mean = []
    G_std = []  # 用于记录生成器生成的数据的均值和方差
    data_mean = 3
    data_std = 1  # 目标分布的均值和方差
    batch_size = 64
    g_input_size = 16
    g_output_size = 512
    epochs = 1001
    d_epoch = 1  # 每个epoch判别器的训练轮数

    # 初始化网络
    D = Discriminator()
    G = Generator()

    # 初始化优化器
    d_learning_rate = 0.01
    g_learning_rate = 0.001
    # loss_func = nn.BCELoss()
    optimiser_D = optim.Adam(D.parameters(), lr=d_learning_rate)
    optimiser_G = optim.Adam(G.parameters(), lr=g_learning_rate)

    # clip_value = 0.01
    plt.ion()
    for epoch in range(epochs):
        G.train()
        # 1 训练判别器d_steps次
        for _ in range(d_epoch):
            # 1.1 真实数据real_data输入D,得到d_real
            real_data = torch.tensor(np.random.normal(data_mean, data_std, (batch_size, g_output_size)),
                                     dtype=torch.float)
            d_real = D(real_data)
            # 1.2 生成数据的输出fake_data输入D,得到d_fake
            g_input = torch.rand(batch_size, g_input_size)
            fake_data = G(g_input).detach()  # detach:只更新判别器的参数
            d_fake = D(fake_data)

            # 1.3 计算损失值,最大化EM距离
            d_loss = -(d_real.mean() - d_fake.mean())
            gradient_penalty = cal_gradient_penalty(D, real_data, fake_data)
            d_loss = d_loss + gradient_penalty * 0.5  # lambda=0.5,这个参数对效果影响很大

            # 1.4 反向传播,优化
            optimiser_D.zero_grad()
            d_loss.backward()
            optimiser_D.step()

        # 2 训练生成器
        # 2.1 G输入g_input,输出fake_data。fake_data输入D,得到d_g_fake
        g_input = torch.rand(batch_size, g_input_size)
        fake_data = G(g_input)
        d_g_fake = D(fake_data)

        # 2.2 计算损失值,最小化EM距离
        g_loss = -d_g_fake.mean()

        # 2.3 反向传播,优化
        optimiser_G.zero_grad()
        g_loss.backward()
        optimiser_G.step()

        # 2.4 记录生成器输出的均值和方差
        G_mean.append(fake_data.mean().item())
        G_std.append(fake_data.std().item())

        if epoch % 10 == 0:
            print("Epoch: {}, 生成数据的均值: {}, 生成数据的标准差: {}".format(epoch, G_mean[-1], G_std[-1]))
            print('-' * 10)
            G.eval()
            draw(G, epoch, g_input_size)

    plt.ioff()
    plt.show()
    plt.plot(G_mean)
    plt.title('均值')
    plt.show()

    plt.plot(G_std)
    plt.title('标准差')
    plt.show()


if __name__ == '__main__':
    train()

结果对比

GAN WGAN WGAN-GP
均值
标准差

[video(video-oiH52xgA-1651718515999)(type-bilibili)(url-https://player.bilibili.com/player.html?aid=213792565)(image-https://img-blog.csdnimg.cn/img_convert/ba879105de09ad59af1af9d19231fba0.png)(title-WGAN、WGAN-GP生成正态分布)]

[1] Wasserstein GAN, https://arxiv.org/abs/1701.07875

[2]Improved Training of Wasserstein GANs,https://arxiv.org/abs/1704.00028v3

相关文章
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
|
4月前
|
机器学习/深度学习 算法 PyTorch
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
深度学习近年来在多个领域取得了显著进展,但其核心组件——人工神经元和反向传播算法自提出以来鲜有根本性突破。穿孔反向传播(Perforated Backpropagation)技术通过引入“树突”机制,模仿生物神经元的计算能力,实现了对传统神经元的增强。该技术利用基于协方差的损失函数训练树突节点,使其能够识别神经元分类中的异常模式,从而提升整体网络性能。实验表明,该方法不仅可提高模型精度(如BERT模型准确率提升3%-17%),还能实现高效模型压缩(参数减少44%而无性能损失)。这一革新为深度学习的基础构建模块带来了新的可能性,尤其适用于边缘设备和大规模模型优化场景。
124 16
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
|
4月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
|
4月前
|
机器学习/深度学习 自然语言处理 PyTorch
基于Pytorch Gemotric在昇腾上实现GAT图神经网络
本实验基于昇腾平台,使用PyTorch实现图神经网络GAT(Graph Attention Networks)在Pubmed数据集上的分类任务。内容涵盖GAT网络的创新点分析、图注意力机制原理、多头注意力机制详解以及模型代码实战。实验通过两层GAT网络对Pubmed数据集进行训练,验证模型性能,并展示NPU上的内存使用情况。最终,模型在测试集上达到约36.60%的准确率。
|
7月前
|
机器学习/深度学习 数据可视化 算法
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
神经常微分方程(Neural ODEs)是深度学习领域的创新模型,将神经网络的离散变换扩展为连续时间动力系统。本文基于Torchdyn库介绍Neural ODE的实现与训练方法,涵盖数据集构建、模型构建、基于PyTorch Lightning的训练及实验结果可视化等内容。Torchdyn支持多种数值求解算法和高级特性,适用于生成模型、时间序列分析等领域。
346 77
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
|
4月前
|
算法 PyTorch 算法框架/工具
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
|
4月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本实验基于PyTorch Geometric,在昇腾平台上实现GraphSAGE图神经网络,使用CiteSeer数据集进行分类训练。内容涵盖GraphSAGE的创新点、算法原理、网络架构及实战分析。GraphSAGE通过采样和聚合节点邻居特征,支持归纳式学习,适用于未见节点的表征生成。实验包括模型搭建、训练与验证,并在NPU上运行,最终测试准确率达0.665。
|
9月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
215 17
|
9月前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
167 10
|
9月前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。

热门文章

最新文章

推荐镜像

更多