假设我有一个 N 维 numpy 数组 x 和一个 (N-1) 维索引数组 m (例如, m = x.argmax(axis=-1) )。我想构造 (N-1) 维数组 y 使得 y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]] (对于上面的 argmax 示例,它相当于 y = x.max(axis=-1) )。
对于 N=3,我可以通过以下方式实现我想要的

y = x[np.arange(x.shape[0])[:, np.newaxis], np.arange(x.shape[1]), m]

问题是,我如何为任意 N 执行此操作?

最佳答案

这是使用 reshaping linear indexing 处理任意维度的多维数组的一种方法 -

shp = x.shape[:-1]
n_ele = np.prod(shp)
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

让我们以 ndarray6 dimensions 为例,假设我们使用 m = x.argmax(axis=-1) 索引到最后一个维度。因此,输出将是 x.max(-1) 。让我们为建议的解决方案验证这一点 -
In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4))

In [122]: m = x.argmax(axis=-1)

In [123]: shp = x.shape[:-1]
     ...: n_ele = np.prod(shp)
     ...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
     ...:

In [124]: np.allclose(x.max(-1),y_out)
Out[124]: True

我喜欢 @B. M.'s solution 的优雅。所以,这是一个运行时测试来对这两个进行基准测试 -
def reshape_based(x,m):
    shp = x.shape[:-1]
    n_ele = np.prod(shp)
    return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

def indices_based(x,m):  ## @B. M.'s solution
    firstdims=np.indices(x.shape[:-1])
    ind=tuple(firstdims)+(m,)
    return x[ind]

时间——
In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5))
     ...: m = x.argmax(axis=-1)
     ...:

In [153]: %timeit indices_based(x,m)
10 loops, best of 3: 30.2 ms per loop

In [154]: %timeit reshape_based(x,m)
100 loops, best of 3: 5.14 ms per loop

关于python - numpy 中的索引(与 max/argmax 相关),我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/36315762/

10-12 22:03