本文介绍了tf.train.NanTensorHook(loss, fail_on_nan_loss=False) 仍然会在 TF1.0 中引发异常的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当我使用TF1.0定义自定义model_fn时,我想在损失为Nan时停止训练.我在model_fn中尝试了下面的代码:

When I define a customized model_fn wtih TF1.0, I want to stop the training when loss is Nan. I tried the code below in model_fn:

    return model_fn_lib.ModelFnOps(
        mode=mode,
        predictions=predictions_dict,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        training_hooks=[tf.train.NanTensorHook(loss, fail_on_nan_loss=False)])

但是 fail_on_nan_loss=False 仍然会引发异常,我希望它会写警告信息并停止特定的训练而不引发异常.

but the fail_on_nan_loss=False will still raise an exception, I expect it will write the warning message and stop the specific training without raising an exception.

关于如何正确使用 tf.train.NanTensorHook 有什么建议吗?

Any suggestion on how to use tf.train.NanTensorHook correctly?

推荐答案

当我探索解决方案时,一种可能的解决方法可能会有所帮助:我从 basic_session_run_hooks.py 复制 NanTensorHook 类,并在我的 model_fn 中创建我自己的调用版本,如下所示

When I explore the solution, one possible work around might helps:I copy the NanTensorHook class from basic_session_run_hooks.py and make my own version of call inside my model_fn as below

        class NanTensorHook2(tf.train.SessionRunHook):
      """NaN Loss monitor by Lei.

      Monitors loss and stops training if loss is NaN.
      Can either fail with exception or just stop training.
      """

      def __init__(self, loss_tensor, fail_on_nan_loss=True):
        """Initializes NanLoss monitor.

        Args:
          loss_tensor: `Tensor`, the loss tensor.
          fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
        """
        self._loss_tensor = loss_tensor
        self._fail_on_nan_loss = fail_on_nan_loss

      def before_run(self, run_context):  # pylint: disable=unused-argument
        return tf.train.SessionRunArgs(self._loss_tensor)

      def after_run(self, run_context, run_values):
        if (np.isnan(run_values.results) or np.isinf(run_values.results)):
          failure_message = "Model diverged with loss = NaN or Inf."
          if self._fail_on_nan_loss:
            logging.error(failure_message)
            raise NanLossDuringTrainingError
          else:
            logging.warning(failure_message)
            # We don't raise an error but we request stop without an exception.
            run_context.request_stop()

然后改用 NanTensorHook2,它就开始工作了.

then use NanTensorHook2 instead, it then started working.

注意我添加了 "np.isinf(run_values.results)" 因为我认为这里也应该检查 loss = inf.

Notes that I added "np.isinf(run_values.results)" as I believe that loss = inf should also be checked here.

哪位专家有更好的解决方案?

any expert has better solution?

这篇关于tf.train.NanTensorHook(loss, fail_on_nan_loss=False) 仍然会在 TF1.0 中引发异常的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

11-02 12:34