我有以下张量:

# (class, index)
obj_class_indexes = tf.constant([(0, 0), (0, 1), (0, 2), (1, 3)])


对于每个值,我正在寻找具有相同类的对象。
目前,我正在尝试以下操作:

same_classes = tf.logical_and(tf.equal(obj_classes_indexes[:, 0], obj_classes_indexes[0][0]), \
                                        obj_classes_indexes[:, 1]  > obj_classes_indexes[0][1])
found_indexes = tf.where(same_classes)

with tf.Session() as sess:
    print(sess.run(same_classes))
    print(sess.run(indexes))


预期的输出将是:

[False  True  True False]
[1, 2]


但这给了我:

[False  True  True False]
[[1], [2]]


我认为logical_and输出实际上不是tf.where函数的正确输入。还是我错过了什么?

谢谢!

最佳答案

输出没有任何错误。 tf.where()预计将输出一个二维张量,如here所示:
“坐标以二维张量返回,其中第一维(行)代表真实元素的数量,第二维(列)代表真实元素的坐标”

如前所述,如果您希望输出为一维张量,则可以在情况下添加一个重塑操作,如下所示:

found_indexes = tf.where(same_classes)
found_indexes = tf.reshape(found_indexes, [-1])


希望这可以帮助!

关于python - Tensorflow,其中(索引)“和”为条件,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53393430/

10-11 11:18