我正在尝试从张量流网络进行多个顺序预测,但是即使对于CPU,性能似乎也很差(对于2层8x8卷积网络,每个预测约500ms)。我怀疑问题的一部分在于它似乎每次都在重新加载网络参数。下面代码中对classifier.predict的每次调用都会导致下一行输出-因此,我看到了数百次。

INFO:tensorflow:Restoring parameters from /tmp/model_data/model.ckpt-102001

如何重用已经加载的检查点?

(我无法在此处进行批量预测,因为网络的输出是在游戏中玩的一招,在提供新游戏状态之前,需要先将其应用于当前状态。)

这是进行预测的循环。

def rollout(classifier, state):
  while not state.terminated:
    predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": state.as_nn_input()}, shuffle=False)
    prediction = next(classifier.predict(input_fn=predict_input_fn))
    index = np.random.choice(NUM_ACTIONS, p=prediction["probabilities"]) # Select a move according to the network's output probabilities
    state.apply_move(index)


classifier是用...创建的tf.estimator.Estimator

classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir=os.path.join(tempfile.gettempdir(), 'model_data'))

最佳答案

Estimator API是高级API。


  tf.estimator框架使其易于构建和训练
  通过其高级Estimator API实现机器学习模型。估算器
  提供可以实例化的类以快速配置通用模型
  类型,例如回归变量和分类符。


Estimator API消除了TensorFlow的许多复杂性,但在此过程中失去了一些通用性。读完代码,很明显,没有每次都不重新加载模型就无法运行多个顺序预测的方法。低级TensorFlow API允许这种行为。但...

Keras是支持此用例的高级框架。简单define the model,然后重复调用predict

def rollout(model, state):
  while not state.terminated:
    predictions = model.predict(state.as_nn_input())
    for _, prediction in enumerate(predictions):
      index = np.random.choice(bt.ACTIONS, p=prediction)
      state.apply_mode(index)


不科学的基准测试表明速度要快100倍左右。

关于python - Tensorflow-停止还原网络参数,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/45804879/

10-12 18:03