参考:MMCV 核心组件分析(六): Hook - 知乎、MMCV 核心组件分析(七): Runner - 知乎
1.Runner(执行器)
MMDetection(3D)中,最常用的Runner是EpochBasedRunner。下面以EpochBasedRunner为例介绍Runner。run()函数是Runner的关键函数,其代码如下(其中的call_hook()函数可暂时忽略):
def run(self, ...):
...
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow): # 例如workflow=[('train',1),('val',1)]
mode, epochs = flow
epoch_runner = getattr(self, mode) # epoch_runner为train或val
for _ in range(epochs):
...
epoch_runner(data_loaders[i], **kwargs) # 调用train或val函数
...
self.call_hook('after_run')
上述代码提到的workflow在配置文件中默认为[('train',1)],可自由设置为如[('train',1),('val',1)],[('train',2),('val',1)],[('val',1),('train',1)]等,其中[('train',1),('val',1)]表示训练一个epoch后验证一个epoch(验证的epoch不计入训练总epoch数中)。
上述代码中被调用的train()和val()函数如下:
def train(self, data_loader, **kwargs):
...
self.call_hook('before_train_epoch')
...
for i, data_batch in enumerate(self.data_loader):
...
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
...
self.call_hook('after_train_epoch')
...
@torch.no_grad()
def val(self, data_loader, **kwargs):
...
self.call_hook('before_val_epoch')
...
for i, data_batch in enumerate(self.data_loader):
...
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
...
self.call_hook('after_val_epoch')
我们假设run()函数调用train()函数,则将代码写完整如下:
def run(self, ...)
...
call_hook('before_run')
while self.epoch < self._max_epochs:
...
call_hook('before_train_epoch')
...
for i, data_batch in enumerate(self.data_loader):
...
call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs) # 训练
call_hook('after_train_iter')
...
call_hook('after_train_epoch')
...
call_hook('after_run')
可以看到,Runner定义了完整的训练过程,这和我们使用Pytorch编程实现的的训练过程是一致的。
2.Hook(挂钩)
2.1 引例——什么是Hook
假设有一个气象台,每天早上8点自动在网站上发布天气预报信息(伪代码如下):
class 气象台:
def 发布消息():
天气预报信息=获取天气信息()
print(天气预报信息)
if __name__='__main__':
气象台实例=气象台() # 实例化
每天早上8点:
气象台实例.发布消息()
用户需要在每天早上8点查看网站获取天气预报信息。后来,为了更方便为用户服务,气象台开放了用户订阅功能,发布消息的时候会同时给所有订阅用户发送包含天气预报信息的短信。具体来说,该气象台使用了订阅用户列表来管理订阅用户:
class 气象台:
def __init__():
self.订阅天气预报服务的用户=[]
def 订阅天气预报服务(用户)
self.订阅天气预报服务的用户.append(用户)
def 取消订阅天气预报服务(用户):
self.订阅天气预报服务的用户.remove(用户)
def 发布消息(): # 每天早上8点,气象台自动调用
天气预报信息=获取天气信息()
for 用户 in self.订阅天气预报服务的用户:
给用户发送短信(用户, 天气预报信息)
print(天气预报信息)
if __name__='__main__':
气象台实例=气象台() # 实例化
每天早上8点:
气象台实例.发布消息()
有用户订阅服务时:
气象台实例.订阅天气预报服务(用户)
有用户取消订阅服务时:
气象台实例.取消订阅天气预报服务(用户)
再后来,气象台希望进一步改进服务质量,以满足用户个性化的需求(例如,用户A只想在预报气温低于5度时收到通知,用户B只想在预报下雨时收到通知,用户C只想在预报气温在15~25度之间、大雨、且风力大于7级的情况下收到通知)。这时,一个简单的用户列表就不能满足要求了;为每个要求建立一个用户列表也是不现实的。
如果气象台管理包含用户需求的函数列表(这些函数称为hook函数),就能方便地满足用户的多样化需求:
class 气象台:
def __init__():
self.天气预报hook=[]
def 订阅天气预报服务(用户个性化hook函数)
self.天气预报hook.append(用户个性化hook函数)
def 取消订阅天气预报服务(用户个性化hook函数):
self.天气预报hook.remove(用户个性化hook函数)
def 发布消息(): # 每天早上8点,气象台自动调用
天气预报信息=获取天气信息()
for hook in self.天气预报hook:
hook(天气预报信息)
print(天气预报信息)
if __name__='__main__':
气象台实例=气象台() # 实例化
每天早上8点:
气象台实例.发布消息()
有用户订阅个性化服务时:
气象台实例.订阅天气预报服务(用户个性化hook函数)
有用户取消订阅个性化服务时:
气象台实例.取消订阅天气预报服务(用户个性化hook函数)
其中“用户个性化hook函数”需要固定输入为“天气预报信息”的格式,但内容可以自由设置。这里我们假设“天气预报信息”为字符串格式。假设用户D只想在预报有大雨的情况下让自己收到通知,且在预报气温低于5度的时候为家人发送“今日气温较低,出门记得穿上羽绒服”的短信,则该用户的“用户个性化hook函数”的可以为:
用户D的个性化hook函数(str): # str为天气预报信息
天气 = 获取天气(str) # 从字符串中提取天气信息
气温 = 获取气温(str) # 从字符串中提取温度信息
if 天气=='大雨':
发送信息到用户D的手机(str)
if 气温<=5:
发送信息到用户D家人的手机("今日气温较低,出门记得穿上羽绒服")
可见,用户无需知道气象台类的内部操作,就可自己设置功能多样的Hook函数。
2.2 MMDetection(3D)中的Hook
MMDetection3D中的Hook会被注册到Runner类中。我们在设计Hook函数时无需了解Runner内部的具体过程,而只需要知道所需输入的含义(均为Runner本身)。
本文第一章最后一个代码段就是EpochBasedRunner的训练流程,其中call_hook()函数就是在调用被注册的Hook。可见,EpochBasedRunner会在训练开始时、每个epoch开始时、每轮迭代开始时、每轮迭代结束时、每个epoch结束时、训练结束时调用Hook。
具体来看,上述call_hook()的具体代码如下:
def call_hook(self, fn_name: str):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as "before_train_epoch".
"""
for hook in self._hooks:
getattr(hook, fn_name)(self)
可见,同一个Hook可以定义在训练过程中的不同位置被调用,只需要我们写好相应的函数即可。例如,我们可以自定义一个Hook:
class MyHook(Hook):
def before_run(self, runner):
print("开始训练")
def before_train_epoch(self, runner):
print("开始一个epoch训练")
def after_train_epoch(self, runner):
print("结束一个epoch训练")
若将该Hook注册到Runner中,则在训练开始时以及每个epoch开始和结束的时候,MyHook类的相应函数就会被调用。
2.3 Hook的注册
2.4 例子:EvalHook
EvalHook的配置文件写法与2.3节中提到的Hook不太相同,且有其它需要强调的地方,因此单独拿出来介绍。
官方的EvalHook类(位于mmdet/core/evalation/eval_hooks.py)的部分代码如下:
class EvalHook(Hook):
def __init__(self, ...):
...
def after_train_epoch(self, runner): # 每个epoch结束时调用
if self.by_epoch and self._should_evaluate(runner):
self._do_evaluate(runner)
def _do_evaluate(self, runner):
...
results = single_gpu_test(runner.model, self.dataloader, show=False) # 在测试集上进行预测
...
key_score = self.evaluate(runner, results) # 评估测试结果,计算指标
if self.save_best and key_score:
self._save_ckpt(runner, key_score) # 保存最优模型
def evaluate(self, runner, results):
eval_res = self.dataloader.dataset.evaluate(results, logger=runner.logger, **self.eval_kwargs)
# 调用数据集类的evaluated函数进行评估(最后的eval_kwargs表明可在配置文件中的evaluation项内设置该评估函数的输入)
...
我们重点关注after_train_epoch()函数,可以看到,该函数就是在进行模型的评估。也就是说,如果我们在EpochBasedRunner下注册了EvalHook,那么在每个epoch结束后,程序会根据用户的设置来判断是否需要进行评估,如果需要评估,就在验证集上进行预测并计算指标(do_evaluate函数的工作)。
关于EvalHook的注册,需要看mmdet3d/apis/train.py中train_detector()中的如下部分:
if validate:
...
eval_cfg = cfg.get('evaluation', {}) # 获取配置信息
...
eval_hook = MMDET_DistEvalHook if distributed else MMDET_EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW') # 注册hook
可见,需要我们在配置文件中配置evaluation项,具体需要看EvalHook的__init__函数,这里不再介绍。