mnist 数据集读取
从tensorflow直接读取数据集,联网下载解压;
代码:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets
from matplotlib import pyplot as plt
import numpy as np
(x_train_raw, y_train_raw), (x_test_raw, y_test_raw) = datasets.mnist.load_data()
print(y_train_raw[0])
print(x_train_raw.shape, y_train_raw.shape)
print(x_test_raw.shape, y_test_raw.shape)
将分类标签变为onehot编码
num_classes = 10
y_train = keras.utils.to_categorical(y_train_raw, num_classes)
y_test = keras.utils.to_categorical(y_test_raw, num_classes)
print(y_train[0])
输出:
5
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
在mnist数据集中,images是一个形状为[60000,28,28]的张量,第一个维度数字用来索引图片,第二、三个维度数字用来索引每张图片中的像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,介于0,255之间。
标签数据是"one-hot vectors",一个one-hot向量除了某一位数字是1之外,其余各维度数字都是0,如标签1可以表示为([0,1,0,0,0,0,0,0,0,0,0]),因此, labels 是一个 [60000, 10] 的数字矩阵。