(train_images, train_labels), (val_images, val_labels) = tf.keras.datasets.mnist.load_data() train_loader = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(32) val_loader = tf.data.Dataset.from_tensor_slices((val_images, val_labels)).batch(32)
首先加载Mnist数据集,得到train_images的维度为(60000,32,32),然后适用from_tensor_slices()
将数据集和标签进行打包形成一个迭代器,迭代器每个批次有32个数据