我有一个numpy数组,它有一堆单调递增的值。说,

a = [1,2,3,4,6,10,10,11,14]
a_arr=np.array(a)

也说
thresh = 4

我想创建一个数组,它包含a_arr的一个子集的索引,该子集遍历数组,选择元素,但忽略与上一个选择至少间隔thresh的元素这可能更容易用一种算法来描述:
def select_idx(a, thresh):
    ret = []
    for idx, elt in enumerate(a):
        if len(ret) == 0 or elt >= a[ret[-1]] + thresh:
            ret.append(idx)
    return ret

很明显我可以用这个函数来做,但这看起来很慢有没有办法在纽比把这个矢量化?
谢谢。
p.s.在本例中,选择_idx(a,thresh)=[0,4,5,8]
编辑:这个问题的一个近似版本可能更容易矢量化:将数字行分成大小为thresh的桶,我猜从A中的第一个值开始,所以这个例子中的桶分隔符将是0, 4, 8,12, 16,…选择作为其bucket中第一个元素的数字的索引。(是的,我意识到这和我以前写的不一样。)

最佳答案

这里是一个向量化解决你的近似问题:

idx = np.cumsum(np.bincount((a-a[0])/thresh))[:-1]

这将为您提供除始终存在的第一个零以外的所有索引解释如下:
(a-a[0])/thresh执行整数除法(假设a具有整数类型)将值按thresh宽分组。
cumsum(bincount(...))计算每个组的大小并将其转换为索引。注意,如果bucket中没有值,那么bincount将报告0,因此此数组中可能有重复。
最后,我们丢弃最后一个索引,它对应于a的大小。或者,如果索引的顺序无关紧要,则可以利用此功能恢复零索引:
idx = np.cumsum(np.bincount((a-a[0])/thresh)) % len(a)

09-25 16:16