哈喽宝子们!
平安夜快乐!!!
今天给大家带来了超级福利
超全超细的 MMCV Hook 方法
快乐过节前再学习下吧~
1. Hook 是什么
Hook 介绍
维基百科:钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)。
在训练过程中,通常有十个关键位点,如下图所示,从训练开始到结束,所有关键位点已用红色标出,共有 10 个。我们可以在这十个位点插入各种逻辑,例如加载模型权重、保存模型权重。而我们将同一类型的逻辑组织成一个 Hook。因此,MMCV 中 Hook 的作用就是训练和验证模型时,在不改变其他代码的前提下,灵活地在不同位点插入定制化的逻辑。
而控制整个训练过程的抽象在 MMCV 中被设计为 Runner,它的主要行为就是执行上图蓝色的工作流,MMCV 提供了两种类型的 Runner,一种是以 epoch 为单位迭代的 EpochBasedRunner,另一种是以 iteration 为单位迭代的 IterBasedRunner。下面给出 EpochBasedRunner 和 IterBasedRunner 在十个位点调用 Hook 对应方法的代码。
class EpochBasedRunner(BaseRunner): def run(self, data_loaders, workflow, max_epochs=None, **kwargs): # 开始运行时调用 self.call_hook('before_run') while self.epoch < self._max_epochs: # 开始 epoch 迭代前调用 self.call_hook('before_train_epoch') for i, data_batch in enumerate(self.train_dataloader): # 开始 iter 迭代前调用 self.call_hook('before_train_iter') # model forward # 经过一次迭代后调用 self.call_hook('after_train_iter') # 经过一个 epoch 迭代后调用 self.call_hook('after_train_epoch') # 开始验证 epoch 迭代前调用 self.call_hook('before_val_epoch') for i, data_batch in enumerate(self.val_dataloader): # 开始 iter 迭代前调用 self.call_hook('before_val_iter') # model forward # 经过一次迭代后调用 self.call_hook('after_val_iter') # 经过一个 epoch 迭代后调用 self.call_hook('after_val_epoch') # 运行完成前调用 self.call_hook('after_run') class IterbasedRunner(BaseRunner): def run(self, data_loaders, workflow, max_iters=None, **kwargs): # 开始运行时调用 self.call_hook('before_run') iter_loaders = [IterLoader(x) for x in data_loaders] # 开始 epoch 迭代前调用 # 注意:IterBaseRunner 只会调用一次 before_epoch 的位点 self.call_hook('before_epoch') while self.iter < self._max_iters: # 开始训练 iter 迭代前调用 self.call_hook('before_train_iter') # model forward # 经过一次训练迭代后调用 self.call_hook('after_train_iter') # 开始验证 iter 迭代前调用 self.call_hook('before_val_iter') # model forward # 经过一次验证迭代后调用 self.call_hook('after_val_iter') # 经过一个 epoch 迭代后调用 self.call_hook('after_epoch') # 运行完成前调用 self.call_hook('after_run')
我们以 CheckpointHook 为例简单介绍一下位点对应的方法。注意:并不是每个位点都需要实现对应的方法。
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py class CheckpointHook(Hook): """保存 checkpoint""" def __init__(self, interval=-1, by_epoch=True, save_optimizer=True, out_dir=None, max_keep_ckpts=-1, save_last=True, sync_buffer=False, file_client_args=None, **kwargs): # 参数初始化 def before_run(self, runner): # 设置 out_dir 和创建 FileClient 对象, # 其中 out_dir 是保存 checkpoint 的目录, # FileClient 对象作为统一接口调用不同的文件后端操作 checkpoint, # 在 CheckpointHook 中主要涉及保存 checkpoint 和删除 checkpoint # 的操作 def after_train_epoch(self, runner): # 处理 by_epoch 为 True 的情况 # 判断是否需要同步 buffer 参数以及 # 调用 _save_checkpoint 保存 checkpoint。 @master_only def _save_checkpoint(self, runner): # 保存 checkpoint 并且删除不想要的 checkpoint, # 不想要的 checkpoint 是指假设我们只想保存最近的 5 个 checkpoint, # 那么我们需要在第 6 个 checkpoint 生成的时候 # 删除第 1 个 checkpoint,可以通过设置 max_keep_ckpts # 实现该功能 def after_train_iter(self, runner): # 处理 by_epoch 为 Fasle 的情况 # 判断是否需要同步 buffer 参数以及 # 调用 _save_checkpoint 保存 checkpoint
Hook 列表
MMCV 提供了很多 Hook,每个 Hook 都有对应的优先级,在 Runner 训练过程中,同一位点,不同 Hook 的调用顺序是按它们的优先级所定义的,优先级越高,越早被调用。如果优先级一样,被调用的顺序和 Hook 注册的顺序一致。
我们将 MMCV 提供的 Hook 分为两类,一类是默认 Hook,另一类是定制 Hook。前者表示当我们调用 Runner 的 register_training_hooks 方法时被默认注册(注意,我们同样需要提供配置),后者表示需要手动注册,这里的手动有两种方式,一种是调用 Runner 的 register_hook 注册,另一种在调用 register_training_hooks 时传入 custom_hooks_config 参数。
!注意:不建议修改 MMCV 默认 Hook 的优先级,除非你有特殊需求。另外,定制 Hook 的优先级默认为 Normal(50)
默认 Hook
定制 Hook
2. Hook 用法介绍
Evalhook
介绍
EvalHook 按照一定的间隔对模型进行验证,在 EvalHook 出现之前,MMCV 对验证的支持是通过设置 workflow,形如 workflow=[('train', 2), ('val', 1)],表示每训练 2 个 epoch(假设使用的 Runner 是 EpochBasedRunner)验证一次。但这种方式灵活度不够,例如不能保存最优的模型。
于是,我们设计了 EvalHook。EvalHook 除了能很好地解决不能保存最优模型的问题,还提供了其他功能,例如支持从指定 epoch 才开始验证模型(因为前面的 epoch 模型效果较差,可以不验证从而节省时间)、支持恢复训练的时候先验证再训练(例如加载模型后想查看 checkpoint 的性能)。
MMCV 除了提供 EvalHook,还提供了 DistEvalHook,其继承自 EvalHook,用于分布式环境下的验证。除了初始化参数有些不同,DistEvalHook 还有一个不同点是重载了 EvalHook 中的 _do_evaluate 方法。EvalHook 中的 _do_evaluate 方法主要执行测试并保存最优模型(如果该模型是当前最优)。而 DistEvalHook 中的 _do_evaluate 作用也是类似的,首先在进行测试前同步 BN 中的 buffer(为了保证各个进程的模型是一致的),然后进行分布式测试(即每个进程单独测试),最后 master 进程收集其他进程的测试结果。
!推荐使用 EvalHook 代替 workflow 中的 val
用法
使用 EvalHook 只需两行代码,一行实例化 EvalHook ,另一个行将实例化的对象注册到 Runner 。
- 最简用法
from mmcv.runner.hooks import EvalHook val_dataloader = ... runner = EpochBasedRunner(...) runner.register_hook(EvalHook(val_dataloader))
- 间隔 5 个 epoch 验证一次
from mmcv.runner.hooks import EvalHook val_dataloader = ... runner = EpochBasedRunner(...) runner.register_hook(EvalHook(val_dataloader, interval=5))
- 恢复训练时先验证再训练
假设从第 5 个 epoch 恢复训练,将 start 设置小于等于 5 即可
from mmcv.runner.hooks import EvalHook val_dataloader = ... runner = EpochBasedRunner(...) runner.register_hook(EvalHook(val_dataloader, start=5))
- 保存最优的模型
通过设置 save_best='acc',EvalHook 会根据 'acc' 来选择最优的模型。
from mmcv.runner.hooks import EvalHook val_dataloader = ... runner = EpochBasedRunner(...) runner.register_hook(EvalHook(val_dataloader, save_best='acc'))
当然,也可以设置为 'auto',那么会自动根据返回的验证结果中的第一个 key 作为选择最优模型的依据。
CheckPointHook
介绍
CheckpointHook 主要是对模型参数进行保存,如果是分布式多卡训练,则仅仅会在 master 进程保存。另外,我们可以通过 max_keep_ckpts 参数设置最多保存多少个权重文件,权重文件数超过 `max_keep_ckpts` 时,前面的权重会被删除。
如果以 epoch 为单位进行保存,则该 Hook 实现 after_train_epoch 方法即可,否则仅需实现 after_train_iter 方法。
用法
- 最简用法
checkpoint_config = { 'interval': 5, # 每训练 5 个 epoch 保存一次 checkpoint } runner = EpochBasedRunner(...) runner.register_checkpoint_hook(checkpoint_config)
- 保存最新的 n 个 checkpoint
checkpoint_config = { 'interval': 5, # 每训练 5 个 epoch 保存一次 checkpoint 'max_keep_ckpts': 5, # 只保留最新的 5 个 checkpoint } runner = EpochBasedRunner(...) runner.register_checkpoint_hook(checkpoint_config)
- 将 checkpoint 保存至指定的路径
checkpoint_config = { 'interval': 5, # 每训练 5 个 epoch 保存一次 checkpoint 'out_dir': '/path/of/expected_directory', # 保存至 /path/of/expected_directory } runner = EpochBasedRunner(...) runner.register_checkpoint_hook(checkpoint_config)
- 同步 buffer
考虑到分布式训练过程,如果有必要(例如分布式训练中没有使用同步 BN,而是普通 BN),则可以通过设置参数 sync_buffer 为 True,在保存权重前,会对模型 buffers(典型的例如 BN 的全局均值和方差参数)进行跨卡同步,让每张卡的 buffers 参数都相同,此时在 master 进程保存权重和 buffer,才是合理的。
checkpoint_config = { 'interval': 5, # 每训练 5 个 epoch 保存一次 checkpoint 'sync_buffer': True, # 同步 buffer } runner = EpochBasedRunner(...) runner.register_checkpoint_hook(checkpoint_config)
文章来源:公众号【OpenMMLab】
2021-12-24 13:59