global_step
完成后,如何从 tf.estimator.Estimator
获取最后一个 train(...)
?例如,一个典型的基于 Estimator 的训练程序可能是这样设置的:
n_epochs = 10
model_dir = '/path/to/model_dir'
def model_fn(features, labels, mode, params):
# some code to build the model
pass
def input_fn():
ds = tf.data.Dataset() # obviously with specifying a data source
# manipulate the dataset
return ds
run_config = tf.estimator.RunConfig(model_dir=model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
for epoch in range(n_epochs):
estimator.train(input_fn=input_fn)
# Now I want to do something which requires to know the last global step, how to get it?
my_custom_eval_method(global_step)
只有
evaluate()
方法返回包含 global_step
作为字段的字典。我怎样才能获得 global_step
,如果由于某种原因,我不能或不想使用这种方法? 最佳答案
最近,我发现 estimator 有 api get_variable_value
global_step = estimator.get_variable_value("global_step")
关于python - 如何从 tf.estimator.Estimator 获取最后一个 global_step,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/51325660/