pytorch的数据加载和处理教程非常具体地针对一个示例,有人可以帮我了解更通用的简单图像加载功能应该是什么样的吗?
教程:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
我的资料:
我在以下文件夹结构中将MINST数据集作为jpg包含在其中。 (我知道我可以只使用数据集类,但这纯粹是为了了解如何在不使用csv或复杂功能的情况下将简单图像加载到pytorch中)。
文件夹名称是标签,图像是灰度的28x28 png,无需任何转换。
data
train
0
3.png
5.png
13.png
23.png
...
1
3.png
10.png
11.png
...
2
4.png
13.png
...
3
8.png
...
4
...
5
...
6
...
7
...
8
...
9
...
最佳答案
这是我为pytorch 0.4.1做的(仍应在1.3中使用)
def load_dataset():
data_path = 'data/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=64,
num_workers=0,
shuffle=True
)
return train_loader
for batch_idx, (data, target) in enumerate(load_dataset()):
#train network
关于python - 如何将MNIST图像加载到Pytorch DataLoader中?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50052295/