我有这个张量:

tf_a1 = [[0 3 1 22]
        [3 5 2  2]
        [2 6 3 13]
        [1 7 0 3 ]
        [4 9 11 10]]


我想要做的是找到在所有列中重复的唯一值而不是threshold

例如,在这里,34 columns中重复。 02 columns中重复。 23 columns中重复,依此类推。

我希望我的输出是这样的(假设阈值是2,所以重复超过2次的索引将被屏蔽)。

[[F T F F]
 [T F T T]
 [T F T F]
 [F F F T]
 [F F F F]]


这是我所做的:

y, idx, count = tf.unique_with_counts(tf_a1)
tf.where(tf.where(count, tf_a1, tf.zeros_like(tf_a1)))


但这会引发错误:


  tensorflow.python.framework.errors_impl.InvalidArgumentError:唯一
  期望一维向量。 [Op:UniqueWithCounts]


谢谢。

最佳答案

当前,支持axis的unique_with_counts API尚未公开。如果要使用多维度张量unique_with_counts,可以在tf 1.14中这样称呼它:

from tensorflow.python.ops import gen_array_ops

# tensor 'x' is [[1, 0, 0],
#                [1, 0, 0],
#                [2, 0, 0]]
y, idx, count = gen_array_ops.unique_with_counts_v2(x, [0])
y ==> [[1, 0, 0],
       [2, 0, 0]]
idx ==> [0, 0, 1]
count ==> [2, 1]


由于我无法评论@sariii帖子。它应该是tf.greater(count,2)。另外,我认为这种解决方案不能满足问题的要求。例如,如果第一个列都是2,而其余列都没有2。根据您的要求,应将2计为1倍。但是在您的解决方案中,如果您首先将形状调整为1d,那么您将丢失此信息。如果我错了,请纠正我。

issue#16503

09-11 18:59