我有以下张量:
# (class, index)
obj_class_indexes = tf.constant([(0, 0), (0, 1), (0, 2), (1, 3)])
对于每个值,我正在寻找具有相同类的对象。
目前,我正在尝试以下操作:
same_classes = tf.logical_and(tf.equal(obj_classes_indexes[:, 0], obj_classes_indexes[0][0]), \
obj_classes_indexes[:, 1] > obj_classes_indexes[0][1])
found_indexes = tf.where(same_classes)
with tf.Session() as sess:
print(sess.run(same_classes))
print(sess.run(indexes))
预期的输出将是:
[False True True False]
[1, 2]
但这给了我:
[False True True False]
[[1], [2]]
我认为
logical_and
输出实际上不是tf.where
函数的正确输入。还是我错过了什么?谢谢!
最佳答案
输出没有任何错误。 tf.where()预计将输出一个二维张量,如here所示:
“坐标以二维张量返回,其中第一维(行)代表真实元素的数量,第二维(列)代表真实元素的坐标”
如前所述,如果您希望输出为一维张量,则可以在情况下添加一个重塑操作,如下所示:
found_indexes = tf.where(same_classes)
found_indexes = tf.reshape(found_indexes, [-1])
希望这可以帮助!
关于python - Tensorflow,其中(索引)“和”为条件,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53393430/