为什么python float 乘以torch.long 会得到一个torch.float 而用torch.long 为float 供电会得到一个torch.long?

>>> a = 0.9
>>> b = torch.tensor(2, dtype=torch.long)

>>> foo = a * b
>>> print(foo, foo.dtype)
tensor(1.8000) torch.float32

>>> bar = a ** b
>>> print(bar, bar.dtype)
tensor(0) torch.int64

最佳答案

这看起来像一个错误,可能是 pytorch 将 ** 绑定(bind)到 __rpow____pow__ 的方式。

例如。如果您尝试过 0.9 - torch.tensor(2) ,因为 0.9 不是张量,这将被解释为 torch.tensor(2).__rsub__(0.9) ,它可以正常工作。 ** 的行为方式相同,但 torch.tensor(2).__rpow__(0.9) 错误地返回了 dtype int64 的 tensor(0)

同时,您可以使用 torch.tensor(0.9) ** torch.tensor(2)

提交了一个错误:
https://github.com/pytorch/pytorch/issues/32436

关于python - python float 和 pytorch integer 的乘法和供电,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/59827509/

10-12 22:07