如何在PyTorch中为不同的Subset使用不同的数据增强(转换)?

例如:

train, test = torch.utils.data.random_split(dataset, [80000, 2000])
traintest将具有与dataset相同的转换。如何对这些子集使用自定义转换?

最佳答案

我当前的解决方案不是很好,但是可以工作:

from copy import copy

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)

test_dataset.dataset.transform = transforms.Compose([
    transforms.Resize(img_resolution),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset.dataset.transform = transforms.Compose([
    transforms.RandomResizedCrop(img_resolution[0]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

基本上,我为一个拆分定义了一个新的数据集(它是原始数据集的副本),然后为每个拆分定义了一个自定义转换。

注意:train_dataset.dataset.transform有效,因为我使用的是ImageFolder数据集,该数据集使用.tranform属性执行转换。

如果有人知道更好的解决方案,请与我们分享!

关于python - 如何在PyTorch中为子集使用不同的数据增强,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/51782021/

10-12 19:10
查看更多