沿具有给定索引的维度对张量进行切片

沿具有给定索引的维度对张量进行切片

本文介绍了沿具有给定索引的维度对张量进行切片的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个张量:

tensor = tf.constant(
  [[[0.05340263, 0.27248233, 0.49127685, 0.07926575, 0.96054204],
    [0.50013988, 0.05903472, 0.43025479, 0.41379231, 0.86508251],
    [0.02033722, 0.11996034, 0.57675261, 0.12049974, 0.65760677],
    [0.71859089, 0.22825203, 0.64064407, 0.47443116, 0.64108334]],

   [[0.18813498, 0.29462021, 0.09433628, 0.97393446, 0.33451445],
    [0.01657461, 0.28126666, 0.64016929, 0.48365073, 0.26672697],
    [0.9379696 , 0.44648103, 0.39463243, 0.51797975, 0.4173626 ],
    [0.89788558, 0.31063058, 0.05492096, 0.86904097, 0.21696292]],

   [[0.07279436, 0.94773635, 0.34173115, 0.7228713 , 0.46553334],
    [0.61199848, 0.88508141, 0.97019517, 0.61465985, 0.48971128],
    [0.53037002, 0.70782324, 0.32158754, 0.2793538 , 0.62661128],
    [0.52787814, 0.17085317, 0.83711126, 0.40567032, 0.71386498]]])

形状为 (3, 4, 5)

which is of shape (3, 4, 5)

我想对它进行切片以返回形状为 (3,5) 的新张量,使用给定的一维张量,其值指示要检索的位置,例如:

I want to slice it to return a new tensor of shape (3,5), with a given 1D tensor whose value indicates which position to retrieve, for example:

index_tensor = tf.constant([2,1,3])

这会产生一个新的张量,如下所示:

which results in a new tensor which looks like this:

[[0.02033722, 0.11996034, 0.57675261, 0.12049974, 0.65760677],
 [0.01657461, 0.28126666, 0.64016929, 0.48365073, 0.26672697],
 [0.52787814, 0.17085317, 0.83711126, 0.40567032, 0.71386498]]

也就是说,沿着第二个维度,从索引 2、1 和 3 中获取项目.它类似于:

that is , along the second dimension, take items from index 2, 1, and 3.It is similar to do:

tensor[:,x,:]

除了这只会给我沿维度索引x"处的项目,我希望它是灵活的.

except this will only give me item at index 'x' along the dimension, and I want it to be flexible.

这能做到吗?

推荐答案

您可以使用 tf.one_hot() 来屏蔽 index_tensor.

You can use tf.one_hot() to mask index_tensor.

index = tf.one_hot(index_tensor,tensor.shape[1])

[[0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]]

然后通过 tf.boolean_mask() 得到你的结果.

Then get your result by tf.boolean_mask().

result = tf.boolean_mask(tensor,index)

[[0.02033722 0.11996034 0.57675261 0.12049974 0.65760677]
 [0.01657461 0.28126666 0.64016929 0.48365073 0.26672697]
 [0.52787814 0.17085317 0.83711126 0.40567032 0.71386498]]

这篇关于沿具有给定索引的维度对张量进行切片的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-11 13:25