问题描述
我有一个形状为(?,368,5)
的参数张量,以及一个形状为的查询张量(?, 368)
。查询张量存储用于对第一张量进行排序的索引。
I've got a params tensor with shape (?,368,5)
, as well as a query tensor with shape (?,368)
. The query tensor stores indices for sorting the first tensor.
所需的输出形状为:(?,368,5)
。由于我需要它用于神经网络中的损失函数,因此使用的操作应该保持可微。此外,在运行时,第一个轴?
的大小对应于batchsize。
The required output has shape: (?,368,5)
. Since I need it for a loss function in a neural network, the used operations should stay differentiable. Also, at runtime the size of the first axis ?
corresponds to the batchsize.
到目前为止,我试验了 tf.gather
和 tf.gather_nd
,但是
tf.gather(params,查询)
产生一个形状(?,368,368,5)
的张量。
So far I experimented with tf.gather
and tf.gather_nd
, howevertf.gather(params,query)
results in a tensor with shape (?,368,368,5)
.
查询张量是通过以下方式实现的:
The query tensor is achieved by performing:
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
总的来说,我尝试通过第三轴上的第一个元素对params张量进行排序(对于倒角距离的种类)。最后要提到的是,我使用 Keras
框架。
Overall, I try to sort the params tensor by the first element on the third axis (for kind of a chamfer distance). At last to mention is, that I work with the Keras
framework.
推荐答案
您需要将第一个维度的索引添加到 query
,以便将其与 tf.gather_nd
一起使用。这是一种方法:
You need to add the indices of the first dimension to query
in order to use it with tf.gather_nd
. Here is a way to do it:
import tensorflow as tf
import numpy as np
np.random.seed(100)
with tf.Graph().as_default(), tf.Session() as sess:
params = tf.placeholder(tf.float32, [None, 368, 5])
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
n = tf.shape(params)[0]
# Make tensor of indices for the first dimension
ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
# Stack indices
idx = tf.stack([ii, query], axis=-1)
# Gather reordered tensor
result = tf.gather_nd(params, idx)
# Test
out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
# Check the order is correct
print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
# True
这篇关于TensorFlow,批量索引(第一维)和排序的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!