自编码神经网络是一种无监督学习算法,目标是让输出值与输入值尽可能相似
1.网络结构
简单的自编码器模型是一种三层神经网络模型,包含了输入层、隐藏层、输出重构层。
在实际中,我们往往设计成两层模型
- 编码层
- 解码层
编码层负责将数据读取,并且进行一系列线性变换,把输入样本压缩到隐藏层中。
解码层需要将复杂的网络结构进行还原,尽可能地把还原后的值与输入值做到相似
2.功能
自编码神经网络,目标是实现输出的结果和输入的结果尽可能相似,并且有很重要的一个步骤就是将输入的数据进行了“压缩”,也就是把输入的特征映射到隐藏层,而隐藏层其实样本的维度是大大小于输入样本的维度的,这就是自编码器的第一个功能实现特征降维: 它总会学习到最主要的特征,从而为后续的解码操作做好铺垫,所以我们在编码层之后就可以得到输入数据的主要特征向量
第二个功能就是它的设计初衷,实现近似输出。模型学会了编码和解码之后,我们可以先进行编码,得到输入的主要特征,然后再进行解码得到一个近似于原始输入的输出。那我们也可以自己设置一些编码后的特征,然后使用解码器进行解码,可以得到一些“令人惊喜”的输出。讲到这里,其实就跟后面的生成对抗网络可以联系起来,当我们训练好一个GAN的时候,我们可以利用生成模型将输入的随机噪音生成图片;而特别是在styleGAN中,我们输入潜在因子,通过映射网络得到中间潜在空间,并且由于映射网络和AdaIN,我们甚至可以不用在意初始输入,只需要在卷积、AdaIN之前添加一些随机噪音就可以控制生成图片的特征。
3.代码部分
这次demo代码是基于pytorch写的,以后用空补上tensorflow版本的
import pandas as pd import numpy as np import matplotlib.pyplot as plt import torch from torch import nn,optim from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import datasets,transforms from torchvision.utils import save_image import seaborn as sns import os import warnings plt.rcParams['font.sans-serif']='SimHei' plt.rcParams['axes.unicode_minus']=False warnings.filterwarnings('ignore') %matplotlib inline 复制代码
# 设置参数 batch_size = 100 learning_rate=1e-2 num_epoches=3 # 导入数据集 train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) 复制代码
# 定义网络 class autoencoder(nn.Module): def __init__(self): super().__init__() # 编码器 self.encoder = nn.Sequential( nn.Linear(28*28, 128), nn.Tanh(), nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 12), nn.Tanh(), nn.Linear(12, 3), ) # 解码器 self.decoder = nn.Sequential( nn.Linear(3, 12), nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid(), ) def forward(self, x): encode = self.encoder(x) decode = self.decoder(encode) return encode, decode 复制代码
# 实例化 net = autoencoder().cuda() # 损失函数以及优化函数 loss_func = nn.MSELoss().cuda() optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # 训练过程可视化 def list_img(i, img, title): img = img.reshape(28, 28) plt.subplot(2, 5, i+1) plt.imshow(img) plt.title('%s' % (title)) def generate_test(inputs, title=''): plt.figure(figsize=(15, 6)) for i in range(len(inputs)): img = inputs[i].view(-1, 28*28).cuda() hidden, outputs = net(img) list_img(i, outputs.cpu().detach().numpy(), title) plt.show() 复制代码
# 训练部分 result = [] test_inputs = [] hiddens=[] plt.figure(figsize=(15, 6)) for i, (img, _) in enumerate(test_loader): if i > 4 : break test_inputs.append(img) list_img(i, img.numpy(), 'truth') plt.show() for e in range(num_epoches): for i, (inputs, _) in enumerate(train_loader): inputs = inputs.view(-1, 28*28).cuda() optimizer.zero_grad() hidden, outputs = net(inputs) hiddens.append(hidden) loss = loss_func(outputs, inputs) loss.backward() optimizer.step() if i % 100 == 0: result.append(float(loss)) if i % 500 == 0: generate_test(test_inputs, 'generation') 复制代码
到最后可以看出模型训练生成的图片以及和输入的真实图片十分接近了
误差也只有0.03
from matplotlib import cm from mpl_toolkits.mplot3d import Axes3D %matplotlib inline # 可视化结果 view_data = Variable((train_dataset.train_data[:500].type(torch.FloatTensor).view(-1, 28*28) / 255. - 0.5) / 0.5).cuda() encode, _ = net(view_data) # 提取压缩的特征值 fig = plt.figure(2) ax = Axes3D(fig) # 3D 图 # x, y, z 的数据值 X = encode.data[:, 0].cpu().numpy() Y = encode.data[:, 1].cpu().numpy() Z = encode.data[:, 2].cpu().numpy() values = train_dataset.train_labels[:500].numpy() # 标签值 for x, y, z, s in zip(X, Y, Z, values): c = cm.rainbow(int(255*s/9)) # 上色 ax.text(x, y, z, s, backgroundcolor=c) # 标位子 ax.set_xlim(X.min(), X.max()) ax.set_ylim(Y.min(), Y.max()) ax.set_zlim(Z.min(), Z.max()) plt.show() 复制代码
在三维空间中,各个数字的分布如下
最后,根据图上的位置,自己写一些随机的输入特征给解码器,看它能否得到我们想要的图像
# 看图上4的位置大概在0附近,-0.5附近,-0.5-0之间 code = Variable(torch.FloatTensor([[0.02,-0.543,-0.012]])).cuda() decode = net.decoder(code) decode_img = decode.data.reshape(28,28).cpu().numpy() * 255 plt.imshow(decode_img.astype('uint8')) # 生成图片 plt.show() 复制代码
我自己随机设置的输入已经很好的满足了我的预期,如果我设置一些图上没有的数据呢,会怎样?
结果生成了一个问号,挺有意思的