我有一个numpy数组,它有一堆单调递增的值。说,
a = [1,2,3,4,6,10,10,11,14]
a_arr=np.array(a)
也说
thresh = 4
我想创建一个数组,它包含
a_arr
的一个子集的索引,该子集遍历数组,选择元素,但忽略与上一个选择至少间隔thresh
的元素这可能更容易用一种算法来描述:def select_idx(a, thresh):
ret = []
for idx, elt in enumerate(a):
if len(ret) == 0 or elt >= a[ret[-1]] + thresh:
ret.append(idx)
return ret
很明显我可以用这个函数来做,但这看起来很慢有没有办法在纽比把这个矢量化?
谢谢。
p.s.在本例中,选择_idx(a,thresh)=[0,4,5,8]
编辑:这个问题的一个近似版本可能更容易矢量化:将数字行分成大小为
thresh
的桶,我猜从A中的第一个值开始,所以这个例子中的桶分隔符将是0, 4, 8,12, 16,…选择作为其bucket中第一个元素的数字的索引。(是的,我意识到这和我以前写的不一样。) 最佳答案
这里是一个向量化解决你的近似问题:
idx = np.cumsum(np.bincount((a-a[0])/thresh))[:-1]
这将为您提供除始终存在的第一个零以外的所有索引解释如下:
(a-a[0])/thresh
执行整数除法(假设a
具有整数类型)将值按thresh
宽分组。cumsum(bincount(...))
计算每个组的大小并将其转换为索引。注意,如果bucket中没有值,那么bincount
将报告0,因此此数组中可能有重复。最后,我们丢弃最后一个索引,它对应于
a
的大小。或者,如果索引的顺序无关紧要,则可以利用此功能恢复零索引:idx = np.cumsum(np.bincount((a-a[0])/thresh)) % len(a)