课程地址
最近做实验发现自己还是基础框架上掌握得不好,于是开始重学一遍PyTorch框架,这个是课程笔记,此节课很详细,笔记记的比较粗
1. DataLoader
1.1 DataLoader类实现
1.1.1 构造函数__init__实现
构造函数有如下参数:
- dataset:传入自己定义好的数据集类Dataset
- batch_size:默认值为1,它代表着每批次训练的样本的个数
- shuffle:布尔类型,True为打乱数据集,False为不打乱数据集
- sampler:决定以何种方式对数据进行采样,可以不用shuffle随机打乱样本,可以用自己编写的函数去决定如何取样本,比如:你想让你的样本以一种有序的方式来组织成mini-batch,比如把长度比较接近的样本放入到一个mini-batch中,这个时候就不能用shuffle,因为一打乱,这些样本的长度就是乱的。如果传入该参数,则shuffle就没有意义。
- batch_sampler:可以用自己编写的函数成批次地取样本。如果传入该参数,则shuffle就没有意义。
- num_workers:默认值为0,它是指数据加载的子进程数量,以加快数据加载的速度,提高训练效率。一般数值设定取决于CPU的核心数,通常数字大到一定程度,其加载速度也不会再提高了。
- collate_fn:聚集函数,它是对一个批次batch进行后处理,比如:我们通过shuffle打乱后得到一个批次batch,然后对这个batch我们希望对它进行一个pad,但是这个pad的长度只能通过batch去算出来,而不是预先能计算出长度,这个时候我们就要用到collate_fn参数,对之前的shuffle后的mini-batch再处理一下,把这个批次batch给它pad成一样的长度,然后再返回一个新的批次batch。
- pin_memory:布尔类型,默认值为False,用于指定是否将数据加载到固定的内存区域(pinned memory)中。固定内存区域是指一块被操作系统锁定的内存,这样可以防止它被移动,从而提高数据传输的效率。当pin_memory参数设置为True时,PyTorch会尝试将从数据集加载的数据存储在固定的内存中,这对于GPU加速的情况下可以提高数据传输效率,因为GPU可以直接从固定内存中访问数据,而不需要进行额外的内存拷贝操作。需要注意的是,只有当你使用GPU进行训练时,才会考虑使用pin_memory参数。对于CPU训练来说,pin_memory参数的影响通常不太明显。而且这个东西对训练速度的影响还有待考究。
- drop_last:布尔类型,默认为False,如果你的总样本数目不是每个批次batch的整数倍的话,这时候我们可以将drop_last设置为True,让最后那个小批次(样本数没达到batch-size的批次)丢掉。
构造函数的具体代码和注释如下:
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
torch._C._log_api_usage_once("python.data_loader")
if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
'use num_workers=0 to disable multiprocessing.')
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if num_workers == 0 and prefetch_factor != 2:
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
'let num_workers > 0 to enable multiprocessing.')
assert prefetch_factor > 0
if persistent_workers and num_workers == 0:
raise ValueError('persistent_workers option needs num_workers > 0')
# 设置成员函数
self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
# 这里不用看,一般我们都是用Dataset类,而不是IterableDataset,所以直接看这个if条件后面对应的else条件
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
if isinstance(dataset, IterDataPipe):
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
elif shuffle is not False:
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"shuffle option, but got shuffle={}".format(shuffle))
if sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"sampler option, but got sampler={}".format(sampler))
elif batch_sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
# 直接跳到else条件
else:
# 设置数据集的种类是DatasetKind.Map类型
self._dataset_kind = _DatasetKind.Map
# 如果你设置了sampler(默认为None),如果你传入了自定义的sampler且shuffle设置为True的话,这种情况是没有意义的,shuffle是官方提供的一种随机采用党的sampler,你都自定义sampler了,就不需要shuffle来随机打乱。所以shuffle和sampler是互斥的,不能同时去设置
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle')
# batch_sampler是批次级别的采样,sampler是样本级的采样,
if batch_sampler is not None:
# 如果你设置了batch_size不是1,或者你设置了shuffle或者你设置了sampler,或者你设置了drop_last,这些都与batch_sampler是互斥的,总结一句话就是:你只要设置了batch_sampler就不需要设置batch_size了,因为你设置了batch_sampler就已经告诉PyTorch框架你的batch_size和以什么样的方式去构成mini-batch
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
batch_size = None
drop_last = False
# 如果batch_size是None,同时如果有drop_last,这时候会报错
elif batch_size is None:
# no auto_collation
if drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with drop_last')
# 如果你没有设置sampler的话
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style(常用的),如果你设置了shuffle的话,它就会用内置的一个叫random sample的类来去对我们这个Dataset进行一个随机的打乱。具体实现在下面的章节
if shuffle:
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
# 如果没有设置shuffle为True的话,它就用SequentialSampler即按原本的顺序来采样
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
# 如果你的batch_size不是None并且batch_sampler也不是None
# 它就默认给你构造一个batch_sampler
# BatchSampler源码实现见下面的章节
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator
# 如果collate_fn参数为None,则如果设置了auto_collatoion,就调用默认的default_collate
if collate_fn is None:
# _auto_collation是根据batch_sampler是否为None来去设置的,如果batch_sampler不是None,_auto_collation设置为True,如果batch_sampler是None的话,它就会调用_utils.collate.default_convert这个函数,否则调用_utils.collate.default_collate函数。
# _utils.collate.default_collate函数是以batch作为输入,它相当于什么都没做,最后返回了个batch,如果自己要实现这个collate_fn,要以batch做输入,然后再做处理。
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
self.persistent_workers = persistent_workers
self.__initialized = True
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
self._iterator = None
self.check_worker_number_rationality()
torch.set_vital('Dataloader', 'enabled', 'True') # type: ignore[attr-defined]
1.1.2 _get_iterator函数
def _get_iterator(self) -> '_BaseDataLoaderIter':
# 如果设置num_workers为0的话,它就走单个样本处理过程
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
# 如果num_workers不为0,说明是多进程读取样本
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
一般迭代用,是在__iter__方法中实现的,使得DataLoader能变成一个可迭代的对象。
1.2 RandomSampler 类的实现
重点看中文注释
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
# 首先看__iter__方法
def __iter__(self) -> Iterator[int]:
# 获取数据集的大小
n = len(self.data_source)
# 如果没有传入generator的话,他就会随机生成一个种子,去构建一个生成器generator
if self.generator is None:
# 设置随机数的种子
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
# 返回0到n-1的列表的随机组合,n是数据集长度
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
def __len__(self) -> int:
return self.num_samples
1.3 SequentialSampler类的实现
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
# 如果迭代它,返回的是有序的索引
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
1.4 BatchSampler类的实现
也是直接看__iter__函数
class BatchSampler(Sampler[List[int]]):
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
# 先看iter函数
def __iter__(self) -> Iterator[List[int]]:
# 先创建一个空列表batch
batch = []
# 对sampler进行一个迭代,去元素的索引
for idx in self.sampler:
# 将其索引添加到列表中
batch.append(idx)
# 如果列表长度等于batch_size,这时候就返回列表,相当于返回一个批次batch,然后把batch置为空
if len(batch) == self.batch_size:
yield batch
batch = []
# 如果drop_last(是否丢弃最后的不够一个批次数量的元素)设置为False,那我们就把最后这个不够数量的批次也返回
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
1.5 其他
这个UP讲的太详细了,没全记录,部分细节可以看看视频