问题描述
当我使用TF1.0定义自定义model_fn时,我想在损失为Nan时停止训练.我在model_fn中尝试了下面的代码:
When I define a customized model_fn wtih TF1.0, I want to stop the training when loss is Nan. I tried the code below in model_fn:
return model_fn_lib.ModelFnOps(
mode=mode,
predictions=predictions_dict,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
training_hooks=[tf.train.NanTensorHook(loss, fail_on_nan_loss=False)])
但是 fail_on_nan_loss=False 仍然会引发异常,我希望它会写警告信息并停止特定的训练而不引发异常.
but the fail_on_nan_loss=False will still raise an exception, I expect it will write the warning message and stop the specific training without raising an exception.
关于如何正确使用 tf.train.NanTensorHook 有什么建议吗?
Any suggestion on how to use tf.train.NanTensorHook correctly?
推荐答案
当我探索解决方案时,一种可能的解决方法可能会有所帮助:我从 basic_session_run_hooks.py 复制 NanTensorHook 类,并在我的 model_fn 中创建我自己的调用版本,如下所示
When I explore the solution, one possible work around might helps:I copy the NanTensorHook class from basic_session_run_hooks.py and make my own version of call inside my model_fn as below
class NanTensorHook2(tf.train.SessionRunHook):
"""NaN Loss monitor by Lei.
Monitors loss and stops training if loss is NaN.
Can either fail with exception or just stop training.
"""
def __init__(self, loss_tensor, fail_on_nan_loss=True):
"""Initializes NanLoss monitor.
Args:
loss_tensor: `Tensor`, the loss tensor.
fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
"""
self._loss_tensor = loss_tensor
self._fail_on_nan_loss = fail_on_nan_loss
def before_run(self, run_context): # pylint: disable=unused-argument
return tf.train.SessionRunArgs(self._loss_tensor)
def after_run(self, run_context, run_values):
if (np.isnan(run_values.results) or np.isinf(run_values.results)):
failure_message = "Model diverged with loss = NaN or Inf."
if self._fail_on_nan_loss:
logging.error(failure_message)
raise NanLossDuringTrainingError
else:
logging.warning(failure_message)
# We don't raise an error but we request stop without an exception.
run_context.request_stop()
然后改用 NanTensorHook2,它就开始工作了.
then use NanTensorHook2 instead, it then started working.
注意我添加了 "np.isinf(run_values.results)" 因为我认为这里也应该检查 loss = inf.
Notes that I added "np.isinf(run_values.results)" as I believe that loss = inf should also be checked here.
哪位专家有更好的解决方案?
any expert has better solution?
这篇关于tf.train.NanTensorHook(loss, fail_on_nan_loss=False) 仍然会在 TF1.0 中引发异常的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!