本文介绍了如何在Pytorch中简化用于自动编码器的DataLoader的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

有没有更简单的方法来设置数据加载器,因为在自动编码器的情况下输入数据和目标数据是相同的,并且在训练过程中可以加载数据? DataLoader 始终需要两个输入.

Is there any easier way to set up the dataloader, because input and target data is the same in case of an autoencoder and to load the data during training? The DataLoader always requires two inputs.

目前,我像这样定义数据加载器:

Currently I define my dataloader like this:

X_train     = rnd.random((300,100))
X_val       = rnd.random((75,100))
train       = data_utils.TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(X_train).float())
val         = data_utils.TensorDataset(torch.from_numpy(X_val).float(), torch.from_numpy(X_val).float())
train_loader= data_utils.DataLoader(train, batch_size=1)
val_loader  = data_utils.DataLoader(val, batch_size=1)

像这样训练:

for epoch in range(50):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target).detach()
        optimizer.zero_grad()
        output = model(data, x)
        loss = criterion(output, target)

推荐答案

为什么不继承TensorDataset以使其与未标记的数据兼容?

Why not subclassing TensorDataset to make it compatible with unlabeled data ?

class UnlabeledTensorDataset(TensorDataset):
    """Dataset wrapping unlabeled data tensors.

    Each sample will be retrieved by indexing tensors along the first
    dimension.

    Arguments:
        data_tensor (Tensor): contains sample data.
    """
    def __init__(self, data_tensor):
        self.data_tensor = data_tensor

    def __getitem__(self, index):
        return self.data_tensor[index]

还有一些类似的方法可以训练您的自动编码器

And something along these lines for training your autoencoder

X_train     = rnd.random((300,100))
train       = UnlabeledTensorDataset(torch.from_numpy(X_train).float())
train_loader= data_utils.DataLoader(train, batch_size=1)

for epoch in range(50):
    for batch in train_loader:
        data = Variable(batch)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data)

这篇关于如何在Pytorch中简化用于自动编码器的DataLoader的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-13 08:54