本文介绍了Pytorch transforms.RandomRotation()在Google Colab上不起作用的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

通常我在计算机上进行字母和数字识别,我想将项目移至Colab,但不幸的是出现了错误(您可以在下面看到错误).经过一些调试后,我发现哪条线给了我错误.

Normally i was working on letter&digit recognition on my computer and I wanted to move my project to Colab but unfortunately there was an error (you can see the error below).after some debugging i found which line is giving me error.

transforms.RandomRotation(degrees=(90, -90))

下面我写了简单的抽象代码来显示此错误.此代码在colab上不起作用,但在我自己的计算机环境下可以正常工作.问题可能与pytorch库的不同版本有关,我在计算机上的版本为1.3.1并且colab使用版本1.4.0.

below i wrote simple abstract code to show this error.This code does not work on colab but it works fine at my own computer environment.Problem might be about the different versions of pytorch library i have version 1.3.1 on my computer and colab uses version 1.4.0.

import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
    transformOpt = transforms.Compose([
            transforms.RandomRotation(degrees=(90, -90)),
            transforms.ToTensor()
        ])

    train_set = datasets.MNIST(
        root='', train=True, transform=transformOpt, download=True)
    test_set = datasets.MNIST(
        root='', train=False, transform=transformOpt, download=True)


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=100,
        shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=100,
        shuffle=False)

    images, labels = next(iter(train_loader))
    plt.imshow(images[0].view(28, 28), cmap="gray")
    plt.show()
TypeError                                 Traceback (most recent call last)

<ipython-input-1-8409db422154> in <module>()
     24     shuffle=False)
     25
---> 26 images, labels = next(iter(train_loader))
     27 plt.imshow(images[0].view(28, 28), cmap="gray")
     28 plt.show()

10 frames

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    343
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
     95
     96         if self.transform is not None:
---> 97             img = self.transform(img)
     98
     99         if self.target_transform is not None:

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)    1001         angle = self.get_params(self.degrees)    1002
-> 1003         return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)    1004     1005     def
__repr__(self):

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill)
    727         fill = tuple([fill] * 3)
    728
--> 729     return img.rotate(angle, resample, expand, center, fillcolor=fill)
    730
    731

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor)    2003         w, h = nw, nh    2004
-> 2005         return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor)    2006     2007     def save(self,    fp, format=None, **params):

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor)    2297             raise ValueError("missing method data")    2298
-> 2299         im = new(self.mode, size, fillcolor)    2300         if method == MESH:    2301             # list of quads

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color)    2503         im.palette = ImagePalette.ImagePalette()    2504         color = im.palette.getcolor(color)
-> 2505     return im._new(core.fill(mode, size, color))    2506     2507

TypeError: function takes exactly 1 argument (3 given)

推荐答案

您绝对正确. torchvision 0.5在fill参数的RandomRotation()中存在错误,可能是由于Pillow版本不兼容所致.此问题现在已得到修复( PR#1760 ),并将在下一版本中解决.

You're absolutely correct. torchvision 0.5 has a bug in RandomRotation() in the fill argument probably due to incompatible Pillow version. This issue has now been fixed (PR#1760) and will be resolved in the next release.

暂时将fill=(0,)添加到RandomRotation转换中以对其进行修复.

Temporarily, you add fill=(0,) to RandomRotation transform to fix it.

transforms.RandomRotation(degrees=(90, -90), fill=(0,))

这篇关于Pytorch transforms.RandomRotation()在Google Colab上不起作用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-14 05:03