MonitoredTrainingSession

MonitoredTrainingSession

我想设置一个分布式 tensorflow 模型,但无法理解 MonitoredTrainingSession 和 StopAtStepHook 如何交互。
在我进行此设置之前:

for epoch in range(training_epochs):
  for i in range(total_batch-1):
    c, p, s = sess.run([cost, prediction, summary_op], feed_dict={x: batch_x, y: batch_y})

现在我有了这个设置(简化):
def run_nn_model(learning_rate, log_param, optimizer, batch_size, layer_config):
  with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % mytaskid,
        cluster=cluster)):

    # [variables...]

    hooks=[tf.train.StopAtStepHook(last_step=100)]
    if myjob == "ps":
        server.join()
    elif myjob == "worker":
        with tf.train.MonitoredTrainingSession(master = server.target,
                                is_chief=(mytaskid==0),
                                checkpoint_dir='/tmp/train_logs',
                                hooks=hooks
                                ) as sess:

          while not sess.should_stop():
            #for epoch in range...[see above]

这是错误的吗?它抛出:
RuntimeError: Run called even after should_stop requested.
Command exited with non-zero status 1

有人可以向我解释 tensorflow 是如何协调的吗?我如何使用步进计数器来跟踪训练? (在我拥有这个方便的时代变量之前)

最佳答案

每次执行 sess.run 时,计数器都会递增。这里的问题是您运行的步骤 (total_batch-1 x training_epochs) 多于钩子(Hook) (200) 中指定的步骤数。

你可以做什么,即使我不认为它是一个干净的语法是定义 last_step = total_batch-1 x training_epochs

关于python - 基本的 StopAtStepHook 和 MonitoredTrainingSession 用法,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/42960304/

10-12 17:40