我正在尝试使用pytorch的Dataset和DataLoader类加载数据。我使用torch.from_numpy
将每个数组转换为割炬数据集中的张量,并且通过查看数据,每个X和y确实是张量
# At this point dataset is {'X': numpy array of arrays, 'y': numpy array of arrays }
class TorchDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.X_train = torch.from_numpy(dataset['X'])
self.y_train = torch.from_numpy(dataset['y'])
def __len__(self):
return len(self.X_train)
def __getitem__(self, index):
return {'X': self.X_train[index], 'y': self.y_train[index]}
torch_dataset = TorchDataset(dataset)
dataloader = DataLoader(torch_dataset, batch_size=4, shuffle=True, num_workers=4)
for epoch in range(num_epochs):
for X, y in enumerate(dataloader):
features = Variable(X)
labels = Variable(y)
....
但是在
features = Variable(X)
上我得到:RuntimeError: Variable data has to be a tensor, but got int
数据集中的X和y的示例是:
In [1]: torch_dataset[1]
Out[1]:
{'X':
-2.5908 -3.1123 -2.9460 ... -3.9898 -4.0000 -3.9975
-3.0867 -2.9992 -2.5254 ... -4.0000 -4.0000 -4.0000
-2.7665 -2.5318 -2.7035 ... -4.0000 -4.0000 -4.0000
... ⋱ ...
-2.4784 -2.6061 -1.6280 ... -4.0000 -4.0000 -4.0000
-2.2046 -2.1778 -1.5626 ... -3.9597 -3.9366 -3.9497
-1.9623 -1.9468 -1.5352 ... -3.8485 -3.8474 -3.8474
[torch.DoubleTensor of size 1024x1024], 'y':
107
[torch.LongTensor of size 1]}
这就是为什么我觉得Torch认为X是一个整数很困惑。任何帮助将不胜感激-谢谢!
最佳答案
您使用enumerate
时出错,导致错误,因为enumerate
的第一个返回值是批处理索引,而不是实际数据。有两种方法可以使脚本正常工作。
第一种方式
由于您的X
和y
是不需要特殊的过程。您可以只返回X
和y
的样本。将您的__getitem__
方法更改为
def __getitem__(self, index):
return self.X_train[index], self.y_train[index]
另外,稍微改变一下训练循环:
for epoch in range(num_epochs):
for batch_id, (x, y) in enumerate(dataloader):
x = Variable(x)
y = Variable(y)
# then do whatever you want to do
第二种方式
您可以在
__getitem__
方法中返回字典,并在训练循环中提取实际数据。在这种情况下,您无需更改__getitem__
方法。只需更改您的训练循环即可:for epoch in range(num_epochs):
for batch_id, data in enumerate(dataloader):
# data will be dict
x = Variable(data['X'])
y = Variable(data['y'])
# then do whatever you want to do
关于python - PyTorch:变量数据必须是张量—数据已经是张量,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49583041/