我如何在numpy中应用蒙版以获取此输出?

ar2 = np.arange(1,26)[::-1].reshape([5,5]).T
ar3 = np.array([1,1,-1,-1,1])
print ar2, '\n\n',  ar3

[[25 20 15 10  5]
 [24 19 14  9  4]
 [23 18 13  8  3]
 [22 17 12  7  2]
 [21 16 11  6  1]]

[ 1  1 -1 -1  1]


-适用于ar3 = 1的地方:ar2/ar2[:,0][:, np.newaxis]

--apply ar3 = -1时:ar2/ar2[:,4][:, np.newaxis]

我追求的结果是:

[[1 0 0 0 0]
 [1 0 0 0 0]
 [ 7  6  4  2  1]
 [11  8  6  3  1]
 [1 0 0 0 0]]


我试过了np.where()

最佳答案

我不明白为什么np.where在这里不应该工作:

>>> np.where((ar3==1)[:, None],
...          ar2 // ar2[:, [0]],  # where condition is True, divide by first column
...          ar2 // ar2[:, [4]])  # where condition is False, divide by last column
array([[ 1,  0,  0,  0,  0],
       [ 1,  0,  0,  0,  0],
       [ 7,  6,  4,  2,  1],
       [11,  8,  6,  3,  1],
       [ 1,  0,  0,  0,  0]])


我正在使用Python 3,这就是为什么我使用//(底数除法)而不是常规除法(/)的原因,否则结果将包含浮点数。

这会急切地计算数组,因此会为所有值评估ar2 // ar2[:, [0]]ar2 // ar2[:, [4]]。有效地在内存中保存3个大小为ar2的数组(结果和两个临时变量)。如果要提高内存效率,则需要在执行操作之前应用遮罩:

>>> res = np.empty_like(ar2)
>>> mask = ar3 == 1
>>> res[mask] = ar2[mask] // ar2[mask][:, [0]]
>>> res[~mask] = ar2[~mask] // ar2[~mask][:, [4]]
>>> res
array([[ 1,  0,  0,  0,  0],
       [ 1,  0,  0,  0,  0],
       [ 7,  6,  4,  2,  1],
       [11,  8,  6,  3,  1],
       [ 1,  0,  0,  0,  0]])


这只会计算使用较少内存(可能也更快)的必要值。

10-07 22:23