我有一个2D和1D数组。我正在寻找找到至少包含一次来自1d数组的值的两行,如下所示:

import numpy as np

A = np.array([[0, 3, 1],
           [9, 4, 6],
           [2, 7, 3],
           [1, 8, 9],
           [6, 2, 7],
           [4, 8, 0]])

B = np.array([0,1,2,3])

results = []

for elem in B:
    results.append(np.where(A==elem)[0])


这有效并导致以下数组:

[array([0, 5], dtype=int64),
 array([0, 3], dtype=int64),
 array([2, 4], dtype=int64),
 array([0, 2], dtype=int64)]


但这可能不是最好的进行方法。按照此问题(Search Numpy array with multiple values)中给出的答案,我尝试了以下解决方案:

out1 = np.where(np.in1d(A, B))

num_arr = np.sort(B)
idx = np.searchsorted(B, A)
idx[idx==len(num_arr)] = 0
out2 = A[A == num_arr[idx]]


但是这些给了我不正确的值:

In [36]: out1
Out[36]: (array([ 0,  1,  2,  6,  8,  9, 13, 17], dtype=int64),)

In [37]: out2
Out[37]: array([0, 3, 1, 2, 3, 1, 2, 0])


谢谢你的帮助

最佳答案

由于您要处理2D数组*,因此可以使用广播将B与简化版本的A进行比较。这将使您的索引呈锯齿状。然后,您可以反转结果并使用np.unravel_index获取原始数组中的相应索引。

In [50]: d = np.where(B[:, None] == A.ravel())[1]

In [51]: np.unravel_index(d, A.shape)
Out[51]: (array([0, 5, 0, 3, 2, 4, 0, 2]), array([0, 2, 2, 0, 0, 1, 1, 2]))
                       ^
               # expected result





*来自documentation:对于3维数组,这在代码行方面肯定是有效的,并且对于小型数据集,它在计算上也可能有效。但是,对于大型数据集,创建大型3-d阵列可能会导致性能下降。
同样,广播是一种强大的工具,可用于编写简短且通常直观的代码,从而在C语言中非常高效地执行其计算。但是,在某些情况下,对于特定的算法,广播会不必要地使用大量内存。在这些情况下,最好用Python编写算法的外部循环。这也可能产生更具可读性的代码,因为随着广播中维度的数量增加,使用广播的算法往往变得难以解释。

07-24 09:52
查看更多