我目前有以下情况,我想使用DataLoader批处理一个numpy数组:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])
print(x.shape)
# >> (100,10)

# Create DataLoader
input_as_tensor = torch.from_numpy(x).float()
dataset = data_utils.TensorDataset(input_as_tensor)
dataloader = data_utils.DataLoader(dataset,
                                   batch_size=100,
                                  )
batch = next(iter(dataloader))

print(type(batch))
# >> <class 'list'>

print(len(batch))
# >> 1

print(type(batch[0]))
# >> class 'torch.Tensor'>


我希望batch已经是torch.Tensor。到目前为止,我已经像这样用batch[0]为该批处理编制了索引以获取张量,但是我觉得这并不是很漂亮,并且使代码更难阅读。

我发现DataLoader具有称为collate_fn的批处理功能。但是,设置data_utils.DataLoader(..., collage_fn=lambda batch: batch[0])仅将列表更改为元组(tensor([ 0.8454, ..., -0.5863]),),其中唯一的条目是作为张量的批处理。

您将通过帮助我找到如何将批处理优雅地转换为张量的方式,对我有很大帮助(即使这包括告诉我批量索引单个条目也是可以的)。

最佳答案

不便之处,敬请原谅。

实际上,您不必从张量创建Dataset,您可以在实现torch.Tensor__getitem__的情况下直接传递__len__,因此这就足够了:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])

# Create DataLoader
dataset = torch.from_numpy(x).float()
dataloader = data_utils.DataLoader(dataset, batch_size=100)
batch = next(iter(dataloader))

关于python - PyTorch DataLoader以列表的形式返回该批次,并且该批次是唯一条目。如何从我的DataLoader获取张量的最佳方法,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/58612401/

10-10 17:15