本文介绍了从一维张量中提取前 k 个值索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

给定 Torch 中的一维张量 (torch.Tensor),其中包含可以比较的值(比如浮点数),我们如何提取 top-k 的索引 张量中的值?

Given a 1-D tensor in Torch (torch.Tensor), containing values which can be compared (say floating point), how can we extract the indices of the top-k values in that tensor?

除了蛮力方法之外,我正在寻找 Torch/lua 提供的一些 API 调用,它可以有效地执行此任务.

Apart from the brute-force method, I am looking for some API call, that Torch/lua provides, which can perform this task efficiently.

推荐答案

截至拉取请求 #496 Torch 现在包含一个名为 torch.topk.示例:

As of pull request #496 Torch now includes a built-in API named torch.topk. Example:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
 1
 2
 3
[torch.DoubleTensor of size 3]

-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
 2
 4
 6
[torch.LongTensor of size 3]

-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
 9
 8
 7
[torch.DoubleTensor of size 3]

在撰写本文时,CPU 实现遵循 排序和缩小方法(有计划在未来改进它).也就是说,目前正在审查.

At the time of writing the CPU implementation follows a sort and narrow approach (there are plans to improve it in the future). That being said an optimized GPU implementation for cutorch is currently being reviewed.

这篇关于从一维张量中提取前 k 个值索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-22 12:27