本文介绍了PyTorch:RuntimeError:函数MulBackward0在索引0处返回了无效的渐变-预期类型为torch.cuda.FloatTensor,但得到了torch.FloatTensor的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我不明白此错误告诉我什么.在不同的帖子中,同样的问题也得到解决,但没有有用的解决方案为此.
I don't understand what this error is telling me. In a different post the same problem was also addressed but there was no useful solution for this.
Traceback (most recent call last):
File "train.py", line 252, in <module>
main()
File "train.py", line 231, in main
train(net, training_dataset, targets, device, criterion, optimizer, epoch, args.epochs)
File "train.py", line 103, in train
loss.backward()
File "/home/hb119056/.local/lib/python3.6/site-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/hb119056/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Function MulBackward0 returned an invalid gradient at index 0 - expected type torch.cuda.FloatTensor but got torch.FloatTensor
这是我代码中的对应段.
This is the corresponding segment from my code.
outputs = net(x, indices)
outputs = outputs.transpose(0, 1)
prob = F.normalize(outputs, p=1, dim=1).detach()
target = torch.from_numpy(np.load(file_dir + '/points/points{:03}.npy'.format(i))).to(device)
rv = torch.zeros(12 * outputs.shape[0])
for j in [x for x in range(10) if x != i]:
source = torch.from_numpy(np.load(file_dir + '/points/points{:03}.npy'.format(j))).to(device)
rv = factor.ransac(source, target, prob, n_iter, tol, device) # self-written
predicted = factor.predict(source, rv, outputs, device) # self-written
loss = criterion(predicted, target.type(torch.FloatTensor).to(device))
loss.backward() # error occurs here
optimizer.step()
非常感谢您的帮助,在此先感谢您!
Any help is greatly appreciated, thank you in advance!
推荐答案
更改此行:
loss = criterion(predicted, target.type(torch.FloatTensor).to(device))
到
predicted = predicted.to(device)
target=target.type(predicted.type()).to(predicted.device)
loss = criterion(predicted, target)
这篇关于PyTorch:RuntimeError:函数MulBackward0在索引0处返回了无效的渐变-预期类型为torch.cuda.FloatTensor,但得到了torch.FloatTensor的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!