1.前言
在PyTorch中,Dataset
和DataLoader
是两个重要的工具,用于构建输入数据的管道。
(1)Dataset
是一个抽象类,表示数据集,需要实现__len__
和__getitem__
方法。
(2)DataLoader
是一个可迭代的数据加载器,它封装了数据集的加载、批处理、打乱和并行加载等功能。
2.分类任务创建Dataset
和DataLoader
(1)对于分类任务,Dataset
需要返回图像和对应的标签
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
class ClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.labels = [...] # 这里应该是与图像对应的标签列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
(2)DataLoader
加载数据
from torch.utils.data import DataLoader
transform = ... # 这里定义你的数据预处理流程
dataset = ClassificationDataset(root_dir='path_to_your_data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
3.检测任务创建Dataset
和DataLoader
(1)Dataset
需要返回图像和对应的边界框信息
class DetectionDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.annotations = [...] # 这里应该是与图像对应的边界框信息列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
boxes = self.annotations[idx] # 这些是边界框信息
if self.transform:
image, boxes = self.transform(image, boxes)
return image, boxes
(2)DataLoader
加载数据
dataloader = DataLoader(DetectionDataset(root_dir='path_to_your_data', transform=transform), batch_size=2, shuffle=True)
4.分割任务创建Dataset
和DataLoader
(1)Dataset
需要返回图像和对应的分割掩码
class SegmentationDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.masks = [...] # 这里应该是与图像对应的分割掩码列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
mask_path = self.masks[idx]
image = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # 假设掩码是灰度图
if self.transform:
image, mask = self.transform(image, mask)
return image, mask
(2)DataLoader
加载数据
dataloader = DataLoader(SegmentationDataset(root_dir='path_to_your_data', transform=transform), batch_size=4, shuffle=True)