我想在numpy中执行一些相对简单的操作:
如果行中有一个,则返回包含1+1的列的索引。
如果行中有零个或多个,则返回0。
但是我得到了一个非常复杂的代码:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)

result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
    result[idx] = np.where(predictions[idx,:]==1)[1] + 1

如果我打印result,程序将打印
[ 1. 0. 4. 0.]这是期望的结果。
我想知道是否有一种更简单的方法来使用numpy写最后一行。

最佳答案

我不确定这是否更好,但你可以尝试使用argmax来实现这一点。此外,不需要使用for loop和np.where来获取有效索引:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1

In [55]: result
Out[55]: array([ 1.,  0.,  4.,  0.])

或者你可以使用np.whereargmax一行完成所有这些工作:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])

09-25 16:15