本文介绍了为什么我的简单 pytorch 网络在 GPU 设备上不起作用?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我根据教程构建了一个简单的网络,但出现此错误:

I built a simple network from a tutorial and I got this error:

RuntimeError: 类型为 torch.cuda.FloatTensor 的预期对象,但已找到输入 torch.FloatTensor 作为参数 #4 'mat1'

有什么帮助吗?谢谢!

import torch
import torchvision

device = torch.device("cuda:0")
root = '.data/'

dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.out = torch.nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

net = Net()
net.to(device)

for i, (inputs, labels) in enumerate(dataloader):
    inputs.to(device)
    out = net(inputs)

推荐答案

TL;DR
这是修复

inputs = inputs.to(device)

为什么?!
torch.nn.Module 之间略有不同.to()torch.Tensor.to():Module.to() 是一个就地操作符,Tensor.to() 不是.因此

Why?!
There is a slight difference between torch.nn.Module.to() and torch.Tensor.to(): while Module.to() is an in-place operator, Tensor.to() is not. Therefore

net.to(device)

更改net 本身并将其移动到device.另一方面

Changes net itself and moves it to device. On the other hand

inputs.to(device)

不会更改inputs,而是返回驻留在device 上的inputs副本.要使用该在设备上"的副本,您需要将其分配给一个变量,因此

does not change inputs, but rather returns a copy of inputs that resides on device. To use that "on device" copy, you need to assign it into a variable, hence

inputs = inputs.to(device)

这篇关于为什么我的简单 pytorch 网络在 GPU 设备上不起作用?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-05 17:47