Closed. This question does not meet Stack Overflow guidelines。它当前不接受答案。












想改善这个问题吗?更新问题,以使为on-topic

在11个月前关闭。



Improve this question





阅读API DOC之后,我也无法理解SessionRunHook的用法。例如,SessionRunHook的成员顺序是什么
要调用的功能?是after_create_session -> before_run -> begin -> after_run -> end吗?
而且我找不到包含详细示例的教程,是否有更详细的说明?

最佳答案

您可以找到一个很长的教程here,但是可以跳过构建网络的部分。或者,您可以根据我的经验阅读下面的小摘要。

首先,应使用MonitoredSession代替普通的Session


SessionRunHook扩展了对session.run()MonitoredSession调用。


然后可以在here中找到一些常见的SessionRunHook类。一个简单的是LoggingTensorHook,但是您可能要在导入后添加以下行,以便在运行时查看日志:

tf.logging.set_verbosity(tf.logging.INFO)


或者,您可以选择实现自己的SessionRunHook类。一个简单的是来自cifar10 tutorial

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()

  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time

      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))


其中loss是在类外部定义的。该_LoggerHook使用print打印信息,而LoggingTensorHook使用tf.logging.INFO

最后,为了更好地理解其工作方式,执行顺序由带有MonitoredSession here的伪代码表示:

  call hooks.begin()
  sess = tf.Session()
  call hooks.after_create_session()
  while not stop is requested:  # py code: while not mon_sess.should_stop():
    call hooks.before_run()
    try:
      results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
      break
    call hooks.after_run()
  call hooks.end()
  sess.close()


希望这可以帮助。

08-20 02:04