本文介绍了二维数组中的前N个值,其中包含要遮罩的重复项的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我有2d numpy数组:
I have 2d numpy array:
arr = np.array([[0.1, 0.1, 0.3, 0.4, 0.5],
[0.06, 0.1, 0.1, 0.1, 0.01],
[0.24, 0.24, 0.24, 0.24, 0.24],
[0.2, 0.25, 0.3, 0.12, 0.02]])
print (arr)
[[0.1 0.1 0.3 0.4 0.5 ]
[0.06 0.1 0.1 0.1 0.01]
[0.24 0.24 0.24 0.24 0.24]
[0.2 0.25 0.3 0.12 0.02]]
我要过滤前N个值,所以我使用argsort
:
I want filter top N values, so I use argsort
:
N = 2
arr1 = np.argsort(-arr, kind='mergesort') < N
print (arr1)
[[False False False True True]
[ True False False True False] <- first top 2 are duplicates
[ True True False False False]
[False True True False False]]
它工作得很好,至少不是顶部重复项,例如第2行.
It working nice, at least not top duplicates, like for row 2.
预期输出:
print (arr1)
[[False False False True True]
[False True True False False]
[ True True False False False]
[False True True False False]]
是否可能有更快的处理方式?
Is possible some faster way for handle it?
推荐答案
切片以获取前N个索引,并使用它们创建最终的掩码-
Slice to get those top N indices and use those to create the final mask -
idx = np.argsort(-arr, kind='mergesort')[:,:N]
mask = np.zeros(arr.shape, dtype=bool)
np.put_along_axis(mask, idx, True, axis=-1)
这篇关于二维数组中的前N个值,其中包含要遮罩的重复项的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!