我有一个相当简单的代码块,我想提高其性能。它由一个 for 块组成,该块使用 np.where() 来查找数组中整数的索引。

下面的代码有效,但我觉得使用 for 将元素添加到空列表并不是解决此问题的最佳方法。

该块由 MCMC 使用,因此它被执行了数百万次。小改进变成大改进。这可以提高效率吗?

import numpy as np

N = 20
# Integers from 1 to N
ran_indexes = np.random.randint(1, N, 1000)
# Number of integers to remove
rm_number = np.random.randint(0, 100, N)

# Better performance for this block?
# For each integer from 1 to N, keep only 'd' indexes of 'ran_indexes' that
# contain that integer, where 'd' is the ith element in 'rm_number'
new_indexes = []
for i, d in enumerate(rm_number):
    new_indexes += list(np.where(ran_indexes == i + 1)[0][:d])

最佳答案

列表连接 += 很慢,因为它们每次都需要一个全新的列表。更常见的是,在迭代构建数组时,我们使用列表追加,它是就地的,并且每次只将元素添加到列表中。

In [45]:
    ...: new_indexes = []
    ...: for i, d in enumerate(rm_number):
    ...:     new_indexes.append(np.where(ran_indexes == i + 1)[0][:d])
    ...:
In [46]: new_indexes
Out[46]:
[array([  5,  96, 143, 150, 154, 175]),
 array([ 14,  22,  26,  28,  32,  38,  46,  54,  70, 205, 218, 242, 248,
        254, 271, 318, 344, 352, 357, 393, 419, 437, 448, 472, 473, 503,
        521, 548, 558, 629, 631, 654, 661, 685, 699, 743, 755]),
 array([ 24,  34,  72,  97, 120, 140, 173, 181, 193, 199, 200, 225, 239,
        251, 265, 296, 350, 386, 411, 422, 465, 476, 506, 533, 609, 628,
        680, 694, 713, 759]),
 ....

通过这种结构,每个数组( where 结果)的长度不同,上限来自 rm_number :
In [89]: [len(i) for i in new_indexes]-rm_number
Out[89]:
array([  0,   0,   0,   0,   0,   0,   0,  -2, -24, -40,   0,  -3, -40,
         0, -15,  -5,   0,   0,   0, -96])

像这样的可变长度数组/列表很好地表明您无法进行超快速的“矢量化”(整个数组)操作,至少在没有显着的聪明的情况下并非如此。

我们可以获得您的代码生成的平面列表:
In [50]: np.concatenate(new_indexes).shape
Out[50]: (626,)

一些时间:
In [53]: %%timeit
    ...: new_indexes = []
    ...: for i, d in enumerate(rm_number):
    ...:     new_indexes += list(np.where(ran_indexes == i + 1)[0][:d])
    ...:

320 µs ± 7.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [54]:
In [54]: %%timeit
    ...: new_indexes = []
    ...: for i, d in enumerate(rm_number):
    ...:     new_indexes.append(np.where(ran_indexes == i + 1)[0][:d])
    ...:

184 µs ± 268 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [55]:
In [55]: %%timeit
    ...: new_indexes = []
    ...: for i, d in enumerate(rm_number):
    ...:     new_indexes.append(np.where(ran_indexes == i + 1)[0][:d])
    ...: new_indexes=np.concatenate(new_indexes)
    ...:
    ...:
193 µs ± 622 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [79]: timeit f2()  # Lukas
291 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

===
temp = ran_indexes[:,None]==np.arange(1,21)

找到所有匹配项,np.where(temp)[0] 是索引。但这不适用于您的 rm_number 界限。
np.where(temp.T)[1]    # without the `rm_number` truncation

np.where(temp[:,i])[0][:d]

关于python - 提高索引搜索和删除的性能,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/59176659/

10-12 06:02