在TensorFlow中fit()
函数可以接收numpy类型数据,前提数据量不大可以全部加载到内存中,但是如果数据量过大我们就需要将其按批次读取,转化成迭代器的形式,也就是DataSets
可以将 Dataset
实例直接传递给方法 fit()
、evaluate()
和 predict()
:
如果使用DataSet就不需要像numpy数据那种在fit中指定batch_size了
完整代码:
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2022/1/2
* 时间: 19:29
* 描述:
"""
import tensorflow as tf
import tensorflow.keras.datasets.mnist
from keras import Input, Model
from keras.layers import Dense
from tensorflow import keras
(train_images, train_labels), (val_images, val_labels) = tensorflow.keras.datasets.mnist.load_data()
train_images, val_images = train_images / 255.0, val_images / 255.0
train_images = train_images.reshape(60000, 784)
val_images = val_images.reshape(10000, 784)
train_datasets = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_datasets = train_datasets.shuffle(buffer_size=1024).batch(64)
val_datasets = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_datasets = val_datasets.batch(64)
def get_model():
inputs = Input(shape=(784,))
outputs = Dense(10, activation='softmax')(inputs)
model = Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
return model
model = get_model()
model.fit(
train_datasets,
epochs=5,
validation_data=val_datasets
)