我有这个张量:
tf_a1 = [[0 3 1 22]
[3 5 2 2]
[2 6 3 13]
[1 7 0 3 ]
[4 9 11 10]]
我想要做的是找到在所有列中重复的唯一值而不是
threshold
。例如,在这里,
3
在4 columns
中重复。 0
在2 columns
中重复。 2
在3 columns
中重复,依此类推。我希望我的输出是这样的(假设阈值是
2
,所以重复超过2次的索引将被屏蔽)。[[F T F F]
[T F T T]
[T F T F]
[F F F T]
[F F F F]]
这是我所做的:
y, idx, count = tf.unique_with_counts(tf_a1)
tf.where(tf.where(count, tf_a1, tf.zeros_like(tf_a1)))
但这会引发错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError:唯一
期望一维向量。 [Op:UniqueWithCounts]
谢谢。
最佳答案
当前,支持axis的unique_with_counts
API尚未公开。如果要使用多维度张量unique_with_counts
,可以在tf 1.14中这样称呼它:
from tensorflow.python.ops import gen_array_ops
# tensor 'x' is [[1, 0, 0],
# [1, 0, 0],
# [2, 0, 0]]
y, idx, count = gen_array_ops.unique_with_counts_v2(x, [0])
y ==> [[1, 0, 0],
[2, 0, 0]]
idx ==> [0, 0, 1]
count ==> [2, 1]
由于我无法评论@sariii帖子。它应该是
tf.greater(count,2)
。另外,我认为这种解决方案不能满足问题的要求。例如,如果第一个列都是2,而其余列都没有2。根据您的要求,应将2计为1倍。但是在您的解决方案中,如果您首先将形状调整为1d,那么您将丢失此信息。如果我错了,请纠正我。issue#16503