我有一个张量 xx.shape=(batch_size,10) ,现在我想要

 x[i][0] = x[i][0]*x[i][1]*...*x[i][9] for i in range(batch_size)

这是我的代码:

for i in range(batch_size):
    for k in range(1, 10):
        x[i][0] = x[i][0] * x[i][k]

但是当我在 forward() 中实现并调用 loss.backward() 时,反向传播的速度非常慢。为什么它很慢,有什么方法可以有效地实现它?

最佳答案

它很慢,因为您使用了两个 for 循环。

您可以使用 .prod 参见:https://pytorch.org/docs/stable/torch.html#torch.prod

在你的情况下,
x = torch.prod(x, dim=1)x = x.prod(dim=1)
应该管用

关于python - 如何有效地计算 Pytorch 中的张量?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53699675/

10-12 18:12
查看更多