文章目录
torch.utils.data.sampler
内置的Sampler
基类 Sampler
sampler 采样器,是一个迭代器。PyTorch提供了多种采样器,用户也可以自定义采样器。所有sampler都是承 torch.utils.data.sampler.Sampler
这个抽象类。
class Sampler(object):
r"""Base class for all Samplers.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
顺序采样 SequentialSampler
- 功能
- 顺序地对元素进行采样,总是以相同的顺序。
- 参数
data_source(Dataset)
: 采样的数据集
初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()
只负责返回数据源包含的数据个数;iter()
方法负责返回一个可迭代对象,这个可迭代对象是由range产生的顺序数值序列,也就是说迭代是按照顺序进行的。
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
- 例子
# 定义数据和对应的采样器
data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)
# 迭代获取采样器生成的索引
for index in seq_sampler:
print("index: {}, data: {}".format(str(index), str(data[index])))
得到下面的输出,说明Sequential Sampler产生的索引是顺序索引:
index: 0, data: 17
index: 1, data: 22
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8
随机采样 RandomSampler
- 功能
- 随机抽取元素。如果没有替换,则从打乱的数据集中采样。 如果有替换,则用户可以指定
:attr:num_samples
- 随机抽取元素。如果没有替换,则从打乱的数据集中采样。 如果有替换,则用户可以指定
- 参数
data_source (Dataset)
: 采样的数据集replacement (bool)
: 如果为True
抽取的样本是有放回的。默认是False
num_samples (int)
: 抽取样本的数量,默认是len(dataset)
。当replacement
是True
的时应该被被实例化
class RandomSampler(Sampler):
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
# 这个参数控制的应该为是否重复采样
self.replacement = replacement
self._num_samples = num_samples
def num_samples(self):
# dataset size might change at runtime
# 初始化时不传入num_samples的时候使用数据源的长度
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
# 返回数据集长度
def __len__(self):
return self.num_samples
# 索引生成
def __iter__(self):
n = len(self.data_source)
if self.replacement:
# 生成的随机数是可能重复的
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
# 生成的随机数是不重复的
return iter(torch.randperm(n).tolist())
randint()
函数生成的随机数学列是可能包含重复数值的,而randperm()
函数生成的随机数序列是绝对不包含重复数值的
- 例子
'''不使用replacement,生成的随机索引不重复'''
ran_sampler = sampler.RandomSampler(data_source=data)
# 得到下面输出
index: 0, data: 17
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8
index: 1, data: 22
'''使用replacement,生成的随机索引有重复'''
ran_sampler = sampler.RandomSampler(data_source=data, replacement=True)
# 得到下面的输出
index: 0, data: 17
index: 4, data: 8
index: 3, data: 41
index: 4, data: 8
index: 2, data: 3
子集随机采样 SubsetRandomSampler
- 功能
- 从给定的索引列表中随机抽取元素,不进行替换。
- 参数
indices (sequence)
: 索引列表
class SubsetRandomSampler(Sampler):
def __init__(self, indices):
# 数据集的切片,比如划分训练集和测试集
self.indices = indices
def __iter__(self):
# 以元组形式返回不重复打乱后的“数据”
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
_iter__()
返回的并不是随机数序列,而是通过随机数序列作为indices的索引,进而返回打乱的数据本身。需要注意的仍然是采样是不重复的,也是通过randperm()
函数实现的。
- 例子
下面将data划分为train和val两个部分
sub_sampler_train = sampler.SubsetRandomSampler(indices=data[0:2])
sub_sampler_val = sampler.SubsetRandomSampler(indices=data[2:])
# 下面是train输出
index: 17
index: 22
*************
# 下面是val输出
index: 8
index: 41
index: 3
加权随机采样 WeightedRandomSampler
- 功能
- 按照给定的概率权重
weights
, 对元素进行采样
- 按照给定的概率权重
- 参数
weights
权重序列num_samples
采样数replacement
抽取的样本是否有放回
class WeightedRandomSampler(Sampler):
def __init__(self, weights, num_samples, replacement=True):
# ...省略类型检查
# weights用于确定生成索引的权重
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
# 用于控制是否对数据进行有放回采样
self.replacement = replacement
def __iter__(self):
# 按照加权返回随机索引值
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
__iter__()
方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的
- 例子
# 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
weights = torch.Tensor([0, 10, 1.1, 1.1, 1.1, 1.1, 1.1])
wei_sampler = sampler.WeightedRandomSampler(weights=weights, num_samples=6, replacement=True)
# 下面是输出:
index: 1
index: 2
index: 3
index: 4
index: 1
index: 1
从输出可以看出,位置[1]由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。
批采样 BatchSampler
- 功能
- 包装另一个采样器以生成一个小批量索引。
- 参数
sampler
对应前面介绍的XxxSampler类实例batch_size
批量大小drop_last
为“True”时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据
class BatchSampler(Sampler):
def __init__(self, sampler, batch_size, drop_last):、
# ...省略类型检查
# 定义使用何种采样器Sampler
self.sampler = sampler
self.batch_size = batch_size
# 是否在采样个数小于batch_size时剔除本次采样
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
# 如果采样个数和batch_size相等则本次采样完成
if len(batch) == self.batch_size:
yield batch
batch = []
# for结束后在不需要剔除不足batch_size的采样个数时返回当前batch
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
# 在不进行剔除时,数据的长度就是采样器索引的长度
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
- 例子
下面的例子中batch sampler采用的采样器为顺序采样器:
seq_sampler = sampler.SequentialSampler(data_source=data)
batch_sampler = sampler.BatchSampler(seq_sampler, 3, False)
# 下面是输出
batch: [0, 1, 2]
batch: [3, 4]