np.compress
在内部做什么,使其比布尔索引更快?
在此示例中,compress
快20%,但是节省的时间取决于a
的大小和布尔数组True
中b
值的数量,但在我的机器上compress
始终是快点。
import numpy as np
a = np.random.rand(1000000,4)
b = (a[:,0]>0.5)
%timeit a[b]
#>>> 10 loops, best of 3: 24.7 ms per loop
%timeit a.compress(b, axis=0)
#>>> 10 loops, best of 3: 20 ms per loop
documentation for boolean indexing说
返回的是数据的副本,而不是带有切片的视图
相反,compress docs说
沿给定轴返回数组的选定切片”。
但是,使用method provided here确定两个数组是否共享相同的数据缓冲区表明,这两个方法都没有与其父级
a
共享数据,这意味着两个方法都不返回实际的切片。def get_data_base(arr):
base = arr
while isinstance(base.base, np.ndarray):
base = base.base
return base
def arrays_share_data(x, y):
return get_data_base(x) is get_data_base(y)
arrays_share_data(a, a.compress(b, axis=0))
#>>> False
arrays_share_data(a, a[b])
#>>> False
我只是好奇,因为我在工作中经常执行这些操作。我运行通过Anaconda安装的python 3.5.2,numpy v 1.11.1。
最佳答案
通过对a.compress
numpy
进行的多层函数调用来跟踪github
/numpy/core/src/multiarray/item_selection.c
PyArray_Compress(PyArrayObject *self, PyObject *condition, int axis,
PyArrayObject *out)
# various checks
res = PyArray_Nonzero(cond);
ret = PyArray_TakeFrom(self, PyTuple_GET_ITEM(res, 0), axis,
out, NPY_RAISE);
对于您的样本数组,
compress
与执行where
以获得索引数组,然后执行take
相同:In [135]: a.shape
Out[135]: (1000000, 4)
In [136]: b.shape
Out[136]: (1000000,)
In [137]: a.compress(b, axis=0).shape
Out[137]: (499780, 4)
In [138]: a.take(np.nonzero(b)[0], axis=0).shape
Out[138]: (499780, 4)
In [139]: timeit a.compress(b, axis=0).shape
100 loops, best of 3: 14.3 ms per loop
In [140]: timeit a.take(np.nonzero(b)[0], axis=0).shape
100 loops, best of 3: 14.3 ms per loop
实际上,如果我在[]索引中使用此索引数组,我将获得可比的时间:
In [141]: idx=np.where(b)[0]
In [142]: idx.shape
Out[142]: (499780,)
In [143]: timeit a[idx,:].shape
100 loops, best of 3: 14.6 ms per loop
In [144]: timeit np.take(a,idx, axis=0).shape
100 loops, best of 3: 9.9 ms per loop
np.take
代码更复杂,因为它包含clip
和wrap
模式。[]索引通过各个层转换为
__getitem__
调用。我没有追溯到代码的差异很大,但我认为可以肯定地说compress
(或更确切地说是take
)只是采用了一条更直接的方法来执行任务,因此速度有所提高。 30-50%的速度差异表明编译后的代码细节上的差异,而不是像views
vs copies
或解释vs编译这样的专业。关于python - 为什么np.compress比 bool 索引更快?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/44487889/