我有一个形状为m
的蒙版[bz]
和一个字典d
包含许多不同的ndarrays,例如d['s'].shape=(bz, 84, 84, 4)
,d['r'].shape=(bz, 1)
等。所有尺寸都具有相同的第一尺寸bz
,但其他尺寸可能有所不同。我想适当地扩展m
的尺寸,以便可以将m
乘以d
中的值。例如,我可能希望m
的形状为[bz, 1, 1, 1]
,以便将d['s']
和[bz, 1]
乘以d['r']
。我可以想到一个讨厌的while循环解决方案,如下所示。
for k, v in d.items():
if m.shape != v.shape:
reshaped_m = m.copy()
while len(reshaped_m.shape) < len(v.shape):
reshaped_m = reshaped_m[..., None]
d[k] = v * reshaped_m
我正在寻找更好的解决方案,谢谢。
最佳答案
您可以简单地-
for k, v in d.items():
d[k] = (v.T*m).T
因此,我们基本上是将第一个轴推到末端,以便
v
可以针对m
进行广播。然后,乘以m
。最后,将最后一个轴推回到前面。可能有不同的方法来排列这些轴,但是
transposing
是最简单的方法。如果遮罩的尺寸大于
1
,则我们也需要转置m
。因此。在这种情况下,它将为d[k] = (v.T*m.T).T
。另一种方法是重塑
m
,然后与v
相乘-for k, v in d.items():
d[k] = v*m.reshape(m.shape + (1,)*(v.ndim-m.ndim))
另一个是
np.einsum
-for k, v in d.items():
d[k] = np.einsum('i...,i->i...',v,m)
关于python - 扩展numpy数组的维数以满足不同要求的最佳方法,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/58687266/