问题描述
我有两个数组,一个形状为 (200000, 28, 28)
,另一个形状为 (10000, 28, 28)
,所以实际上两个以矩阵为元素的数组.现在我想计算并获取在两个数组中重叠的所有元素(以 (N, 28, 28)
形式).使用普通的 for 循环会变慢,所以我用 numpys intersect1d 方法尝试了它,但我不知道如何将它应用于这种类型的数组.
I have two arrays, one of the shape (200000, 28, 28)
and the other of the shape (10000, 28, 28)
, so practically two arrays with matrices as elements.Now I want to count and get all the elements (in the form (N, 28, 28)
), that overlap in both arrays. With normal for loops it is way to slow, so I tryied it with numpys intersect1d method, but I dont know how to apply it on this types of arrays.
推荐答案
使用来自 这个关于唯一行的问题
def intersect_along_first_axis(a, b):
# check that casting to void will create equal size elements
assert a.shape[1:] == b.shape[1:]
assert a.dtype == b.dtype
# compute dtypes
void_dt = np.dtype((np.void, a.dtype.itemsize * np.prod(a.shape[1:])))
orig_dt = np.dtype((a.dtype, a.shape[1:]))
# convert to 1d void arrays
a = np.ascontiguousarray(a)
b = np.ascontiguousarray(b)
a_void = a.reshape(a.shape[0], -1).view(void_dt)
b_void = b.reshape(b.shape[0], -1).view(void_dt)
# intersect, then convert back
return np.intersect1d(b_void, a_void).view(orig_dt)
请注意,使用 void
对浮点数是不安全的,因为它会导致 -0
不等于 0
Note that using void
is unsafe with floats, as it will cause -0
to be unequal to 0
这篇关于Numpy intersect1d 与以矩阵为元素的数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!