给定一个数组“array”和一组索引“index”,如何找到通过矢量化方式沿这些索引拆分数组而形成的子数组的累积和?
为了澄清,假设我有:

>>> array = np.arange(20)
>>> array
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
indices = np.arrray([3, 8, 14])

操作应输出:
array([0, 1, 3, 3, 7, 12, 18, 25, 8, 17, 27, 38, 50, 63, 14, 29, 45, 62, 80, 99])

请注意,数组非常大(100000个元素),因此,我需要一个矢量化的答案。使用任何循环都会大大降低速度。
另外,如果我有同样的问题,但是一个二维数组和相应的索引,我需要对数组中的每一行做同样的事情,我该怎么做?
对于二维版本:
>>>array = np.arange(12).reshape((3,4))
>>>array
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> indices = np.array([[2], [1, 3], [1, 2]])

结果将是:
array([[ 0,  1,  3,  3],
       [ 4,  9,  6, 13],
       [ 8, 17, 10, 11]])

澄清:每一行都将被拆分。

最佳答案

您可以在indices位置引入原始累积和数组的微分,以在这些位置创建类似边界的效果,这样当微分数组被累积和时,就会给出停止累积和输出的索引。乍一看可能会觉得有点做作,但坚持下去,尝试其他样品,希望会有意义这个想法与this other MATLAB solution.中的应用非常相似,因此,遵循这样的哲学,这里有一种方法使用numpy.diffcumulative summation。-

# Get linear indices
n = array.shape[1]
lidx = np.hstack(([id*n+np.array(item) for id,item in enumerate(indices)]))

# Get successive differentiations
diffs = array.cumsum(1).ravel()[lidx] - array.ravel()[lidx]

# Get previous group's offsetted summations for each row at all
# indices positions across the entire 2D array
_,idx = np.unique(lidx/n,return_index=True)
offsetted_diffs = np.diff(np.append(0,diffs))
offsetted_diffs[idx] = diffs[idx]

# Get a copy of input array and place previous group's offsetted summations
# at indices. Then, do cumulative sum which will create a boundary like
# effect with those offsets at indices positions.
arrayc = array.copy()
arrayc.ravel()[lidx] -= offsetted_diffs
out = arrayc.cumsum(1)

这应该是一个几乎矢量化的解决方案,几乎是因为即使我们在循环中计算线性索引,但由于它不是这里计算密集的部分,所以它对整个运行时的影响将是最小的。此外,如果您不关心破坏输入以节省内存,则可以将arrayc替换为array
样本输入、输出-
In [75]: array
Out[75]:
array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23]])

In [76]: indices
Out[76]: array([[3, 6], [4, 7], [5]], dtype=object)

In [77]: out
Out[77]:
array([[ 0,  1,  3,  3,  7, 12,  6, 13],
       [ 8, 17, 27, 38, 12, 25, 39, 15],
       [16, 33, 51, 70, 90, 21, 43, 66]])

07-24 18:05
查看更多