问题描述
我正在寻找一种 TensorFlow 方式来实现类似于 Python 的 list.index() 函数.
I am looking for a TensorFlow way of implementing something similar to Python's list.index() function.
给定一个矩阵和一个要查找的值,我想知道该值在矩阵的每一行中第一次出现的位置.
Given a matrix and a value to find, I want to know the first occurrence of the value in each row of the matrix.
例如
m is a <batch_size, 100> matrix of integers
val = 23
result = [0] * batch_size
for i, row_elems in enumerate(m):
result[i] = row_elems.index(val)
我不能假设 'val' 在每一行中只出现一次,否则我会使用 tf.argmax(m == val) 实现它.就我而言,重要的是获取 第一次 'val' 出现的索引,而不是任何索引.
I cannot assume that 'val' appears only once in each row, otherwise I would have implemented it using tf.argmax(m == val). In my case, it is important to get the index of the first occurrence of 'val' and not any.
推荐答案
看起来 tf.argmax
的工作方式类似于 np.argmax
(根据 the test),当最大值出现多次时将返回第一个索引.你可以使用 tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)
来得到你想要的.但是,当前 tf.argmax
的行为在多次出现最大值的情况下是未定义的.
It seems that tf.argmax
works like np.argmax
(according to the test), which will return the first index when there are multiple occurrences of the max value.You can use tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)
to get what you want. However, currently the behavior of tf.argmax
is undefined in case of multiple occurrences of the max value.
如果您担心未定义的行为,您可以按照@Igor Tsvetkov 的建议在 tf.where
的返回值上应用 tf.argmin
.例如,
If you are worried about undefined behavior, you can apply tf.argmin
on the return value of tf.where
as @Igor Tsvetkov suggested.For example,
# test with tensorflow r1.0
import tensorflow as tf
val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0 , 0, val, 0, val],
[val, 0, val, val, 0],
[0 , val, 0, 0, 0]]
tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])
with tf.Session() as sess:
print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]
请注意,tf.segment_min
将在某些行不包含 val
时引发 InvalidArgumentError
.在您的代码中,当 row_elems
不包含 val
时,row_elems.index(val)
也会引发异常.
Note that tf.segment_min
will raise InvalidArgumentError
when there is some row containing no val
. In your code row_elems.index(val)
will raise exception too when row_elems
don't contain val
.
这篇关于如何在 TensorFlow 中找到第一个匹配元素的索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!