使用数据集:hymenoptera_data.zip

数据的第一层目录下有train和val两个文件夹,分布存放训练集和验证集数据,每个文件夹下包含ants和bees两个文件夹。

ants和bees文件夹会被程序用做标签。

使用torchvision.datasets.ImageFolder(path, transform)进行图片的处理,返回torch.utils.data.Dataset类型的数据格式。

path目录下必须包含分类类型为名称的文件夹如上面的ants和bees,程序会自动根据文件夹生成标签。

import torchvision
import os
import torch


input_size = 224
batch_size = 64

data_dir = "./data/hymenoptera_data"

data_transforms = {
    "train": torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(input_size),
                                 torchvision.transforms.RandomHorizontalFlip(),
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),

    "val": torchvision.transforms.Compose([torchvision.transforms.Resize(input_size),
                               torchvision.transforms.CenterCrop(input_size),
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

image_datasets = {x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x), transform=data_transforms[x])
                  for x in ["train", "val"]}

dataloader_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)
                   for x in ['train', 'val']}

print(next(iter(dataloader_dict['train']))[0].shape)
print(next(iter(dataloader_dict['train']))[1])
print(next(iter(dataloader_dict['val']))[0].shape)
print(next(iter(dataloader_dict['val']))[1])

09-09 07:22