我有一个形状为m的蒙版[bz]和一个字典d包含许多不同的ndarrays,例如d['s'].shape=(bz, 84, 84, 4)d['r'].shape=(bz, 1)等。所有尺寸都具有相同的第一尺寸bz,但其他尺寸可能有所不同。我想适当地扩展m的尺寸,以便可以将m乘以d中的值。例如,我可能希望m的形状为[bz, 1, 1, 1],以便将d['s'][bz, 1]乘以d['r']。我可以想到一个讨厌的while循环解决方案,如下所示。

for k, v in d.items():
    if m.shape != v.shape:
        reshaped_m = m.copy()
        while len(reshaped_m.shape) < len(v.shape):
            reshaped_m = reshaped_m[..., None]
        d[k] = v * reshaped_m


我正在寻找更好的解决方案,谢谢。

最佳答案

您可以简单地-

for k, v in d.items():
    d[k] = (v.T*m).T


因此,我们基本上是将第一个轴推到末端,以便v可以针对m进行广播。然后,乘以m。最后,将最后一个轴推回到前面。

可能有不同的方法来排列这些轴,但是transposing是最简单的方法。

如果遮罩的尺寸大于1,则我们也需要转置m。因此。在这种情况下,它将为d[k] = (v.T*m.T).T



另一种方法是重塑m,然后与v相乘-

for k, v in d.items():
    d[k] =  v*m.reshape(m.shape + (1,)*(v.ndim-m.ndim))




另一个是np.einsum-

for k, v in d.items():
    d[k] = np.einsum('i...,i->i...',v,m)

关于python - 扩展numpy数组的维数以满足不同要求的最佳方法,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/58687266/

10-11 19:38