我有一个2D和1D数组。我正在寻找找到至少包含一次来自1d数组的值的两行,如下所示:
import numpy as np
A = np.array([[0, 3, 1],
[9, 4, 6],
[2, 7, 3],
[1, 8, 9],
[6, 2, 7],
[4, 8, 0]])
B = np.array([0,1,2,3])
results = []
for elem in B:
results.append(np.where(A==elem)[0])
这有效并导致以下数组:
[array([0, 5], dtype=int64),
array([0, 3], dtype=int64),
array([2, 4], dtype=int64),
array([0, 2], dtype=int64)]
但这可能不是最好的进行方法。按照此问题(Search Numpy array with multiple values)中给出的答案,我尝试了以下解决方案:
out1 = np.where(np.in1d(A, B))
num_arr = np.sort(B)
idx = np.searchsorted(B, A)
idx[idx==len(num_arr)] = 0
out2 = A[A == num_arr[idx]]
但是这些给了我不正确的值:
In [36]: out1
Out[36]: (array([ 0, 1, 2, 6, 8, 9, 13, 17], dtype=int64),)
In [37]: out2
Out[37]: array([0, 3, 1, 2, 3, 1, 2, 0])
谢谢你的帮助
最佳答案
由于您要处理2D数组*,因此可以使用广播将B
与简化版本的A
进行比较。这将使您的索引呈锯齿状。然后,您可以反转结果并使用np.unravel_index
获取原始数组中的相应索引。
In [50]: d = np.where(B[:, None] == A.ravel())[1]
In [51]: np.unravel_index(d, A.shape)
Out[51]: (array([0, 5, 0, 3, 2, 4, 0, 2]), array([0, 2, 2, 0, 0, 1, 1, 2]))
^
# expected result
*来自documentation:对于3维数组,这在代码行方面肯定是有效的,并且对于小型数据集,它在计算上也可能有效。但是,对于大型数据集,创建大型3-d阵列可能会导致性能下降。
同样,广播是一种强大的工具,可用于编写简短且通常直观的代码,从而在C语言中非常高效地执行其计算。但是,在某些情况下,对于特定的算法,广播会不必要地使用大量内存。在这些情况下,最好用Python编写算法的外部循环。这也可能产生更具可读性的代码,因为随着广播中维度的数量增加,使用广播的算法往往变得难以解释。