我想知道如何用dtype=object
数组在numpy中支持矩阵乘法。我有封装在类Ciphertext
中的同态加密数字,对此我已覆盖了__add__
,__mul__
等基本数学运算符。
我创建了numpy数组,其中每个条目都是我的类Ciphertext
的实例,并且numpy了解如何广播加法和乘法运算。
encryptedInput = builder.encrypt_as_array(np.array([6,7])) # type(encryptedInput) is <class 'numpy.ndarray'>
encryptedOutput = encryptedInput + encryptedInput
builder.decrypt(encryptedOutput) # Result: np.array([12,14])
但是,numpy不允许我做矩阵乘法
out = encryptedInput @ encryptedInput # TypeError: Object arrays are not currently supported
考虑到加法和乘法的工作原理,我不太明白为什么会发生这种情况。我猜想这与numpy无法知道对象的形状有关,因为它可能是列表或某种形式。
天真的解决方案:我可以编写自己的类来扩展
ndarray
并覆盖__matmul__
操作,但是我可能会失去性能,并且这种方法需要实现广播等,因此,我基本上会重新设计轮子,以解决问题现在是这样。问题:如何在对象行为完全像数字的
dtype=objects
数组上使用numpy提供的标准矩阵乘法?先感谢您!
最佳答案
无论出于什么原因,matmul都不起作用,但是tensordot函数可以按预期工作。
encryptedInput = builder.encrypt_as_array(np.array([6,7]))
out = np.tensordot(encryptedInput, encryptedInput, axes=([1,0]))
# Correct Result: [[ 92. 105.]
# [120. 137.]]
现在调整轴只是一个麻烦。我仍然想知道这是否真的比使用for循环的幼稚实现更快。
关于python - Python中对象数组的矩阵乘法,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49386827/