本文介绍了沿 NumPy 数组中的轴获取 N 个最大值和索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我认为对于有经验的 numpy 用户来说,这是一个简单的问题.

I think this is an easy question for experienced numpy users.

我有一个分数矩阵.原始索引对应样本,列索引对应项目.例如,

I have a score matrix. The raw index corresponds to samples and column index corresponds to items. For example,

score_matrix =
  [[ 1. ,  0.3,  0.4],
   [ 0.2,  0.6,  0.8],
   [ 0.1,  0.3,  0.5]]

我想获得每个样本的前 M 个项目索引.我也想获得前 M 分.例如,

I want to get top-M indices of items for each samples. Also I want to get top-M scores. For example,

top2_ind =
  [[0, 2],
   [2, 1],
   [2, 1]]

top2_score =
  [[1. , 0.4],
   [0,8, 0.6],
   [0.5, 0.3]]

使用 numpy 执行此操作的最佳方法是什么?

What is the best way to do this using numpy?

推荐答案

我会使用 argsort():

top2_ind = score_matrix.argsort()[:,::-1][:,:2]

也就是说,生成一个包含对 score_matrix 进行排序的索引的数组:

That is, produce an array which contains the indices which would sort score_matrix:

array([[1, 2, 0],
       [0, 1, 2],
       [0, 1, 2]])

然后用::-1反转列,然后用:2取前两列:

Then reverse the columns with ::-1, then take the first two columns with :2:

array([[0, 2],
       [2, 1],
       [2, 1]])

然后类似但使用常规 np.sort() 来获取值:

Then similar but with regular np.sort() to get the values:

top2_score = np.sort(score_matrix)[:,::-1][:,:2]

遵循与上述相同的机制,为您提供:

Which following the same mechanics as above, gives you:

array([[ 1. ,  0.4],
       [ 0.8,  0.6],
       [ 0.5,  0.3]])

这篇关于沿 NumPy 数组中的轴获取 N 个最大值和索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-13 17:05