使用CNN完成MNIST手写体识别(TensorFlow)
CNN(Convolutional Neural Network,卷积神经网络)是一种比较常见的神经网络模型,它通常被用于图像识别、语音识别等领域。相比于传统的神经网络模型,CNN在处理图像等数据方面有明显的优势,其核心思想是通过卷积、池化等操作提取出图像中的特征,从而实现图像识别等任务。
CNN的基本结构由卷积层(Convolutional Layer)、池化层(Pooling Layer)、全连接层(Fully Connected Layer)等组成。其中,卷积层是CNN最核心的部分,它通过卷积核对输入的图像进行特征提取。每个卷积核都是一个小的矩阵,可以看做是一种特定的滤波器,卷积核在输入的图像上滑动,将每个位置的像素值与卷积核中对应位置的权重做乘积之和,最终得到一个新的特征图。通过多次卷积操作,可以逐步提取出图像中的高级特征。
在卷积层之后通常会添加一个池化层,将特征图进行降维处理。常见的池化方式有最大池化和平均池化两种,最大池化会取出输入的特定区域中的最大值作为池化后的值,而平均池化则是取输入区域的平均值作为池化后的值。通过池化操作,可以减少特征图的维度,降低计算复杂度,同时还可以提高模型的鲁棒性,避免因输入数据中的一些细节变化而影响模型的输出结果。
最后,全连接层将卷积和池化后得到的特征图转化成一个一维向量,通过多个全连接层的组合可以实现复杂的图像分类、目标检测等任务。在训练CNN时,通常会使用反向传播算法(Backpropagation)对网络中的参数进行优化,通过反向传播算法可以计算出损失函数对各个参数的梯度,从而进行参数更新。
总的来说,CNN是一种十分有效的神经网络模型,在图像处理等领域有着广泛的应用。不过,对于初学者而言,搭建和训练一个CNN模型需要比较高的数学和编程技能,需要有一定的基础才能掌握。
1. 导入TensorFlow库
# Tensorflow提供了一个类来处理MNIST数据 # 导入相关库 from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf import time import warnings warnings.filterwarnings('ignore')
2. 数据集
# 载入数据集 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) mnist
WARNING:tensorflow:From <ipython-input-2-574fb576d2f2>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use alternatives such as official/mnist/dataset.py from tensorflow/models. WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version. Instructions for updating: Please write your own downloading logic. WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.data to implement this functionality. Extracting MNIST_data/train-images-idx3-ubyte.gz WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.data to implement this functionality. Extracting MNIST_data/train-labels-idx1-ubyte.gz WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.one_hot on tensors. Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use alternatives such as official/mnist/dataset.py from tensorflow/models. Datasets(train=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f05f902e2b0>, validation=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f05f902e358>, test=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f05f902e320>)
# 设置批次大小 batch_size = 50 # 计算一共有多少个批次 n_batch = mnist.train.num_examples // batch_size n_batch
1100
# 定义初始化权值函数 def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial)
# 定义初始化偏置函数 def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial)
""" strides=[b,h,w,c] b表示在样本上的步长默认为1,也就是每一个样本都会进行运算。 h表示在高度上的默认移动步长为1,这个可以自己设定,根据网络的结构合理调节。 w表示在宽度上的默认移动步长为1,这个同上可以自己设定。 c表示在通道上的默认移动步长为1,这个表示每一个通道都会进行运算 """ # 卷积层 def conv2d(input, filter): return tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
""" ksize=[b,h,w,c],通常为[1,2,2,1] b表示在样本上的步长默认为1,也就是每一个样本都会进行运算。 h表示在高度上的默认移动步长为1,这个可以自己设定,根据网络的结构合理调节。 w表示在宽度上的默认移动步长为1,这个同上可以自己设定。 c表示在通道上的默认移动步长为1,这个表示每一个通道都会进行运算 """ # 池化层 def max_pool_2x2(value): return tf.nn.max_pool(value, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# 输入层 # 定义两个placeholder x = tf.placeholder(tf.float32, [None, 784]) # 28*28 y = tf.placeholder(tf.float32, [None, 10])
# 改变x的格式转为4维的向量[batch,in_hight,in_width,in_channels] x_image = tf.reshape(x, [-1, 28, 28, 1])
3. 卷积、激励、池化操作
# 初始化第一个卷积层的权值和偏置 # MNIST使用的是灰度图像,每个像素点只需要一个数值,因此这里通道数为1 # 5*5的采样窗口,32个卷积核从1个平面抽取特征 w_conv1 = weight_variable([5, 5, 1, 32]) # 每一个卷积核一个偏置值 b_conv1 = bias_variable([32]) # 把x_image和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数 h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) # 进行max_pooling 池化层 14*14*32 h_pool1 = max_pool_2x2(h_conv1)
# 初始化第二个卷积层的权值和偏置 # 5*5的采样窗口,64个卷积核从32个平面抽取特征 w_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64]) # 把第一个池化层结果和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数 h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) # 池化层 7*7*64 h_pool2 = max_pool_2x2(h_conv2)
28x28的图片第一次卷积后还是28x28,第一次池化后变为14x14
第二次卷积后为14x14,第二次池化后变为了7x7
经过上面操作后得到64张7x7的平面
4. 全连接层
# 初始化第一个全连接层的权值 # 经过池化层后有7*7*64个神经元,全连接层有128个神经元 w_fc1 = weight_variable([7 * 7 * 64, 128]) # 128个节点 b_fc1 = bias_variable([128]) # 把池化层2的输出扁平化为1维 h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) # 求第一个全连接层的输出 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
h_pool2_flat.shape, h_fc1.shape
(TensorShape([Dimension(None), Dimension(3136)]), TensorShape([Dimension(None), Dimension(128)]))
# keep_prob: float类型,每个元素被保留下来的概率,设置神经元被选中的概率,在初始化时keep_prob是一个占位符 keep_prob = tf.placeholder(tf.float32) h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
WARNING:tensorflow:From <ipython-input-14-a22383db216d>:3: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version. Instructions for updating: Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
# 初始化第二个全连接层 W_fc2 = weight_variable([128, 10]) b_fc2 = bias_variable([10])
5. 输出层
# 计算输出 prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
# 交叉熵代价函数 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction)) # 使用AdamOptimizer进行优化 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) # 结果存放在一个布尔列表中(argmax函数返回一维张量中最大的值所在的位置) correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1)) # 求准确率(tf.cast将布尔值转换为float型) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
WARNING:tensorflow:From <ipython-input-17-ef3c12a7f7c4>:2: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version. Instructions for updating: Future major versions of TensorFlow will allow gradients to flow into the labels input on backprop by default. See `tf.nn.softmax_cross_entropy_with_logits_v2`.
6. 训练模型
# 创建会话 with tf.Session() as sess: start_time = time.clock() # 初始化变量 sess.run(tf.global_variables_initializer()) print('开始训练 ----------') # 训练10次 for epoch in range(10): print("Test" + str(epoch) + " :") for batch in range(n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 进行迭代训练 sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7}) print('第' + str(batch) + '批训练') # 测试数据计算出准确率 acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0}) print('Iter' + str(epoch) + ',Testing Accuracy=' + str(acc)) end_time = time.clock() # 输出运行时间 print('Running time:%s Second' % (end_time - start_time))
开始训练 ---------- Test0 : 第0批训练 第1批训练 第2批训练 第3批训练 第4批训练 第5批训练 第6批训练 第7批训练 第8批训练 第9批训练 第10批训练 第11批训练 第12批训练 第13批训练 第14批训练 第15批训练 第16批训练 第17批训练 第18批训练 第19批训练 第20批训练 第21批训练 第22批训练 第23批训练 第24批训练 第25批训练 第26批训练 第27批训练 第28批训练 第29批训练 第30批训练 第31批训练 第32批训练 第33批训练 第34批训练 第35批训练 第36批训练 第37批训练 第38批训练 第39批训练 第40批训练 第41批训练 第42批训练 第43批训练 第44批训练 第45批训练 第46批训练 第47批训练 第48批训练 第49批训练 第50批训练 第51批训练 第52批训练 第53批训练 第54批训练 第55批训练 第56批训练 第57批训练 第58批训练 第59批训练 第60批训练 第61批训练 第62批训练 第63批训练 第64批训练 第65批训练 ……