问题描述
我有一个尺寸为 [batch_size, num_rows, num_coordinates]
的张量 logits
(即批次中的每个 logit 都是一个矩阵).在我的例子中,batch size 为 2,有 4 行和 4 个坐标.
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],[11.0, 10.0, 10.0, 30.0],[12.0, 10.0, 10.0, 20.0],[13.0, 10.0, 10.0, 20.0]],[[14.0, 11.0, 21.0, 31.0],[15.0, 11.0, 11.0, 21.0],[16.0, 11.0, 11.0, 21.0],[17.0, 11.0, 11.0, 21.0]]])
我要选择第一批的第一行和第二行以及第二批的第二行和第四行.
indices = tf.constant([[0, 1], [1, 3]])
所以想要的输出是
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],[11.0, 10.0, 10.0, 30.0]],[[15.0, 11.0, 11.0, 21.0],[17.0, 11.0, 11.0, 21.0]]])
如何使用 TensorFlow 执行此操作?我尝试使用 tf.gather(logits, indices)
但它没有返回我预期的结果.谢谢!
这在 TensorFlow 中是可行的,但有点不方便,因为 tf.gather()
当前仅适用于一维索引,并且仅从张量的第 0 维中选择切片.但是,仍然可以通过转换参数以便将它们传递给 tf.gather()
来有效地解决您的问题:
logits = ... # [2 x 4 x 4] 张量指数 = tf.constant([[0, 1], [1, 3]])# 使用 tf.shape() 使其与动态形状一起工作.batch_size = tf.shape(logits)[0]rows_per_batch = tf.shape(logits)[1]index_per_batch = tf.shape(indices)[1]# 偏移量以添加到索引中的每一行.我们使用 `tf.expand_dims()` 来制作# 这个广播适当.偏移量 = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)# 将索引和 logits 转换为适合 `tf.gather()` 的形式.flattened_indices = tf.reshape(indices + offset, [-1])flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))selected_rows = tf.gather(flattened_logits,flattened_indices)结果 = tf.reshape(selected_rows,tf.concat(0, [tf.pack([batch_size,indices_per_batch]),tf.shape(logits)[2:]]))
请注意,由于这使用 tf.reshape()
而不是 tf.transpose()
,它不需要修改logits
张量中的(可能很大)数据,所以它应该是相当高效的.>
I have a tensor logits
with the dimensions [batch_size, num_rows, num_coordinates]
(i.e. each logit in the batch is a matrix). In my case batch size is 2, there's 4 rows and 4 coordinates.
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
I want to select the first and second row of the first batch and the second and fourth row of the second batch.
indices = tf.constant([[0, 1], [1, 3]])
So the desired output would be
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0]],
[[15.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
How do I do this using TensorFlow? I tried using tf.gather(logits, indices)
but it did not return what I expected. Thanks!
This is possible in TensorFlow, but slightly inconvenient, because tf.gather()
currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather()
:
logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])
# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]
# Offset to add to each row in indices. We use `tf.expand_dims()` to make
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)
# Convert indices and logits into appropriate form for `tf.gather()`.
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))
selected_rows = tf.gather(flattened_logits, flattened_indices)
result = tf.reshape(selected_rows,
tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
tf.shape(logits)[2:]]))
Note that, since this uses tf.reshape()
and not tf.transpose()
, it doesn't need to modify the (potentially large) data in the logits
tensor, so it should be fairly efficient.
这篇关于如何从 TensorFlow 的 3-D 张量中选择行?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!