使用(n-1)维数组沿给定维访问n维数组的最优雅的方法是什么,如虚拟示例中所示

a = np.random.random_sample((3,4,4))
b = np.random.random_sample((3,4,4))
idx = np.argmax(a, axis=0)

现在我如何使用idx a访问以获得a中的最大值,就像使用了a.max(axis=0)一样?或者如何检索idx中的b指定的值?
我曾考虑过使用np.meshgrid但我认为这是一种过度杀伤力。注意,尺寸axis可以是任何有用的轴(0,1,2),并且事先不知道。有优雅的方法吗?

最佳答案

利用-

m,n = a.shape[1:]
I,J = np.ogrid[:m,:n]
a_max_values = a[idx, I, J]
b_max_values = b[idx, I, J]

对于一般情况:
def argmax_to_max(arr, argmax, axis):
    """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)"""
    new_shape = list(arr.shape)
    del new_shape[axis]

    grid = np.ogrid[tuple(map(slice, new_shape))]
    grid.insert(axis, argmax)

    return arr[tuple(grid)]

不幸的是,比这种自然的操作要尴尬得多。
对于使用advanced-indexing数组对n dim数组进行索引,我们可以稍微简化它,为所有轴提供索引网格,就像这样。-
def all_idx(idx, axis):
    grid = np.ogrid[tuple(map(slice, idx.shape))]
    grid.insert(axis, idx)
    return tuple(grid)

因此,使用它对输入数组进行索引-
axis = 0
a_max_values = a[all_idx(idx, axis=axis)]
b_max_values = b[all_idx(idx, axis=axis)]

关于python - 索引n维数组与(n-1)d数组,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/46103044/

10-12 21:57
查看更多