给定一个3D数组,例如:
array = np.random.randint(1, 6, (3, 3, 3))
以及沿轴0的最大值数组:
max_array = array.max(axis=0)
是否有一种矢量化的方法可以计算数组中轴0上等于max_array中匹配索引值的元素数量?例如,如果数组在一个轴0位置包含[1、3、3],则输出为2,对于其他8个位置,依此类推,返回带有计数的数组。
最佳答案
要计算x
中等于xmax
中相应值的值数,可以使用:
(x == xmax).sum(axis=0)
请注意,由于
x
具有形状(3,3,3),而xmax
具有形状(3,3),表达式x == xmax
导致NumPy变为broadcast xmax
直到形状(3,3,3 ),新轴添加到左侧。例如,
import numpy as np
np.random.seed(2015)
x = np.random.randint(1, 6, (3,3,3))
print(x)
# [[[3 5 5]
# [3 2 1]
# [3 4 1]]
# [[1 5 4]
# [1 4 1]
# [2 3 4]]
# [[2 3 3]
# [2 1 1]
# [5 1 2]]]
xmax = x.max(axis=0)
print(xmax)
# [[3 5 5]
# [3 4 1]
# [5 4 4]]
count = (x == xmax).sum(axis=0)
print(count)
# [[1 2 1]
# [1 1 3]
# [1 1 1]]