问题描述
假设我有一个大小为[batch_size, 5, 10]
的张量称为my_tensor
.我还有一个大小为[batch_size, 1]
的张量,其中包含索引为selecter
的张量.我想相对于selecter
过滤my_tensor
以产生大小为[batch_size, 10]
的新张量,即仅选择selecter
包含的值.基本上,这是在减小中间尺寸(尺寸为5).
我觉得tf.where
是正确的选择,但不确定.非常感谢您的帮助!
解决方案是与tf.gather_nd
一起使用.
tf.gather_nd(
my_tensor,
tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))
如果从头开始将selecter
构造为一维,则可以摆脱squeeze
.
Let's say I have a tensor of size [batch_size, 5, 10]
called my_tensor
.I also have an another tensor of size [batch_size, 1]
holding indices called selecter
.
I want to filter my_tensor
with respect to selecter
to produce new tensor of size [batch_size, 10]
, i.e. select only values that selecter
contains. Basically, it's kinda reducing the middle dimension(which has size 5).
I feel like tf.where
is the right choice, but not sure about it.I would really appreciate your help!
The solution is to go with tf.gather_nd
.
tf.gather_nd(
my_tensor,
tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))
You can get rid of the squeeze
if you construct selecter
to be 1-D from the beginning.
这篇关于如何基于带索引的张量过滤张量流的张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!