关注B站查看更多手把手教学:
基本用法
torch.utils.data.Dataset
是 PyTorch 中一个非常重要的抽象类,它用于表示数据集,方便数据加载和预处理。通过实现这个类的两个方法 __len__
和 __getitem__
,你可以自定义自己的数据集类。__len__
方法应返回数据集的大小(即样本数),而 __getitem__
方法则根据给定的索引返回一个样本。
以下是一个简单的示例,说明如何使用 torch.utils.data.Dataset
创建一个自定义的数据集类:
import torch
from torch.utils.data import Dataset
class MyCustomDataset(Dataset):
def __init__(self, data, targets):
"""
参数:
data: 样本数据, 形状为 [num_samples, ...] (例如 [num_samples, num_channels, height, width])
targets: 样本标签, 形状为 [num_samples, ...] (例如 [num_samples])
"""
self.data = data
self.targets = targets
def __len__(self):
# 返回数据集的样本数
return len(self.data)
def __getitem__(self, idx):
# 根据索引 idx 返回一个样本 (数据和标签)
return self.data[idx], self.targets[idx]
# 示例数据和标签
X = torch.randn(100, 3, 32, 32) # 假设有 100 个 3x32x32 的样本
y = torch.randint(0, 10, (100,)) # 假设有 100 个对应的标签 (0-9)
# 创建数据集实例
dataset = MyCustomDataset(X, y)
# 可以使用 len() 获取数据集大小
print(len(dataset)) # 输出: 100
# 可以使用索引获取样本
sample, label = dataset[0] # 获取第一个样本和标签
print(sample.shape) # 输出: torch.Size([3, 32, 32])
print(label) # 输出: 一个整数 (0-9)
在上面的示例中,我们创建了一个名为 MyCustomDataset
的自定义数据集类,该类继承自 torch.utils.data.Dataset
。在类的构造函数中,我们接收样本数据和标签,并将它们存储在类的实例变量中。我们还实现了 __len__
和 __getitem__
方法,分别用于返回数据集的大小和根据索引获取样本。最后,我们创建了一个数据集实例,并展示了如何使用它来获取数据集的大小和样本。
标准数据集
在PyTorch的torchvision.datasets
模块中,包含了多个标准的数据集,这些数据集在计算机视觉领域非常流行。以下是一些常用的标准数据集:
- MNIST:手写数字识别数据集,包含了大量的手写数字图片和对应的标签。
- CIFAR:包含CIFAR-10和CIFAR-100两个数据集,分别用于10类和100类的小图片分类任务。
- ImageNet:一个大规模的图片分类数据集,包含了上千万张标注过的图片,通常用于训练深度神经网络。在
torchvision.datasets
中,可以通过ImageFolder
类来加载按文件夹组织的ImageNet风格的数据集。虽然完整的ImageNet数据集很大并不直接包含在torchvision.datasets
中,但PyTorch提供了处理这种数据集的工具。 - COCO (Common Objects in Context):用于图像标注、目标检测和语义分割的大型数据集。它包含了图片、物体的标注框、分割掩码以及关键点等信息。
- LSUN (Large-scale Scene UNderstanding):场景理解的大型数据集,包含了不同类别的场景图片。
- FashionMNIST:类似于MNIST,但是用于时尚服装和配饰的图片分类。
- SVHN (Street View House Numbers):从谷歌街景图片中提取的门牌号识别数据集。
- PhotoTour:用于图像匹配的数据集,包含了从不同角度拍摄的同一景点的图片对。
- STL10:一个用于无监督学习和半监督学习的图像数据集,包含了少量的标注数据和大量的无标注数据。
- Kinetics:用于视频动作识别的大型数据集。
- CelebA (CelebFaces Attributes):用于人脸检测和属性识别的大型人脸数据集。
这些标准数据集可以通过简单地调用torchvision.datasets
中的相应类来加载和预处理。例如,加载MNIST数据集可以通过以下代码实现:
import torchvision.datasets as dsets
# 加载MNIST训练集
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
# 加载MNIST测试集
test_dataset = dsets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
注意,上面的代码中使用了transforms.ToTensor()
来对图片进行预处理,将其转换为PyTorch的Tensor
格式。在实际使用中,你可能还需要根据具体任务添加其他的预处理步骤,比如裁剪、归一化等。这些都可以通过组合torchvision.transforms
中的不同变换来实现。