问题描述
我研究了张量流切片的不同方式,即tf.gather
和tf.gather_nd
.在tf.gather中,它只是在一个维度上切片,在tf.gather_nd
中,它也只接受一个indices
应用于输入张量.
I have looked at different ways of slicing in tensorflow, namely, tf.gather
and tf.gather_nd
.In tf.gather, it just slices over a dimension, and also in tf.gather_nd
it just accepts one indices
to be applied over the input tensor.
我需要的是不同的,我想使用两个不同的张量在输入张量上切片;一个切片在行上,第二切片在列上,它们不一定具有相同的形状.
What I need is different, I want to slice over the input tensor using two different tensor;one slices over the rows the second slices over the column and they are not in the same shape necessarily.
例如:
假设这是我的输入张量,我想在其中提取部分张量.
suppose this is my input tensor in which I want to extract part of it.
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
第二个是:
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
第三个张量:
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
现在,我想使用rows_tf
和columns_tf
对input_tf
进行切片.行中的索引[1 2 5]
和columns_tf
中的索引[1]
.同样,在columns_tf
中的[1 2 5]
行和[2]
行.
Now, I want to slice over input_tf
using rows_tf
and columns_tf
. index [1 2 5]
in rows and [1]
in columns_tf
. Again, rows [1 2 5]
with [2]
in columns_tf
.
或者,[1 4 6]
和[2]
.
总体而言,rows_tf
中的每个索引与columns_tf
中的相同索引都将提取input_tf
的一部分.
Overall, each index in the rows_tf
, with the same index in columns_tf
will extract part of the input_tf
.
因此,预期输出为:
[[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
例如,此处使用
rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)
关于张量流切片的问题很多,尽管他们使用了tf.gather
或tf.gather_nd
和tf.stack
,但它们没有给出我想要的输出.
There were a couple of questions regarding slicing in tensorflow, though they used tf.gather
or tf.gather_nd
and tf.stack
which it did not give my desired output.
无需提及,在numpy
中,我们可以通过调用input_tf[rows_tf, columns_tf]
轻松地做到这一点.
No need to mention that in numpy
we can easily do that by calling: input_tf[rows_tf, columns_tf]
.
我还查看了这个高级索引,它试图模拟numpy中可用的高级索引,但是它仍然不像numpy flexible https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb
I also, looked at this advanced indexing which tries to simulate the advanced indexing available in numpy, however it still is not like numpy flexible https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb
这是我尝试过的不正确的方法:
This is what I have tried which is not correct:
tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)
此代码的尺寸输出为(8, 1, 3, 8)
,这完全不正确.
the dimension output of this code is (8, 1, 3, 8)
which is incorrect totally.
提前谢谢!
推荐答案
想法是首先获取稀疏索引(通过将行索引和列索引连接在一起)作为列表.然后,您可以使用gather_nd
检索值.
The idea is to first get the sparse indices (by concatenating row index and column index) as a list. Then you can use gather_nd
to retrieve the values.
tf.reset_default_graph()
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
rows_tf = tf.reshape(rows_tf, shape=[-1, 1])
columns_tf = tf.reshape(
tf.tile(columns_tf, multiples=[1, 3]),
shape=[-1, 1])
sparse_indices = tf.reshape(
tf.concat([rows_tf, columns_tf], axis=-1),
shape=[-1, 2])
v = tf.gather_nd(input_tf, sparse_indices)
v = tf.reshape(v, [-1, 3])
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
#print 'rows\n', sess.run(rows_tf)
#print 'columns\n', sess.run(columns_tf)
print sess.run(v)
结果将是:
[[ 8.3355999 0. 8.45768547]
[ 0. 6.10318184 8.60233688]
[ 8.8973999 7.33056402 0. ]
[ 0. 3.89140368 5.82665682]
[ 8.8973999 0. 8.28397083]
[ 6.10318184 3.06143212 5.82665682]
[ 7.33056402 0. 8.28397083]
[ 6.10318184 3.89140368 0. ]]
这篇关于在tensorflow中的张量对象上进行非连续索引切片(高级索引如numpy)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!