如何基于带索引的张量过滤张量流的张量

如何基于带索引的张量过滤张量流的张量

本文介绍了如何基于带索引的张量过滤张量流的张量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个大小为[batch_size, 5, 10]的张量称为my_tensor.我还有一个大小为[batch_size, 1]的张量,其中包含索引为selecter的张量.

我想相对于selecter过滤my_tensor以产生大小为[batch_size, 10]的新张量,即仅选择selecter包含的值.基本上,这是在减小中间尺寸(尺寸为5).

我觉得tf.where是正确的选择,但不确定.非常感谢您的帮助!

解决方案

解决方案是与tf.gather_nd一起使用.

tf.gather_nd(
    my_tensor,
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))

如果从头开始将selecter构造为一维,则可以摆脱squeeze.

Let's say I have a tensor of size [batch_size, 5, 10] called my_tensor.I also have an another tensor of size [batch_size, 1] holding indices called selecter.

I want to filter my_tensor with respect to selecter to produce new tensor of size [batch_size, 10], i.e. select only values that selecter contains. Basically, it's kinda reducing the middle dimension(which has size 5).

I feel like tf.where is the right choice, but not sure about it.I would really appreciate your help!

解决方案

The solution is to go with tf.gather_nd.

tf.gather_nd(
    my_tensor,
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))

You can get rid of the squeeze if you construct selecter to be 1-D from the beginning.

这篇关于如何基于带索引的张量过滤张量流的张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-03 10:07