我正在尝试实现提供高级api的tensorflow,特别是基线分类器。但是,当尝试训练模型时,我得到以下信息
错误:
NotFoundError (see above for traceback): Key baseline/bias not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
码:
import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
def digit_cross():
# Number of classes, one class for each of 10 digits.
num_classes = 10
digit = datasets.load_digits()
x = digit.data
y = digit.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=42)
y_train_index = np.arange(y_train.size)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(x_train)},
y=np.array(y_train),
num_epochs=None,
shuffle=False)
# Build BaselineClassifier
classifier = tf.estimator.BaselineClassifier(n_classes=num_classes,
model_dir="./checkpoints_tutorial17-1/")
# Fit model.
classifier.train(train_input_fn)
digit_cross()
最佳答案
看来您在model_dir="./checkpoints_tutorial17-1/"
中有一个检查点,该检查点来自另一个模型,而不是来自BaselineClassifier。具体来说,该文件夹中有一个检查点文件和model.ckpt- *文件。
正如tensorflow记录的那样:
model_dir:用于保存模型参数,图形等的目录。还可以用于将目录中的检查点加载到估计器中,以继续训练先前保存的模型。如果是PathLike对象,则路径将被解析。如果为None,则设置时将使用config中的model_dir。如果两者都设置,则必须相同。如果两者均为“无”,则将使用一个临时目录。
在这里,BaselineClassifier
首先将建立一个使用baseline/bias
的图形。然后发现model_dir
中存在先前的检查点。它将尝试加载此检查点,并且您应该看到一条信息(如果已完成tf.logging.set_verbosity(tf.logging.INFO)
),例如
"INFO:tensorflow:Restoring parameters from .../checkpoints_tutorial17-1\model.ckpt-..."
因为
model_dir
中的此检查点不是来自BaselineClassifier
,所以不会具有baseline/bias
。 BaselineClassifier
找不到它,因此将引发错误。