我正在学习一门使用不推荐使用的PyTorch版本的类(class),该版本不会根据需要将torch.int64更改为torch.LongTensor。引发错误的当前代码部分是:
loss = loss_fn(Ypred, Ytrain_) # calc loss on the prediction
我相信dtype应该在此部分中进行更改:
Ytrain_ = torch.from_numpy(y_train.values).view(1, -1)[0]

当使用Ytrain_.dtype测试数据类型时,它将返回torch.int64。我试图通过应用long()函数将其转换为:Ytrain_ = Ytrain_.long()无济于事。

我也尝试过在documentation中寻找它,但似乎它说torch.int64torch.long,我认为这意味着torch.int64应该可以工作。

RuntimeError                              Traceback (most recent call last)
----> 9     loss = loss_fn(Ypred, Ytrain_) # calc loss on the prediction
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'

最佳答案

user8426627所述,您想更改张量类型,而不是数据类型。因此,解决方案是添加.type(torch.LongTensor)以将其转换为LongTensor

最终代码:
Ytrain_ = torch.from_numpy(Y_train.values).view(1, -1)[0].type(torch.LongTensor)
测试张量类型:
Ytrain_.type()'torch.LongTensor'

关于python - 怎么把tort int64转换成LongTensor?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/56510189/

10-13 05:44