在TensorFlow中fit()
函数可以接收numpy类型数据,前提数据量不大可以全部加载到内存中,但是如果数据量过大我们就需要将其按批次读取,转化成迭代器的形式,也就是DataSets
可以将 Dataset
实例直接传递给方法 fit()
、evaluate()
和 predict()
:
如果使用DataSet就不需要像numpy数据那种在fit中指定batch_size了
完整代码:
""" * Created with PyCharm * 作者: 阿光 * 日期: 2022/1/2 * 时间: 19:29 * 描述: """ import tensorflow as tf import tensorflow.keras.datasets.mnist from keras import Input, Model from keras.layers import Dense from tensorflow import keras (train_images, train_labels), (val_images, val_labels) = tensorflow.keras.datasets.mnist.load_data() train_images, val_images = train_images / 255.0, val_images / 255.0 train_images = train_images.reshape(60000, 784) val_images = val_images.reshape(10000, 784) train_datasets = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) train_datasets = train_datasets.shuffle(buffer_size=1024).batch(64) val_datasets = tf.data.Dataset.from_tensor_slices((val_images, val_labels)) val_datasets = val_datasets.batch(64) def get_model(): inputs = Input(shape=(784,)) outputs = Dense(10, activation='softmax')(inputs) model = Model(inputs, outputs) model.compile( optimizer=keras.optimizers.RMSprop(learning_rate=1e-3), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'] ) return model model = get_model() model.fit( train_datasets, epochs=5, validation_data=val_datasets )