我正在尝试使用pytorch的Dataset和DataLoader类加载数据。我使用torch.from_numpy将每个数组转换为割炬数据集中的张量,并且通过查看数据,每个X和y确实是张量

# At this point dataset is {'X': numpy array of arrays, 'y': numpy array of arrays  }

class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.X_train = torch.from_numpy(dataset['X'])
        self.y_train = torch.from_numpy(dataset['y'])

    def __len__(self):
        return len(self.X_train)

    def __getitem__(self, index):
        return {'X': self.X_train[index], 'y': self.y_train[index]}

torch_dataset = TorchDataset(dataset)
dataloader = DataLoader(torch_dataset, batch_size=4, shuffle=True, num_workers=4)


for epoch in range(num_epochs):
    for X, y in enumerate(dataloader):
        features = Variable(X)
        labels = Variable(y)
        ....


但是在features = Variable(X)上我得到:

RuntimeError: Variable data has to be a tensor, but got int


数据集中的X和y的示例是:

In [1]: torch_dataset[1]
Out[1]:
{'X':
 -2.5908 -3.1123 -2.9460  ...  -3.9898 -4.0000 -3.9975
 -3.0867 -2.9992 -2.5254  ...  -4.0000 -4.0000 -4.0000
 -2.7665 -2.5318 -2.7035  ...  -4.0000 -4.0000 -4.0000
       ...             ⋱             ...
 -2.4784 -2.6061 -1.6280  ...  -4.0000 -4.0000 -4.0000
 -2.2046 -2.1778 -1.5626  ...  -3.9597 -3.9366 -3.9497
 -1.9623 -1.9468 -1.5352  ...  -3.8485 -3.8474 -3.8474
 [torch.DoubleTensor of size 1024x1024], 'y':
  107
 [torch.LongTensor of size 1]}


这就是为什么我觉得Torch认为X是一个整数很困惑。任何帮助将不胜感激-谢谢!

最佳答案

您使用enumerate时出错,导致错误,因为enumerate的第一个返回值是批处理索引,而不是实际数据。有两种方法可以使脚本正常工作。

第一种方式

由于您的Xy是不需要特殊的过程。您可以只返回Xy的样本。将您的__getitem__方法更改为

 def __getitem__(self, index):
        return self.X_train[index], self.y_train[index]


另外,稍微改变一下训练循环:

for epoch in range(num_epochs):
    for batch_id, (x, y) in enumerate(dataloader):
           x = Variable(x)
           y = Variable(y)
           # then do whatever you want to do


第二种方式

您可以在__getitem__方法中返回字典,并在训练循环中提取实际数据。在这种情况下,您无需更改__getitem__方法。只需更改您的训练循环即可:

for epoch in range(num_epochs):
    for batch_id, data in enumerate(dataloader):
        # data will be dict
        x = Variable(data['X'])
        y = Variable(data['y'])
        # then do whatever you want to do

关于python - PyTorch:变量数据必须是张量—数据已经是张量,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49583041/

10-09 07:06
查看更多