本文介绍了如何从 tf.estimator.Estimator 获取最后一个 global_step的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何在 train(...) 完成后从 tf.estimator.Estimator 获取最后一个 global_step?例如,一个典型的基于 Estimator 的训练程序可能是这样设置的:n_epochs = 10model_dir = '/path/to/model_dir'

How can I obtain the last global_step from a tf.estimator.Estimator after train(...) finishes? For instance, a typical Estimator-based training routine might be set up like this: n_epochs = 10 model_dir = '/path/to/model_dir'

def model_fn(features, labels, mode, params):
    # some code to build the model
    pass

def input_fn():
    ds = tf.data.Dataset()  # obviously with specifying a data source
    # manipulate the dataset
    return ds

run_config = tf.estimator.RunConfig(model_dir=model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

for epoch in range(n_epochs):
    estimator.train(input_fn=input_fn)
    # Now I want to do something which requires to know the last global step, how to get it?
    my_custom_eval_method(global_step)

只有 evaluate() 方法返回一个包含 global_step 作为字段的字典.如何获得global_step,如果由于某种原因,我不能或不想使用这种方法?

Only the evaluate() method returns a dictionary containing the global_step as a field. How can I get the global_step, if for some reason, I can't have or don't want to use this method?

推荐答案

最近,我发现 estimator 有 api get_variable_value

recently, I found estimator has the api get_variable_value

global_step = estimator.get_variable_value("global_step")

这篇关于如何从 tf.estimator.Estimator 获取最后一个 global_step的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-29 07:32