使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

简介: 生成对抗网络(GAN)是一种用于生成新的照片,文本或音频的模型。它由两部分组成:生成器和判别器。生成器的作用是生成新的样本,而判别器的作用是识别这些样本是真实的还是假的。两个模型相互博弈,通过不断调整自己的参数来提高自己的能力。生成器希望判别器错误地认为其生成的样本是真实的,而判别器希望能正确地识别生成器生成的样本是假的。最终,生成器会学到如何生成逼真的样本,而判别器会学到如何区分真假样本。

前面的博客讲了如何基于PyTorch使用神经网络识别手写数字

使用PyTorch构建神经网络

下面在此基础上构建一个生成对抗网络,生成对抗网络可以模拟出新的手写数字数据集。

1 生成对抗网络基本概念


生成对抗网络(GAN)是一种用于生成新的照片,文本或音频的模型。它由两部分组成:生成器和判别器。生成器的作用是生成新的样本,而判别器的作用是识别这些样本是真实的还是假的。两个模型相互博弈,通过不断调整自己的参数来提高自己的能力。生成器希望判别器错误地认为其生成的样本是真实的,而判别器希望能正确地识别生成器生成的样本是假的。最终,生成器会学到如何生成逼真的样本,而判别器会学到如何区分真假样本。


一个非常形象的例子,目前的数据集是人民币,生成器是造假币的,判别器是银行。刚开始造假币的只是粗略模仿人民币的印制,银行由于没有经验也分辨不好真钱还是假币。但随着时间推移,银行对鉴别假币越来越有经验,造假币的水平也变得越来越逼真,二者不断进步,这就是GAN网络。



ba3197766bff4d128812e1fd582432ae.png



2 生成对抗网络建模


2.1 建立MnistDataset类

对于非GAN独有的建模部分,讲解不会细化到每一行代码,如有阅读困难可参考本博客使用PyTorch构建神经网络部分的文章。但基本上具备Python的基础知识即可顺利阅读本篇文章。

与神经网络建模相同,我们首先构建一个MnistDataset类,这个类具备getitem功能,可以返回每条数据相应的数据标签label,image_values, target。这些变量的含义分别是:


label:获得了指定数据的第一个数值,也就是这个数据的标签;

target:制作了一个维度为10的张量,标签对应的项是1,其他是0。比如,某个手写数据的标签是2,则这个张量是[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]。

image_values:像素输入的值是0-255,这里对像素数据做了标准化,是值位于0-1之间。

同样,我们定义了一个绘制的功能,这个功能在建模中并没有实际作用,但是会很方便我们快速查看数据是否成功导入。MnistDataset类的全部代码如下:


class MnistDataset:
    def __init__(self, csv_file):
        self.data = pandas.read_csv(csv_file)
        pass
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        # 预期输出的张量制作
        label = self.data.iloc[index, 0]
        target = torch.zeros(10)
        target[label] = 1.0
        # 图像数据标准化
        image_values = torch.FloatTensor(self.data.iloc[index, 1:].values) / 255.0
        return label, image_values, target
    # 制图
    def plot_image(self, index):
        arr = self.data.iloc[index, 1:].values.reshape(28, 28)
        plt.title("label=" + str(self.data.iloc[index, 0]))
        plt.imshow(arr, interpolation='none', cmap='Blues')
        plt.show()


2.2 建立鉴别器

此处的鉴别器与基于PyTorch建立神经网络一文中的鉴别器基本相同。主要不同的是网络的输出层:本鉴别器的的网格为784-200-1。网格的输出层只有一个节点,这是因为鉴别器只需要判断这是真实数据还是虚假数据即可。真实数据为1,虚假数据为0。

鉴别器的主要函数包括:


# 鉴别器类
class Discriminator(nn.Module):
    def __init__(self):
        # 初始化父类
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        # 创造损失函数
        self.loss_function = nn.MSELoss()
        # 创造优化器
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
        # 创造进程计数器
        self.counter = 0
        self.progress = []



对类的初始化中:继承父类nn.Module的初始化属性;并建立784-200-1的神经网络,神经网络的激活函数使用最经典的Sigmoid函数;建立损失函数与优化器,损失函数选择MSE方法(均方误差)。


 

def forward(self, inputs):
        # 执行模型
        return self.model(inputs)



简单的执行功能,能够基于input输出预测结果,即0或1。

def train(self, inputs, targets):
        # 计算输出
        outputs = self.forward(inputs)
        # 计算损失
        loss = self.loss_function(outputs, targets)
        # 赋值进程计数器
        self.counter += 1
        if self.counter % 10 == 0:
            self.progress.append(loss.item())
        if self.counter % 10000 == 0:
            print("counter = ", self.counter)
        # 计算损失梯度,优化权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()


训练模块,可以实现基于模型实际输出与与其输出,不断更新网络的权重。并每隔10次训练计算此时模型的损失,每隔10000次训练打印一次训练次数,方便掌握训练进度。


# 绘制损失与训练过程的关系
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))


对前面每10条保存一次的模型损失函数结果进行绘图。


2.3 测试鉴别器

此处我们还没有编写生成器,但是可以创建一个随机数据集,看看鉴别器是否可以分辨出真实的mnist数据和随机数据。

首先建立一个用于生成随机数据的生成器,size是生成数据的特征数。

def generate_random(size):
    random_data = torch.rand(size)
    return random_data


接下来我们用真是数据与随机数据训练模型

for label, image_data_tensor, target_tensor in mnist_dataset:
    # 真实数据
    D.train(image_data_tensor, torch.FloatTensor([1.0]))
    # 随机数据
    D.train(generate_random(784), torch.FloatTensor([0.0]))


其中真是数据我们希望输出节点的数据输出是1,而随机数据我们希望的输出是0。

在训练完成后,可以使用我们在鉴别器类中定义的绘图功能,查看模型损失的变化情况。同时,也可以再传入4组随机真假数据,来更清晰的查看此时模型的训练情况。


for i in range(4):
  image_data_tensor = mnist_dataset[random.randint(0,60000)][1]
  print( D.forward( image_data_tensor ).item() )
  pass
for i in range(4):
  print( D.forward( generate_random(784) ).item() )
  pass


基于这个运行结果也可以判断出,模型是可以有效的区分真实数据与随机数据的。





daee20399a694c6da186049e3cba3337.png



2.4 Mnist生成器制作

生成器与判别器都是神经网络模型,所以代码基本相同,这里主要讲一下不同的地方。与判别器相比,Mnist生成器应该与判别器的网格结构刚好相反。因为判别器是输入图像输出判别结果,而生成器应该是输入判别结果,输出图像。所以网络的结构可以是1-200-784。事实上,此处我们只要保证输出的格式是784个数据即可,为了让输出的数据更加多元,我们也可以增加输入层的节点数量。这里节点数量使用1,10,甚至是100都是可以的。此处我们以100个输入节点为例。


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 784),
            nn.Sigmoid()
        )
1


除此之外,生成器的训练过程也稍有不同。在使用生成器生成数据后,我们需要将这个数据传入判别器,并使用判别器返回的损失作为这个生成器的损失。在Python中,在一个类调用类一个类的功能是完全可以的因此这一步骤变得简单了很多。


class Generator(nn.Module):
    def train(self, D, inputs, targets):
        # 生成器生成数据
        g_output = self.forward(inputs)
        # 将生成的数据传入判别器
        d_output = D.forward(g_output)
        loss = D.loss_function(d_output, targets)
        if self.counter % 10 == 0:
            self.progress.append(loss.item())
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()



除了以上两项,这个生成器都与鉴别器完全相同,大家按此更改或者直接在文末下载完整版代码均可。

在训练GAN之前,可以检查一下生成器的输出是否正确。方法还是让生成器生成一个数据,然后使用plt包绘制出来


G = Generator()
output = G.forward(generate_random(100))
img = output.detach().numpy().reshape(28,28)
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()



aeb9bdd783d7480bbfa07dd9f714e53c.png



现在,我们的模型中就具备了生成对抗网络的三要素:真实数据、生成器与对抗器。


3 模型的训练

对于生成器,其输入是由我们使用随机数据生成器来产生的。之前我们使用torch.rand进行随机数据的生成,这次可以尝试使用torch.randn。两者的区别是:randn是从标准正态分布中返回一个或多个样本值。


# 生成器使用的随即输入
def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data


同2.3的过程一样,在训练过程中,我们将真是数据与生成器产出的数据交替传入鉴别器,只是此处增加了对生成器的训练。


for label, image_data_tensor, target_tensor in mnist_dataset:
    # 使用真实数据训练判别器
    D.train(image_data_tensor, torch.FloatTensor([1.0]))
    # 使用生成器数据训练判别器
    # 使用 detach() 截断梯度计算
    D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))
    # 训练生成器
    G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))



值得注意的是,在这里使用生成器数据训练判别器时,我们使用detach进行了截断。这个作用是在计算梯度时,对下图红叉所示地方进行切断,使梯度计算到这里就截止了,也就是此次计算只对生成器有效。这一操作的功能是降低模型的计算量。




bdd8d3a6962c445291d8c54d8cbe031b.png



同样此处也可以引入time模块对训练进行计时。


4 模型表现的判断

前面在定义类时,我们已经内置好了绘制损失随训练变化的功能,这里直接调用即可。


D.plot_progress()
G.plot_progress()



3f3dc34031dd4e56a1acb587b969e05e.png




鉴别器的损失基本看不到明显的变化,这是因为尽管鉴别器的能力不断提升,生成器的能力却也在不断提升。





c4832b99fc74401f92c689706716bcb2.png


生成器稍有不同,在前期出现了下降的趋势,在一定程度上骗过了鉴别器,但后期随着鉴别器能力的提升,生成器的随时也趋于稳定。


我们也可以依据生成器输出的图像,来更直观的判断生成器的表现。


f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
    for j in range(3):
        output = G.forward(generate_random_seed(100))
        img = output.detach().numpy().reshape(28,28)
        axarr[i,j].imshow(img, interpolation='none', cmap='Blues')
plt.show()






4472441cdb9c4348ab06c306db58632d.png

我们使用plt建立了一个2行3列的画布,并向生成器传入了随机参数,可以看到生成器的输出已经和手写图像很像了。


以上内容的全部代码,可以直接打包下载


相关文章
|
7月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
303 0
|
7月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
196 1
|
8月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
1294 56
|
11月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
|
11月前
|
算法 PyTorch 算法框架/工具
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
|
11月前
|
机器学习/深度学习 自然语言处理 PyTorch
基于Pytorch Gemotric在昇腾上实现GAT图神经网络
本实验基于昇腾平台,使用PyTorch实现图神经网络GAT(Graph Attention Networks)在Pubmed数据集上的分类任务。内容涵盖GAT网络的创新点分析、图注意力机制原理、多头注意力机制详解以及模型代码实战。实验通过两层GAT网络对Pubmed数据集进行训练,验证模型性能,并展示NPU上的内存使用情况。最终,模型在测试集上达到约36.60%的准确率。
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
|
11月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本实验基于PyTorch Geometric,在昇腾平台上实现GraphSAGE图神经网络,使用CiteSeer数据集进行分类训练。内容涵盖GraphSAGE的创新点、算法原理、网络架构及实战分析。GraphSAGE通过采样和聚合节点邻居特征,支持归纳式学习,适用于未见节点的表征生成。实验包括模型搭建、训练与验证,并在NPU上运行,最终测试准确率达0.665。
|
11月前
|
机器学习/深度学习 算法 PyTorch
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
深度学习近年来在多个领域取得了显著进展,但其核心组件——人工神经元和反向传播算法自提出以来鲜有根本性突破。穿孔反向传播(Perforated Backpropagation)技术通过引入“树突”机制,模仿生物神经元的计算能力,实现了对传统神经元的增强。该技术利用基于协方差的损失函数训练树突节点,使其能够识别神经元分类中的异常模式,从而提升整体网络性能。实验表明,该方法不仅可提高模型精度(如BERT模型准确率提升3%-17%),还能实现高效模型压缩(参数减少44%而无性能损失)。这一革新为深度学习的基础构建模块带来了新的可能性,尤其适用于边缘设备和大规模模型优化场景。
441 16
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
1068 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体

热门文章

最新文章

推荐镜像

更多
下一篇
开通oss服务