我正在使用31个类(Office数据集)进行图像分类。每个类都有一个文件夹。我有一个使用PyTorch编写的python脚本,该脚本使用datasets.ImageFolder
加载数据集,并为每个图像分配标签,然后进行训练。这是我用于加载数据的代码段:
from torchvision import datasets, transforms
import torch
def load_training(root_path, dir, batch_size, kwargs):
transform = transforms.Compose(
[transforms.Resize([256, 256]),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
data = datasets.ImageFolder(root=root_path + dir, transform=transform)
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
return train_loader
该代码将占用每个文件夹,并为该文件夹中的所有图像分配相同的标签。有什么方法可以找到将哪个标签分配给哪个图像/图像文件夹?
最佳答案
ImageFolder类具有属性class_to_idx
,该属性是将类的名称映射到索引(标签)的字典。因此,您可以使用data.classes
访问类,并为每个类使用data.class_to_idx
获得标签。
供参考:https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py