数据可视化
绘制前9张图片
代码:
plt.figure()
for i in range(9):
plt.subplot(3,3,i+1)
plt.imshow(x_train_raw[i])
#plt.ylabel(y[i].numpy())
plt.axis('off')
plt.show()
输出:
数据处理,因为我们构建的是全连接网络所以输出应该是向量的形式,而非现在图像的矩阵形式。因此我们需要把图像整理成向量。
代码:
将2828的图像展开成7841的向量
x_train = x_train_raw.reshape(60000, 784)
x_test = x_test_raw.reshape(10000, 784)
现在像素点的动态范围为0到255。处理图形像素值时,我们通常会把图像像素点归一化到0到1的范围内。
代码:
代码:
将图像像素值归一化
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255