本文介绍了Tensorflow tf.gather 与轴参数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我正在使用 tensorflow 的 tf.gather
从多维数组中获取元素,如下所示:
I am using tensorflow's tf.gather
to get elements from a multidimensional array like this:
import tensorflow as tf
indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
result = tf.gather(x, indices, axis=1)
with tf.Session() as sess:
selection = sess.run(result)
print(selection)
导致:
[[1 2 2]
[4 5 5]
[7 8 8]]
我想要的是:
[1
5
8]
如何使用 tf.gather
在指定轴上应用单个索引?(结果与此答案中指定的解决方法相同:https://stackoverflow.com/a/41845855/9763766)
how can I use tf.gather
to apply the single indices on the specified axis?(Same result as the workaround specified in this answer: https://stackoverflow.com/a/41845855/9763766)
推荐答案
您需要将indices
转换为full索引
,并使用gather_nd
>.可以通过以下方式实现:
You need to convert the indices
to full indices
, and using gather_nd
. Can be achieved by doing:
result = tf.squeeze(tf.gather_nd(x,tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices[...,tf.newaxis]], axis=2)))
这篇关于Tensorflow tf.gather 与轴参数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!