我想知道如何用dtype=object数组在numpy中支持矩阵乘法。我有封装在类Ciphertext中的同态加密数字,对此我已覆盖了__add____mul__等基本数学运算符。

我创建了numpy数组,其中每个条目都是我的类Ciphertext的实例,并且numpy了解如何广播加法和乘法运算。

    encryptedInput = builder.encrypt_as_array(np.array([6,7])) # type(encryptedInput) is <class 'numpy.ndarray'>
    encryptedOutput = encryptedInput + encryptedInput
    builder.decrypt(encryptedOutput)                           # Result: np.array([12,14])

但是,numpy不允许我做矩阵乘法
out = encryptedInput @ encryptedInput # TypeError: Object arrays are not currently supported

考虑到加法和乘法的工作原理,我不太明白为什么会发生这种情况。我猜想这与numpy无法知道对象的形状有关,因为它可能是列表或某种形式。

天真的解决方案:我可以编写自己的类来扩展ndarray并覆盖__matmul__操作,但是我可能会失去性能,并且这种方法需要实现广播等,因此,我基本上会重新设计轮子,以解决问题现在是这样。

问题:如何在对象行为完全像数字的dtype=objects数组上使用numpy提供的标准矩阵乘法?

先感谢您!

最佳答案

无论出于什么原因,matmul都不起作用,但是tensordot函数可以按预期工作。

encryptedInput = builder.encrypt_as_array(np.array([6,7]))
out = np.tensordot(encryptedInput, encryptedInput, axes=([1,0]))
    # Correct Result: [[ 92. 105.]
    #                  [120. 137.]]

现在调整轴只是一个麻烦。我仍然想知道这是否真的比使用for循环的幼稚实现更快。

关于python - Python中对象数组的矩阵乘法,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49386827/

10-11 20:32
查看更多