本文介绍了多维张量的前 K 个指数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我有一个二维张量,我想获得前 k 个值的索引.我知道 pytorch 的 topk 功能.pytorch 的 topk 函数的问题在于,它计算某个维度上的 topk 值.我想获得两个维度的 topk 值.
I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.
例如下面的张量
a = torch.tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
pytorch 的 topk 函数会给我以下信息.
pytorch's topk function will give me the following.
values, indices = torch.topk(a, 3)
print(indices)
# tensor([[1, 2, 0],
# [0, 2, 1],
# [0, 1, 4],
# [1, 4, 3],
# [1, 0, 4]])
但我想得到以下内容
tensor([[0, 1],
[2, 0],
[3, 1]])
这是二维张量中 9 的索引.
This is the indices of 9 in the 2D tensor.
是否有任何方法可以使用 pytorch 实现此目的?
Is there any approach to achieve this using pytorch?
推荐答案
v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)
输出:
[[3 1]
[2 0]
[0 1]]
- 展平并找到顶部 k
- 使用
unravel_index
将一维索引转换为二维索引
- Flatten and find top k
- Convert 1D indices to 2D using
unravel_index
这篇关于多维张量的前 K 个指数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!