开发者社区> 阿里云大数据Al技术> 正文
阿里云
为了无法计算的价值
打开APP
阿里云APP内打开

【DSW Gallery】使用Tensorflow来构建AutoEncoder

简介: 本文基于TensorFlow 1.x版本,实现了一个自编码器。自编码器是一个应用比较广泛的神经网络。他可以用来做非监督的异常检测,也可以用在特征工程之中,衡量feature之间的高阶非线性关系等等。
+关注继续查看

直接使用

请打开使用Tensorflow来构建AutoEncoder,并点击右上角 “ 在DSW中打开” 。

image.png


使用Tensorflow实现自编码器

AutoEncoder 自编码器(autoencoder, AE)是一类在半监督学习和非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning)。 自编码器包含编码器(encoder)和解码器(decoder)两部分。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。 自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection)。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising)、神经风格迁移(neural style transfer)等。

30-1.png

  1. 数据降维
  • 相对于PCA,AutoEncoder可以实现线性/非线性的降维。
  1. 异常检测
  • 基于训练好的 Encoder - Decoder,可以实现对异常数据样本的检测

本文是基于Tensorflow 1.x构建一个AutoEncoder模型

from __future__ import division, print_function, absolute_import

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
  1. 导入数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
WARNING:tensorflow:From <ipython-input-2-c3d55fec490c>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
  1. 设置模型参数,并且定义weight和bias
# Training Parameters
learning_rate = 0.01
num_steps = 30000
batch_size = 256

display_step = 1000
examples_to_show = 10

# Network Parameters
num_hidden_1 = 256 # 1st layer num features
num_hidden_2 = 128 # 2nd layer num features (the latent dim)
num_input = 784 # MNIST data input (img shape: 28*28)

# tf Graph input (only pictures)
X = tf.placeholder("float", [None, num_input])

weights = {
    'encoder_h1': tf.Variable(tf.random_normal([num_input, num_hidden_1])),
    'encoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_hidden_2])),
    'decoder_h1': tf.Variable(tf.random_normal([num_hidden_2, num_hidden_1])),
    'decoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_input])),
}
biases = {
    'encoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),
    'encoder_b2': tf.Variable(tf.random_normal([num_hidden_2])),
    'decoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),
    'decoder_b2': tf.Variable(tf.random_normal([num_input])),
}
  1. 定义 Encoder和Decoder的计算图
def encoder(x):
    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
                                   biases['encoder_b1']))
    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
                                   biases['encoder_b2']))
    return layer_2


def decoder(x):
    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
                                   biases['decoder_b1']))
    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
                                   biases['decoder_b2']))
    return layer_2

encoder_op = encoder(X)
decoder_op = decoder(encoder_op)

y_pred = decoder_op
y_true = X

loss = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(loss)

init = tf.global_variables_initializer()
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/python/ops/math_grad.py:1375: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
WARNING:tensorflow:From /home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/rmsprop.py:119: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
  1. 启动session,并且开始训练
sess = tf.Session()

# Run the initializer
sess.run(init)

# Training
for i in range(1, num_steps+1):
    batch_x, _ = mnist.train.next_batch(batch_size)

    # Run optimization op (backprop) and cost op (to get loss value)
    _, l = sess.run([optimizer, loss], feed_dict={X: batch_x})
    # Display logs per step
    if i % display_step == 0 or i == 1:
        print('Step %i: Minibatch Loss: %f' % (i, l))
Step 1: Minibatch Loss: 0.443712
Step 1000: Minibatch Loss: 0.143148
Step 2000: Minibatch Loss: 0.119074
Step 3000: Minibatch Loss: 0.104780
Step 4000: Minibatch Loss: 0.102706
Step 5000: Minibatch Loss: 0.097898
Step 6000: Minibatch Loss: 0.096692
Step 7000: Minibatch Loss: 0.095731
Step 8000: Minibatch Loss: 0.093153
Step 9000: Minibatch Loss: 0.090918
Step 10000: Minibatch Loss: 0.088756
Step 11000: Minibatch Loss: 0.086178
Step 12000: Minibatch Loss: 0.079732
Step 13000: Minibatch Loss: 0.080625
Step 14000: Minibatch Loss: 0.077883
Step 15000: Minibatch Loss: 0.077515
Step 16000: Minibatch Loss: 0.077849
Step 17000: Minibatch Loss: 0.075183
Step 18000: Minibatch Loss: 0.073790
Step 19000: Minibatch Loss: 0.072038
Step 20000: Minibatch Loss: 0.070626
Step 21000: Minibatch Loss: 0.069325
Step 22000: Minibatch Loss: 0.067790
Step 23000: Minibatch Loss: 0.063283
Step 24000: Minibatch Loss: 0.064594
Step 25000: Minibatch Loss: 0.065526
Step 26000: Minibatch Loss: 0.065334
Step 27000: Minibatch Loss: 0.059984
Step 28000: Minibatch Loss: 0.059698
Step 29000: Minibatch Loss: 0.060184
Step 30000: Minibatch Loss: 0.055265
  1. 验证Autoencoder的效果,看看是否能够很好的还原原数据
  • 从下面的结果来看,我们定义的这个AutoEncoder可以很好的还原原数据,通过进一步的调餐完全可以实现更好的效果,让还原的图片上面有更少的噪点
n = 4
canvas_orig = np.empty((28 * n, 28 * n))
canvas_recon = np.empty((28 * n, 28 * n))
for i in range(n):
    # MNIST test set
    batch_x, _ = mnist.test.next_batch(n)
    # Encode and decode the digit image
    g = sess.run(decoder_op, feed_dict={X: batch_x})
    
    # Display original images
    for j in range(n):
        # Draw the generated digits
        canvas_orig[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = batch_x[j].reshape([28, 28])
    # Display reconstructed images
    for j in range(n):
        # Draw the generated digits
        canvas_recon[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])

print("Original Images")     
plt.figure(figsize=(n, n))
plt.imshow(canvas_orig, origin="upper", cmap="gray")
plt.show()

print("Reconstructed Images")
plt.figure(figsize=(n, n))
plt.imshow(canvas_recon, origin="upper", cmap="gray")
plt.show()
Original Images

30-2.png

Reconstructed Images

30-3.png

30-4.png

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

相关文章
《30天吃掉那只 TensorFlow2.0》 4-1 张量的结构操作
《30天吃掉那只 TensorFlow2.0》 4-1 张量的结构操作
9 0
《30天吃掉那只 TensorFlow2.0》 2-3 自动微分机制
《30天吃掉那只 TensorFlow2.0》 2-3 自动微分机制
13 0
《30天吃掉那只 TensorFlow2.0》 三、TensorFlow的层次结构
《30天吃掉那只 TensorFlow2.0》 三、TensorFlow的层次结构
15 0
TensorFlow2框架使用---中高阶API的使用
TensorFlow2框架使用---中高阶API的使用
30 0
搭建docker版TensorFlow
搭建docker版TensorFlow
90 0
TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略
TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略
164 0
TensorFlow——快速安装
TensorFlow:2015年Google开源的机器学习框架 1、Anoconda安装 (1)Window,MacOS,Linux都已支持Tensorflow。
2609 0
如何将TensorFlow用作计算框架
如果你刚刚接触TensorFlow并想使用其来作为计算框架,那么本文是你的一个很好的选择,阅读它相信会对你有所帮助。
4025 0
一步一步学用Tensorflow构建卷积神经网络
本文主要和大家分享如何使用Tensorflow从头开始构建和训练卷积神经网络。这样就可以将这个知识作为一个构建块来创造有趣的深度学习应用程序了。
18924 0
基于Docker的Tensorflow实验环境
利用Docker和阿里云容器服务轻松在本地和云端搭建Tensorflow的学习环境
54097 0
+关注
阿里云大数据Al技术
阿里云大数据Al技术
文章
问答
来源圈子
更多
相关文档: 机器学习平台PAI
文章排行榜
最热
最新
相关电子书
更多
ADMM based Scalable Machine Learning on Apache Spark
立即下载
用AI高效测试移动应用
立即下载
Hadoop in the cloud The What, Why and How from the experts
立即下载