This question already has answers here:
numpy matrix, setting 0 to values by sorting each row

(2 个回答)


3年前关闭。




我有一个 numpy 数据数组,我只需要保留 n 最高值,其他所有内容都为零。

我目前的解决方案:
import numpy as np
np.random.seed(30)

# keep only the n highest values
n = 3

# Simple 2x5 data field for this example, real life application will be exteremely large
data = np.random.random((2,5))
#[[ 0.64414354  0.38074849  0.66304791  0.16365073  0.96260781]
# [ 0.34666184  0.99175099  0.2350579   0.58569427  0.4066901 ]]


# find indices of the n highest values per row
idx = np.argsort(data)[:,-n:]
#[[0 2 4]
# [4 3 1]]


# put those values back in a blank array
data_ = np.zeros(data.shape) # blank slate
for i in xrange(data.shape[0]):
    data_[i,idx[i]] = data[i,idx[i]]

# Each row contains only the 3 highest values per row or the original data
#[[ 0.64414354  0.          0.66304791  0.          0.96260781]
# [ 0.          0.99175099  0.          0.58569427  0.4066901 ]]

在上面的代码中,data_ 具有 n 最高值,其他所有内容都归零。即使 data.shape[1] 小于 n ,这也能很好地工作。但唯一的问题是 for loop ,它很慢,因为我的实际用例是在非常非常大的数组上。

是否有可能摆脱 for 循环?

最佳答案

您可以对 np.argsort 的结果进行操作 - np.argsort 两次,第一个获取索引顺序,第二个获取排名 - 以矢量化方式,然后使用 np.where 或简单地乘法将其他所有内容归零:

In [116]: np.argsort(data)
Out[116]:
array([[3, 1, 0, 2, 4],
       [2, 0, 4, 3, 1]])

In [117]: np.argsort(np.argsort(data))  # these are the ranks
Out[117]:
array([[2, 1, 3, 0, 4],
       [1, 4, 0, 3, 2]])

In [118]: np.argsort(np.argsort(data)) >= data.shape[1] - 3
Out[118]:
array([[ True, False,  True, False,  True],
       [False,  True, False,  True,  True]], dtype=bool)

In [119]: data * (np.argsort(np.argsort(data)) >= data.shape[1] - 3)
Out[119]:
array([[ 0.64414354,  0.        ,  0.66304791,  0.        ,  0.96260781],
       [ 0.        ,  0.99175099,  0.        ,  0.58569427,  0.4066901 ]])

In [120]: np.where(np.argsort(np.argsort(data)) >= data.shape[1]-3, data, 0)
Out[120]:
array([[ 0.64414354,  0.        ,  0.66304791,  0.        ,  0.96260781],
       [ 0.        ,  0.99175099,  0.        ,  0.58569427,  0.4066901 ]])

关于python - 保留 numpy 数组每一行的 n 个最高值,并将其他所有内容归零,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/47106112/

10-11 22:39