Closed. This question does not meet Stack Overflow guidelines。它当前不接受答案。
想改善这个问题吗?更新问题,以使为on-topic。
在11个月前关闭。
Improve this question
阅读API DOC之后,我也无法理解SessionRunHook的用法。例如,SessionRunHook的成员顺序是什么
要调用的功能?是
而且我找不到包含详细示例的教程,是否有更详细的说明?
或者,您可以选择实现自己的
其中
最后,为了更好地理解其工作方式,执行顺序由带有
希望这可以帮助。
想改善这个问题吗?更新问题,以使为on-topic。
在11个月前关闭。
Improve this question
阅读API DOC之后,我也无法理解SessionRunHook的用法。例如,SessionRunHook的成员顺序是什么
要调用的功能?是
after_create_session -> before_run -> begin -> after_run -> end
吗?而且我找不到包含详细示例的教程,是否有更详细的说明?
最佳答案
您可以找到一个很长的教程here,但是可以跳过构建网络的部分。或者,您可以根据我的经验阅读下面的小摘要。
首先,应使用MonitoredSession
代替普通的Session
。
SessionRunHook扩展了对session.run()
的MonitoredSession
调用。
然后可以在here中找到一些常见的SessionRunHook
类。一个简单的是LoggingTensorHook
,但是您可能要在导入后添加以下行,以便在运行时查看日志:
tf.logging.set_verbosity(tf.logging.INFO)
或者,您可以选择实现自己的
SessionRunHook
类。一个简单的是来自cifar10 tutorialclass _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
其中
loss
是在类外部定义的。该_LoggerHook
使用print
打印信息,而LoggingTensorHook
使用tf.logging.INFO
。最后,为了更好地理解其工作方式,执行顺序由带有
MonitoredSession
here的伪代码表示: call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested: # py code: while not mon_sess.should_stop():
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
希望这可以帮助。
08-20 02:04