本文介绍了如何从 TensorFlow 的 3-D 张量中选择行?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个尺寸为 [batch_size, num_rows, num_coordinates] 的张量 logits(即批次中的每个 logit 都是一个矩阵).在我的例子中,batch size 为 2,有 4 行和 4 个坐标.

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],[11.0, 10.0, 10.0, 30.0],[12.0, 10.0, 10.0, 20.0],[13.0, 10.0, 10.0, 20.0]],[[14.0, 11.0, 21.0, 31.0],[15.0, 11.0, 11.0, 21.0],[16.0, 11.0, 11.0, 21.0],[17.0, 11.0, 11.0, 21.0]]])

我要选择第一批的第一行和第二行以及第二批的第二行和第四行.

indices = tf.constant([[0, 1], [1, 3]])

所以想要的输出是

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],[11.0, 10.0, 10.0, 30.0]],[[15.0, 11.0, 11.0, 21.0],[17.0, 11.0, 11.0, 21.0]]])

如何使用 TensorFlow 执行此操作?我尝试使用 tf.gather(logits, indices) 但它没有返回我预期的结果.谢谢!

解决方案

这在 TensorFlow 中是可行的,但有点不方便,因为 tf.gather() 当前仅适用于一维索引,并且仅从张量的第 0 维中选择切片.但是,仍然可以通过转换参数以便将它们传递给 tf.gather() 来有效地解决您的问题:

logits = ... # [2 x 4 x 4] 张量指数 = tf.constant([[0, 1], [1, 3]])# 使用 tf.shape() 使其与动态形状一起工作.batch_size = tf.shape(logits)[0]rows_per_batch = tf.shape(logits)[1]index_per_batch = tf.shape(indices)[1]# 偏移量以添加到索引中的每一行.我们使用 `tf.expand_dims()` 来制作# 这个广播适当.偏移量 = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)# 将索引和 logits 转换为适合 `tf.gather()` 的形式.flattened_indices = tf.reshape(indices + offset, [-1])flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))selected_rows = tf.gather(flattened_logits,flattened_indices)结果 = tf.reshape(selected_rows,tf.concat(0, [tf.pack([batch_size,indices_per_batch]),tf.shape(logits)[2:]]))

请注意,由于这使用 tf.reshape() 而不是 tf.transpose(),它不需要修改logits张量中的(可能很大)数据,所以它应该是相当高效的.>

I have a tensor logits with the dimensions [batch_size, num_rows, num_coordinates] (i.e. each logit in the batch is a matrix). In my case batch size is 2, there's 4 rows and 4 coordinates.

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

I want to select the first and second row of the first batch and the second and fourth row of the second batch.

indices = tf.constant([[0, 1], [1, 3]])

So the desired output would be

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0]],
                     [[15.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

How do I do this using TensorFlow? I tried using tf.gather(logits, indices) but it did not return what I expected. Thanks!

解决方案

This is possible in TensorFlow, but slightly inconvenient, because tf.gather() currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather():

logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])

# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))

selected_rows = tf.gather(flattened_logits, flattened_indices)

result = tf.reshape(selected_rows,
                    tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                                  tf.shape(logits)[2:]]))

Note that, since this uses tf.reshape() and not tf.transpose(), it doesn't need to modify the (potentially large) data in the logits tensor, so it should be fairly efficient.

这篇关于如何从 TensorFlow 的 3-D 张量中选择行?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-29 09:02