作为一个2年多的不资深keraser和tfer,被boss要求全员换成pytorch。不得不说,pytorch还是真香的。之前用keras,总会发现多GPU使用的情况下不太好,对计算资源的利用率不太高。把模型改成pytorch以后,发现资源利用率非常可观。
非常看好pytorch的前途,到时候能制衡一下tf就好了。闲话不多扯,我来讲讲初入pytorch最重要的东西:dataset
网上有很多介绍pytorch dataset类的文章,不过大多数都是讲解某一类任务的数据集模型建立。不太具有泛化性,本文将提出一个通用的数据集接口解决技巧,供大家参考。
实验环境:
python==3.7.3
ubuntu==16.04
pytorch==1.1.0
dataset类
为什么木盏会说dataset是初入pytorch最重要的东西?因为我们复现项目的时候,最需要改的就是数据集。其他调调参改改模型问题都不大。
如果弄明白了pytorch中dataset类,你可以创建适应任意模型的数据集接口。
所谓数据集,无非就是一组{x:y}的集合吗,你只需要在这个类里说明“有一组{x:y}的集合”就可以了。
对于图像分类任务,图像+分类
对于目标检测任务,图像+bbox、分类
对于超分辨率任务,低分辨率图像+超分辨率图像
对于文本分类任务,文本+分类
...
你只需定义好这个项目的x和y是什么。好了,上面都是扯闲篇,我们直接看dataset代码:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
上面的代码是pytorch给出的官方代码,其中__getitem__和__len__是子类必须继承的。
很好解释,pytorch给出的官方代码限制了标准,你要按照它的标准进行数据集建立。首先,__getitem__就是获取样本对,模型直接通过这一函数获得一对样本对{x:y}。__len__是指数据集长度。
自己建立一个dataset试试:
class MyDataSet(Dataset):
def __init__(self):
self.sample_list = ...
def __getitem__(self, index):
x= ...
y= ...
return x, y
def __len__(self):
return len(self.sample_list)
上面这个模板是本人定义好的,史称“木盏模板”。咱只需按照需求把模板填完就Ok了,那么为什么说这个模板使用于各种任务的数据集建造呢?还得依靠一个trick:通过txt文件映射。
举个实例,假设我要给一个分类器训练喂数据,我的数据是images+number的组合,比如{img:3},这代表这个图像应该分在“3”类。我怎么写代码呢?
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def __init__(self, dataset_type, transform=None, update_dataset=False):
"""
dataset_type: ['train', 'test']
"""
dataset_path = '/home/muzhan/projects/dataset/'
if update_dataset:
make_txt_file(dataset_path) # update datalist
self.transform = transform
self.sample_list = list()
self.dataset_type = dataset_type
f = open(dataset_path + self.dataset_type + '/datalist.txt')
lines = f.readlines()
for line in lines:
self.sample_list.append(line.strip())
f.close()
def __getitem__(self, index):
item = self.sample_list[index]
# img = cv2.imread(item.split(' _')[0])
img = Image.open(item.split(' _')[0])
if self.transform is not None:
img = self.transform(img)
label = int(item.split(' _')[-1])
return img, label
def __len__(self):
return len(self.sample_list)
上面有个transform参数,用于对数据集进行预处理的,可以根据项目选择使用。
上面有一个make_txt_file的函数需要说明一下,这个函数可以在数据集目录下创建一个txt文件,代表x和y的映射关系。这个函数大家可以自己写,一个简单脚本而已,我就不共享代码了 。(如有需要,留言告知)
我给大家看一下我的datalist.txt中的几行:
/home/muzhan/projects/dataset/test/250_04.png _0
/home/muzhan/projects/dataset/test/250_05.png _7
/home/muzhan/projects/dataset/test/250_06.png _3
/home/muzhan/projects/dataset/test/250_07.png _2
/home/muzhan/projects/dataset/test/250_08.png _2
/home/muzhan/projects/dataset/test/250_09.png _3
/home/muzhan/projects/dataset/test/250_10.png _4
/home/muzhan/projects/dataset/test/250_11.png _0
/home/muzhan/projects/dataset/test/250_12.png _9
这样就可以理解我在__getitem__函数中解析x和y的方法吧,在文本中用字符串' _'隔开,当然你可以用其他字符,能够保证剪切字符串不出错即可。
我们需要测试这个dataset类是否成功:
if __name__ == '__main__':
ds = MyDataSet()
print(ds.__len__())
img, gt = ds.__getitem__(34) # get the 34th sample
print(type(img))
print(gt)
上面有输出,并且和你数据集一致,那证明这个dataset类是成功的。
有了这个,用DataLoader函数就可以加载我们的数据集了。