我对整个领域有点陌生,因此决定使用MNIST数据集。我几乎修改了https://github.com/pytorch/examples/blob/master/mnist/main.py中的整个代码,只有一个重大更改:数据加载。我不想在Torchvision中使用预加载的数据集。所以我用MNIST in CSV。
我是通过从数据集继承并制作新的数据加载器来从CSV文件加载数据的。
以下是相关代码:
mean = 33.318421449829934
sd = 78.56749081851163
# mean = 0.1307
# sd = 0.3081
import numpy as np
from torch.utils.data import Dataset, DataLoader
class dataset(Dataset):
def __init__(self, csv, transform=None):
data = pd.read_csv(csv, header=None)
self.X = np.array(data.iloc[:, 1:]).reshape(-1, 28, 28, 1).astype('float32')
self.Y = np.array(data.iloc[:, 0])
del data
self.transform = transform
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
item = self.X[idx]
label = self.Y[idx]
if self.transform:
item = self.transform(item)
return (item, label)
import torchvision.transforms as transforms
trainData = dataset('mnist_train.csv', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (sd,))
]))
testData = dataset('mnist_test.csv', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (sd,))
]))
train_loader = DataLoader(dataset=trainData,
batch_size=10,
shuffle=True,
)
test_loader = DataLoader(dataset=testData,
batch_size=10,
shuffle=True,
)
但是,此代码为我提供了您在图片中看到的绝对奇怪的训练错误图,以及11%的最终验证错误,因为它将所有内容归类为“ 7”。
我设法将问题归结为如何规范化数据,以及是否使用示例代码中给出的值(0.1307和0.3081)进行转换.Normalize以及将数据读取为'uint8'类型都可以正常工作。
请注意,在这两种情况下提供的数据差异很小。对0到1的值进行0.1307和0.3081归一化,与对0到255的值进行33.31和78.56归一化具有相同的效果。该值甚至大体相同(黑色像素对应于-0.4241,而-0.4242在第二)。
如果您想在IPython Notebook中清楚地看到此问题,请查看https://colab.research.google.com/drive/1W1qx7IADpnn5e5w97IcxVvmZAaMK9vL3
我无法理解是什么原因导致这两种加载数据的方式略有不同。任何帮助将不胜感激。
最佳答案
长话短说:您需要将item = self.X[idx]
更改为item = self.X[idx].copy()
。
长话短说:T.ToTensor()
运行torch.from_numpy
,它返回一个张量,该张量将numpy数组dataset.X
的内存作为别名。还有T.Normalize()
works inplace,因此每次抽取样本时都会减去mean
并除以std
,从而导致数据集退化。
编辑:关于为什么它可以在原始MNIST加载程序中运行,因此兔子洞甚至更深。 MNIST
中的关键行是将映像transformed放入PIL.Image实例中。该操作声称仅在缓冲区不连续的情况下才复制(在我们的情况下),但是在hood下,它检查是否跨步(它是跨步的),从而进行复制。因此,幸运的是,默认的Torchvision管道涉及一个副本,因此T.Normalize()
的就地操作不会破坏我们self.data
实例的内存中的MNIST
。
关于python - MNIST Pytorch中的验证错误意外增加,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53652015/