【13】变分自编码器(VAE)的原理介绍与pytorch实现

本文涉及的产品
注册配置 MSE Nacos/ZooKeeper,118元/月
云原生网关 MSE Higress,422元/月
服务治理 MSE Sentinel/OpenSergo,Agent数量 不受限
简介: 【13】变分自编码器(VAE)的原理介绍与pytorch实现

1.VAE的设计思路


VAE作为一个生成模型,其基本思路是很容易理解的:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个自编码器模型。那VAE(变分自编码器)就是在自编码器模型上做进一步变分处理,使得编码器的输出结果能对应到目标分布的均值和方差,如下图所示,具体的方法和思想在后文会介绍:

image.png


VAE最想解决的问题是什么?当然是如何构造编码器和解码器,使得图片能够编码成易于表示的形态,并且这一形态能够尽可能无损地解码回原真实图像。


这似乎听起来与PCA(主成分分析)有些相似,而PCA本身是用来做矩阵降维的:

image.png

如图,X本身是一个矩阵,通过一个变换W变成了一个低维矩阵c,因为这一过程是线性的,所以再通过一个变换就能还原出一个,现在我们要找到一种变换W,使得矩阵X与能够尽可能地一致,这就是PCA做的事情。在PCA中找这个变换W用到的方法是SVD(奇异值分解)算法,这是一个纯数学方法,不再细述,因为在VAE中不再需要使用SVD,直接用神经网络代替。


回顾上述介绍,我们会发现PCA与我们想要构造的自编码器的相似之处是在于,如果把矩阵X视作输入图像,W视作一个编码器,低维矩阵c视作图像的编码,然后和分别视作解码器和生成图像,PCA就变成了一个自编码器网络模型的雏形。

image.png

现在我们需要对这一雏形进行改进。首先一个最明显能改进的地方是用神经网络代替W变换和变换,就得到了如下Deep Auto-Encoder模型:

image.png

这一替换的明显好处是,引入了神经网络强大的拟合能力,使得编码(Code)的维度能够比原始图像(X)的维度低非常多。在一个手写数字图像的生成模型中,Deep Auto-Encoder能够把一个784维的向量(28*28图像)压缩到只有30维,并且解码回的图像具备清楚的辨认度(如下图)。

image.png


至此我们构造出了一个重构图像比较清晰的自编码模型,但是这并没有达到我们真正想要构造的生成模型的标准,因为,对于一个生成模型而言,解码器部分应该是单独能够提取出来的,并且对于在规定维度下任意采样的一个编码,都应该能通过解码器产生一张清晰且真实的图片。


我们先来分析一下现有模型无法达到这一标准的原因。

image.png


如上图所示,假设有两张训练图片,一张是全月图,一张是半月图,经过训练我们的自编码器模型已经能无损地还原这两张图片。接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。一个比较合理的解释是,因为编码和解码的过程使用了深度神经网络,这是一个非线性的变换过程,所以在code空间上点与点之间的迁移是非常没有规律的。


如何解决这个问题呢?我们可以引入噪声,使得图片的编码区域得到扩大,从而掩盖掉失真的空白编码点。

image.png


如上图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内,于是在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。


由此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点。为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。

image.png


那么上述的这种将图像编码由离散变为连续的方法,就是变分自编码的核心思想。下面就会介绍VAE的模型架构,以及解释VAE是如何实现上述构思的。


2.VAE的模型架构


上面这张图就是VAE的模型架构,我们先粗略地领会一下这个模型的设计思想。

image.png

在auto-encoder中,编码器是直接产生一个编码的,但是在VAE中,为了给编码添加合适的噪音,编码器会输出两个编码,一个是原有编码(m1,m2,m3),另外一个是控制噪音干扰程度的编码(σ1,σ2,σ3),第二个编码其实很好理解,就是为随机噪音码(e1,e2,e3)分配权重,然后加上exp(σi)的目的是为了保证这个分配的权重是个正值,最后将原编码与噪音编码相加,就得到了VAE在code层的输出结果(c1,c2,c3)。其它网络架构都与Deep Auto-encoder无异。


损失函数方面,除了必要的重构损失外,VAE还增添了一个损失函数(见上图Minimize2内容),这同样是必要的部分,因为如果不加的话,整个模型就会出现问题:为了保证生成图片的质量越高,编码器肯定希望噪音对自身生成图片的干扰越小,于是分配给噪音的权重越小,这样只需要将(σ1,σ2,σ3)赋为接近负无穷大的值就好了。所以,第二个损失函数就有限制编码器走这样极端路径的作用,这也从直观上就能看出来,exp(σi)-(1+σi)在σi=0处取得最小值,于是(σ1,σ2,σ3)就会避免被赋值为负无穷大。


上述我们只是粗略地理解了VAE的构造机理,但是还有一些更深的原理需要挖掘,例如第二个损失函数为何选用这样的表达式,以及VAE是否真的能实现我们的预期设想,即“图片能够编码成易于表示的形态,并且这一形态能够尽可能无损地解码回原真实图像,是否有相应的理论依据。


下面我们会从理论上深入地分析一下VAE的构造依据以及作用原理。


3.VAE的作用原理


我们知道,对于生成模型而言,主流的理论模型可以分为隐马尔可夫模型HMM、朴素贝叶斯模型NB和高斯混合模型GMM,而VAE的理论基础就是高斯混合模型。


什么是高斯混合模型呢?就是说,任何一个数据的分布,都可以看作是若干高斯分布的叠加。

image.png


如图所示,如果P(X)代表一种分布的话,存在一种拆分方法能让它表示成图中若干浅蓝色曲线对应的高斯分布的叠加。有意思的是,这种拆分方法已经证明出,当拆分的数量达到512时,其叠加的分布相对于原始分布而言,误差是非常非常小的了。


于是我们可以利用这一理论模型去考虑如何给数据进行编码。一种最直接的思路是,直接用每一组高斯分布的参数作为一个编码值实现编码。

image.png


如上图所示,m代表着编码维度上的编号,譬如实现一个512维的编码,m的取值范围就是1,2,3……512。m会服从于一个概率分布P(m)(多项式分布)。现在编码的对应关系是,每采样一个m,其对应到一个小的高斯分布N(μm,∑m),P(X)就可以等价为所有的这些高斯分布的叠加,即:

image.png

其中

image.png

上述的这种编码方式是非常简单粗暴的,它对应的是我们之前提到的离散的、有大量失真区域的编码方式。于是我们需要对目前的编码方式进行改进,使得它成为连续有效的编码。

image.png


现在我们的编码换成一个连续变量z,我们规定z服从正态分布N(0,1)(实际上并不一定要选N(0,1)用,其他的连续分布都是可行的)。每对于一个采样z,会有两个函数μ和σ,分别决定z对应到的高斯分布的均值和方差,然后在积分域上所有的高斯分布的累加就成为了原始分布P(X),即:

image.png

其中

image.png

image.png

image.png

image.png

image.png


image.png

image.png

image.png

image.png


4.VAE的Pytorch实现


1)参考代码

model.py


# 定义变分自编码器VAE
class Variable_AutoEncoder(nn.Module):
    def __init__(self):
        super(Variable_AutoEncoder, self).__init__()
        # 定义编码器
        self.Encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU()
        )
        # 定义解码器
        self.Decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )
        self.fc_m = nn.Linear(64, 20)
        self.fc_sigma = nn.Linear(64, 20)
    def forward(self, input):
        code = input.view(input.size(0), -1)
        code = self.Encoder(code)
        # m, sigma = code.chunk(2, dim=1)
        m = self.fc_m(code)
        sigma = self.fc_sigma(code)
        e = torch.randn_like(sigma)
        c = torch.exp(sigma) * e + m
        # c = sigma * e + m
        output = self.Decoder(c)
        output = output.view(input.size(0), 1, 28, 28)
        return output, m, sigma


train.py


import torch
import torchvision
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from model import Auto_Encoder, Variable_AutoEncoder
import os
# 定义超参数
learning_rate = 1e-3
batch_size = 64
epochsize = 30
root = 'E:/学习/机器学习/数据集/MNIST'
sample_dir = "image5"
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# 图像相关处理操作
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5])   # 一定要去掉这句,不需要Normalize操作
])
# 训练集下载
mnist_train = datasets.MNIST(root=root, train=True, transform=transform, download=False)
mnist_train = DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
# 测试集下载
mnist_test = datasets.MNIST(root=root, train=False, transform=transform, download=False)
mnist_test = DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)
# image,_ = iter(mnist_test).next()
# print("image.shape:",image.shape)   # torch.Size([64, 1, 28, 28])
device = torch.device('cuda')
# 定义并导入网络结构
VAE = Variable_AutoEncoder()
VAE = VAE.to(device)
# VAE.load_state_dict(torch.load('VAE.ckpt'))
criteon = nn.MSELoss()
optimizer = optim.Adam(VAE.parameters(), lr=learning_rate)
print("start train...")
for epoch in range(epochsize):
    # 训练网络
    for batchidx, (realimage, _) in enumerate(mnist_train):
        realimage = realimage.to(device)
        # 生成假图像
        fakeimage, m, sigma = VAE(realimage)
        # 计算KL损失与MSE损失
        # KLD = torch.sum(torch.exp(sigma) - (1 + sigma) + torch.pow(m, 2)) / (input.size(0)*28*28)
        # KLD = torch.sum(torch.exp(sigma) - (1 + sigma) + torch.pow(m, 2))
        # 此公式是直接根据KL Div公式化简,两个分布分别是(0-1)分布与(m,sigma)分布
        # 最后根据像素点与样本批次求平均,既realimage.size(0)*28*28
        KLD = 0.5 * torch.sum(
            torch.pow(m, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (realimage.size(0)*28*28)
        # 计算均方差损失
        # MSE = criteon(fakeimage, realimage)
        MSE = torch.sum(torch.pow(fakeimage - realimage, 2)) / (realimage.size(0)*28*28)
        # 总的损失函数
        loss = MSE + KLD
        # 更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batchidx%300 == 0:
            print("epoch:{}/{}, batchidx:{}/{}, loss:{}, MSE:{}, KLD:{}"
                  .format(epoch, epochsize, batchidx, len(mnist_train), loss, MSE, KLD))
    # 生成图像
    realimage, _ = iter(mnist_test).next()
    realimage = realimage.to(device)
    fakeimage, _, _ = VAE(realimage)
    # 真假图像何必成一张
    image = torch.cat([realimage[0:32], fakeimage[0:32]], dim=0)
    # 保存图像
    save_image(image, os.path.join(sample_dir, 'image-{}.png'.format(epoch + 1)), nrow=8, normalize=True)
    torch.save(VAE.state_dict(), 'VAE.ckpt')


2)训练结果展示

image.png

Epoch1生成的图像

image.png

Epoch10生成的图像

image.png

Epoch30生成的图像


一开始的时候VAE的效果比AE的差,但是训练次数多了之后效果会变好。


3)生成结果展示

test.py


import torch
from torchvision.utils import save_image
from model import Variable_AutoEncoder
import os
epochsize = 20
batch_size = 64
sample_dir = "vae_val_result"
#seed = 0
#torch.manual_seed(seed)
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
VAE = Variable_AutoEncoder()
VAE.load_state_dict(torch.load('VAE.ckpt'))
for epoch in range(epochsize):
    z = torch.randn(batch_size, 20)
    # code = VAE.Encoder(z)
    fakeimage = VAE.Decoder(z)
    fakeimage = fakeimage.view(z.size(0), 1, 28, 28)
    # print("fakeimage.shape:",fakeimage.shape)
    save_image(fakeimage, os.path.join(sample_dir, 'image-{}.png'.format(epoch + 1)), nrow=8, normalize=True)
    print("generate success:",epoch)


我们根据之前训练好的VAE网络来随机的从0-1分布中sample一些噪声出来到Decoder中,结果如下,可以看见能够正常是随机生成图像。

image.png

但是,就结果而言,生成的图像比我们之前利用GAN生成出来的图像要模糊。


5.实现VAE中出现的问题


  • 问题1:训练结果中长时间生成的图像只有少量白点或者全部都是同一个模糊图像

image.png

损失函数出现的问题,在这次实验总,根据随机从0-1高斯分布sample出来的噪声与网络生成出来的N(μ,σ),这两者的KL分布直接计算出来,而不是单纯的使用paper给出的公式。

image.png

计算结果:

image.png


代码表示:

# KLD = torch.sum(torch.exp(sigma) - (1 + sigma) + torch.pow(m, 2))
KLD = 0.5 * torch.sum(
    torch.pow(m, 2) +
    torch.pow(sigma, 2) -
    torch.log(1e-8 + torch.pow(sigma, 2)) - 1
) / (realimage.size(0)*28*28)


  • 问题2:训练结过中开始时生成一个近似成功的图像,但是长期训练后只能得到一些类似的结果

image.png

此问题的原因是使用了Normalize操作数据集,不使用即可


transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5])   
])


  • 问题3:生成结果时出现了全部为0的状况

对于这种情况,主要是VAE的结构设计得不对,以下为原始结构。


# 定义变分自编码器VAE
class Variable_AutoEncoder(nn.Module):
    def __init__(self):
        super(Variable_AutoEncoder, self).__init__()
        # 定义编码器
        self.Encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # 定义解码器
        self.Decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )
    def forward(self, input):
        code = input.view(input.size(0), -1)
        code = self.Encoder(code)
        m, sigma = code.chunk(2, dim=1)
        e = torch.randn_like(sigma)
        c = torch.exp(sigma) * e + m
        output = self.Decoder(c)
        output = output.view(input.size(0), 1, 28, 28)
        return output, m, sigma


对于这种结构,可以正常训练出一个结果,见下图。

image.png

但是当使用此训练好的网络来随机输出一些服从0-1分布噪声进Decoder的时候,结果全为0。

image.png


个人猜测是code的维度不够,不足以保留过多的信息,随后见上述的参考代码model.py,改进过之后便可以正常的生成图像。


参考资料:

本文的理论部分摘抄至:http://www.gwylab.com/note-vae.html

讲解的十分的详细。


相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
目录
相关文章
|
1月前
|
监控 PyTorch 数据处理
通过pin_memory 优化 PyTorch 数据加载和传输:工作原理、使用场景与性能分析
在 PyTorch 中,`pin_memory` 是一个重要的设置,可以显著提高 CPU 与 GPU 之间的数据传输速度。当 `pin_memory=True` 时,数据会被固定在 CPU 的 RAM 中,从而加快传输到 GPU 的速度。这对于处理大规模数据集、实时推理和多 GPU 训练等任务尤为重要。本文详细探讨了 `pin_memory` 的作用、工作原理及最佳实践,帮助你优化数据加载和传输,提升模型性能。
87 4
通过pin_memory 优化 PyTorch 数据加载和传输:工作原理、使用场景与性能分析
|
5月前
|
机器学习/深度学习 PyTorch 编译器
Pytorch的编译新特性TorchDynamo的工作原理和使用示例
PyTorch的TorchDynamo是一个即时编译器,用于优化动态图执行,提高运行效率。它在运行时分析和转换代码,应用优化技术,如操作符融合,然后编译成高效机器码。通过一个包含特征工程、超参数调整、交叉验证的合成数据集示例,展示了TorchDynamo如何减少训练时间并提高模型性能。它易于集成,只需对现有PyTorch代码进行小改动,即可利用其性能提升。TorchDynamo的优化包括动态捕获计算图、应用优化和编译,适用于实时应用和需要快速响应的场景。
93 11
|
5月前
|
资源调度 PyTorch 调度
多任务高斯过程数学原理和Pytorch实现示例
本文探讨了如何使用高斯过程扩展到多任务场景,强调了多任务高斯过程(MTGP)在处理相关输出时的优势。通过独立多任务GP、内在模型(ICM)和线性模型(LMC)的核心区域化方法,MTGP能够捕捉任务间的依赖关系,提高泛化能力。ICM和LMC通过引入核心区域化矩阵来学习任务间的共享结构。在PyTorch中,使用GPyTorch库展示了如何实现ICM模型,包括噪声建模和训练过程。实验比较了MTGP与独立GP,显示了MTGP在预测性能上的提升。
105 7
|
7月前
|
机器学习/深度学习 算法 PyTorch
深入理解PyTorch自动微分:反向传播原理与实现
【4月更文挑战第17天】本文深入解析PyTorch的自动微分机制,重点讨论反向传播的原理和实现。反向传播利用链式法则计算神经网络的梯度,包括前向传播、梯度计算、反向传播及参数更新。PyTorch通过`autograd`模块实现自动微分,使用`Tensor`和计算图记录操作历史以自动计算梯度。通过示例展示了如何在PyTorch中创建张量、定义计算过程及求梯度。掌握这些有助于提升深度学习模型的训练效率。
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
860 0
|
机器学习/深度学习 自然语言处理 算法
RNN、CNN、RNN、LSTM、CTC算法原理,pytorch实现LSTM算法
RNN、CNN、RNN、LSTM、CTC算法原理,pytorch实现LSTM算法
329 0
|
机器学习/深度学习 人工智能 算法
部署教程 | ResNet原理+PyTorch复现+ONNX+TensorRT int8量化部署
部署教程 | ResNet原理+PyTorch复现+ONNX+TensorRT int8量化部署
320 0
|
机器学习/深度学习 人工智能 自然语言处理
【Pytorch神经网络理论篇】 11 卷积网络模型+Sobel算子原理
在微积分中,无限细分的条件是,被细分的对象必须是连续的,例如直线可以无限细分为点、但是若干个点则无法进行细分。
542 0
|
机器学习/深度学习 存储 算法
PyTorch中的傅立叶卷积:通过FFT有效计算大核卷积的数学原理和代码实现
PyTorch中的傅立叶卷积:通过FFT有效计算大核卷积的数学原理和代码实现
596 0
PyTorch中的傅立叶卷积:通过FFT有效计算大核卷积的数学原理和代码实现