本文介绍了PyTorch 数据集:将整个数据集转换为 NumPy的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试将 Torchvision MNIST 训练和测试数据集转换为 NumPy 数组,但找不到实际执行转换的文档.

我的目标是将整个数据集转换为单个 NumPy 数组,最好不要遍历整个数据集.

我看过如何将 Pytorch Dataloader 转换为 numpy 数组以使用 matplotlib 显示图像数据? 但它没有解决我的问题.

所以我的问题是,使用 torch.utils.data.DataLoader,我将如何将数据集(训练/测试)转换为两个 NumPy 数组,以便所有示例都存在?

注意:我暂时将批量大小保留为默认值 1;我可以将它设置为 60,000 用于火车和 10,000 用于测试,但我宁愿不使用那种幻数.

谢谢.

解决方案

如果我没理解错,你想得到 MNIST 图像的整个训练数据集(总共 60000 张图像,每个图像大小为 1x28x28 数组,1 个用于颜色)通道)作为大小为 (60000, 1, 28, 28) 的 numpy 数组?

from torchvision 导入数据集,转换从 torch.utils.data 导入 DataLoader# 转换为归一化的张量transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)train_loader = DataLoader(train_dataset,batch_size=len(train_dataset))# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))train_dataset_array = next(iter(train_loader))[0].numpy()# test_dataset_array = next(iter(test_loader))[0].numpy()

结果如下:

>>>train_dataset_array数组([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],...,[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296]]],[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],...,[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296]]],[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],...,[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296]]],...,[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],...,[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296]]],[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],...,[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296]]],[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],...,[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296],[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,-0.42421296, -0.42421296]]]], dtype=float32)

I'm trying to convert the Torchvision MNIST train and test datasets into NumPy arrays but can't find documentation to actually perform the conversion.

My goal would be to take an entire dataset and convert it into a single NumPy array, preferably without iterating through the entire dataset.

I've looked at How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib? but it doesn't address my issue.

So my question is, utilizing torch.utils.data.DataLoader, how would I go about converting the datasets (train/test) into two NumPy arrays such that all of the examples are present?

Note: I've left the batch size as the default of 1 for now; I could set it to 60,000 for train and 10,000 for test, but I'd prefer to not use magic numbers of that sort.

Thank you.

解决方案

If I understand you correctly, you want to get the whole train dataset of MNIST images (in total 60000 images, each image of size 1x28x28 array with 1 for color channel) as a numpy array of size (60000, 1, 28, 28)?

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transform to normalized Tensors
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)


train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()

This is the result:

>>> train_dataset_array

array([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       ...,


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]]], dtype=float32)

这篇关于PyTorch 数据集:将整个数据集转换为 NumPy的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

07-29 11:51