我从 tensorflow 中调整了 cifar10 网络,以解决我自己的分类问题。我已经训练了网络,现在我尝试使用 cifar10_eval.py 评估训练后的模型
top_k_op = tf.nn.in_top_k(logits, labels, 1)
但我收到以下错误。经过进一步调查,目标指数在2,3和4之间变化
tensorflow.python.framework.errors.InvalidArgumentError: targets[3] is out of range
到现在为止,我知道我的标签 Tensor 有问题。它是一个 int32-Tensor,其 shape(50,) 如下所示。
labels = {Tensor} Tensor("batch_processing/Reshape_1:0", shape=(50,), dtype=int32, device=/device:CPU:0)
我的数据集只有 2 个类/标签。也许这可能是问题所在。有谁知道,问题是什么?
最佳答案
总而言之,函数 tf.nn.in_top_k(predictions, targets, k)
(参见 doc )有参数:
[batch_size, num_classes]
,类型 float32 [batch_size]
,类型 int32 或 int64 当元素
InvalidArgumentError: targets[i] is out of range
超出 targets[i]
的范围时,该函数会引发错误 predictions[i]
。例如,有 2 个类 (
num_classes=2
) 和 targets=[1, 3]
。对于这些目标,您将看到错误
InvalidArgumentError: targets[1] is out of range
,因为 targets[1] = 3
超出了只有形状 2 的 predictions[1]
的范围。要检查您的
labels
是否正确,您可以打印它们的最大值:labels = ...
labels_max = tf.reduce_max(labels)
sess = tf.Session()
print sess.run(labels_max)
如果打印的值优于
num_classes
,则您有问题。关于tensorflow - tf.nn.in_top_k : targets out of range,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/37587622/