This tutorial explains how to use Keras EarlyStopping callback API.
A callback is an object that can perform actions at various stages of training. Some of the sample use cases for callbacks are
In the below code snippet we will use EarlyStopping callback and understand its effect on model.fit()
method.
tf.keras.callbacks.EarlyStopping(
monitor="loss",
min_delta=0,
patience=0,
verbose=0,
mode="auto",
baseline=None,
restore_best_weights=False,
)
monitor
: Quantity to be monitored.min_delta
: Minimum change in the monitored quantity to qualify as an improvement.patience
: Number of epochs with no improvement after which training will be stopped.verbose
: verbosity mode.mode
: One of {"auto", "min", "max"}. In min mode, training will stop when the quantity monitored has stopped decreasing; in "max" mode it will stop when the quantity monitored has stopped increasing; in "auto" mode, the direction is automatically inferred from the name of the monitored quantity.baseline
: Baseline value for the monitored quantity. Training will stop if the model doesn't show improvement over the baseline.restore_best_weights
: Whether to restore model weights from the epoch with the best value of the monitored quantity.
1. Import required modules.
# import required modules
import tensorflow as tf
import tensorflow_datasets as tfds
2. Load MNIST dataset.
# load mnist dataset
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
print(type(ds_train))
3. Define image normalization function.
# define image normalize function
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
4. Apply normalization and batch on training and test datasets.
# apply normalization and batch on training and test datasets
dataset_train = ds_train.map(normalize_img)
dataset_train = ds_train.batch(128)
dataset_test = ds_test.map(normalize_img)
dataset_test = ds_test.batch(128)
5. Create Keras sequential model.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
6. Compile the model.
model.compile(
optimizer=tf.keras.optimizers.Adam(0.006),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
7. Instantiate EarlyStopping callback.
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)
8. Train the model for 100 epoch and check number of epochs execution due to EarlyStopping.
history = model.fit(
dataset_train,
epochs=100,
validation_data=dataset_test,
callbacks=[callback]
)
print(len(history.history['loss']))
Output
Epoch 1/100
469/469 [==============================] - 9s 19ms/step - loss: 3.2035 - sparse_categorical_accuracy: 0.8107 - val_loss: 0.4739 - val_sparse_categorical_accuracy: 0.8911
Epoch 2/100
469/469 [==============================] - 2s 3ms/step - loss: 0.4325 - sparse_categorical_accuracy: 0.8966 - val_loss: 0.3721 - val_sparse_categorical_accuracy: 0.9147
Epoch 3/100
469/469 [==============================] - 2s 3ms/step - loss: 0.3718 - sparse_categorical_accuracy: 0.9117 - val_loss: 0.4223 - val_sparse_categorical_accuracy: 0.9026
Epoch 4/100
469/469 [==============================] - 2s 3ms/step - loss: 0.3774 - sparse_categorical_accuracy: 0.9109 - val_loss: 0.4949 - val_sparse_categorical_accuracy: 0.9073
Epoch 5/100
469/469 [==============================] - 2s 3ms/step - loss: 0.4109 - sparse_categorical_accuracy: 0.9045 - val_loss: 0.4240 - val_sparse_categorical_accuracy: 0.9076
As from the above output we can see only 5 epochs run during training due to patience=2
in callback and loss didn't reduce in Epoch 4
and Epoch 5
consecutively.
# import required modules
import tensorflow as tf
import tensorflow_datasets as tfds
# load mnist dataset
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
print(type(ds_train))
# define image normalize function
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
# apply normalization and batch on training and test datasets
dataset_train = ds_train.map(normalize_img)
dataset_train = ds_train.batch(128)
dataset_test = ds_test.map(normalize_img)
dataset_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.006),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)
history = model.fit(
dataset_train,
epochs=100,
validation_data=dataset_test,
callbacks=[callback]
)
print(len(history.history['loss']))
Category: TensorFlow