模型训练
fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None):
x: 输入训练数据;
y: 目标(标签)数据;
batch_size: 每次梯度更新的样本数。如果未指定,默认为 32;
epochs:训练模型迭代轮次;
verbose:0, 1 或 2。日志显示模式。 0 = 不显示, 1 = 进度条, 2 = 每轮显示一行;
callbacks:在训练时使用的回调函数;
validation_split:验证集与训练数据的比例;
validation_data:验证集;这个参数会覆盖validation_split;
shuffle: 是否在每轮迭代之前混洗数据。当steps_per_epoch非None时,这个参数无效;
initial_epoch: 开始训练的轮次,常用于恢复之前的训练权重;
steps_per_epoch:steps_per_epoch = 数据集大小/batch_size;
validation_steps:只有在指定了 steps_per_epoch 时才有用。停止前要验证的总步数(批次样本)。
代码:
import numpy as np
train_x = np.random.random((1000, 36))
train_y = np.random.random((1000, 10))
val_x = np.random.random((200, 36))
val_y = np.random.random((200, 10))
model.fit(train_x, train_y, epochs=10, batch_size=100,
validation_data=(val_x, val_y))
输出:
Train on 1000 samples, validate on 200 samples
Epoch 1/10
1000/1000 [==============================] - 0s 488us/sample - loss: 12.6024 - categorical_accuracy: 0.0960 - val_loss: 12.5787 - val_categorical_accuracy: 0.0850
Epoch 2/10
1000/1000 [==============================] - 0s 23us/sample - loss: 12.6007 - categorical_accuracy: 0.0960 - val_loss: 12.5776 - val_categorical_accuracy: 0.0850
Epoch 3/10
1000/1000 [==============================] - 0s 31us/sample - loss: 12.6002 - categorical_accuracy: 0.0960 - val_loss: 12.5771 - val_categorical_accuracy: 0.0850
…
Epoch 10/10
1000/1000 [==============================] - 0s 24us/sample - loss: 12.5972 - categorical_accuracy: 0.0960 - val_loss: 12.5738 - val_categorical_accuracy: 0.0850
对于大型数据集可以使用tf.data构建训练输入。
代码:
dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
dataset = dataset.batch(32)
dataset = dataset.repeat()
val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y))
val_dataset = val_dataset.batch(32)
val_dataset = val_dataset.repeat()
model.fit(dataset, epochs=10, steps_per_epoch=30,
validation_data=val_dataset, validation_steps=3)
输出:
Train for 30 steps, validate for 3 steps
Epoch 1/10
30/30 [==============================] - 0s 15ms/step - loss: 12.6243 - categorical_accuracy: 0.0948 - val_loss: 12.3128 - val_categorical_accuracy: 0.0833
…
30/30 [==============================] - 0s 2ms/step - loss: 12.5797 - categorical_accuracy: 0.0951 - val_loss: 12.3067 - val_categorical_accuracy: 0.0833