通常,可以定义Keras的提前停止以检查它在每个时期之后是否达到极限(损失,准确性)。但是在小批量方法中,每个时期都包含多个损失,用于训练和验证阶段。我们怎样才能告诉Keras检查某个时期的每个累加或损失以尽早停止?
filepath="weights.best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='acc', verbose=1, save_best_only=True)
early_stop = EarlyStopping(monitor='acc', patience=5)
callbacks_list = [checkpoint, early_stop]
history = model.fit(x, y, validation_data=(x_test, y_test), epochs=1, callbacks=callbacks_list)
最佳答案
从source复制EarlyStopping
的代码,并用on_epoch_end
更改on_batch_end
。更新其他有关epoch
的内容,并且已经准备就绪:
class BatchEarlyStopping(Callback):
"""Stop training when a monitored quantity has stopped improving.
# Arguments
monitor: quantity to be monitored.
min_delta: minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: number of batches with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity to reach.
Training will stop if the model doesn't show improvement
over the baseline.
restore_best_weights: whether to restore model weights from
the batch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
"""
def __init__(self,
monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False):
super(BatchEarlyStopping, self).__init__()
self.monitor = monitor
self.baseline = baseline
self.patience = patience
self.verbose = verbose
self.min_delta = min_delta
self.wait = 0
self.stopped_batch = 0
self.restore_best_weights = restore_best_weights
self.best_weights = None
if mode not in ['auto', 'min', 'max']:
warnings.warn('BatchEarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode,
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_batch = 0
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_batch_end(self, batch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_batch = batch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of '
'the best batch')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_batch > 0 and self.verbose > 0:
print('Batch %05d: early stopping' % (self.stopped_batch + 1))
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
if monitor_value is None:
warnings.warn(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return monitor_value
关于python - 我们如何定义Keras的Early Stopping以便在每批之后检查(不是整个时期),我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/57618220/