我正在学习一门使用不推荐使用的PyTorch版本的类(class),该版本不会根据需要将torch.int64
更改为torch.LongTensor
。引发错误的当前代码部分是:loss = loss_fn(Ypred, Ytrain_) # calc loss on the prediction
我相信dtype应该在此部分中进行更改:Ytrain_ = torch.from_numpy(y_train.values).view(1, -1)[0]
。
当使用Ytrain_.dtype
测试数据类型时,它将返回torch.int64
。我试图通过应用long()
函数将其转换为:Ytrain_ = Ytrain_.long()
无济于事。
我也尝试过在documentation中寻找它,但似乎它说torch.int64
或torch.long
,我认为这意味着torch.int64
应该可以工作。
RuntimeError Traceback (most recent call last)
----> 9 loss = loss_fn(Ypred, Ytrain_) # calc loss on the prediction
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'
最佳答案
如user8426627
所述,您想更改张量类型,而不是数据类型。因此,解决方案是添加.type(torch.LongTensor)
以将其转换为LongTensor
。
最终代码:Ytrain_ = torch.from_numpy(Y_train.values).view(1, -1)[0].type(torch.LongTensor)
测试张量类型:Ytrain_.type()
'torch.LongTensor'
关于python - 怎么把tort int64转换成LongTensor?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/56510189/