过节福利 | MMCV Hook 超全使用方法(上)

简介: 在训练过程中,通常有十个关键位点,如下图所示,从训练开始到结束,所有关键位点已用红色标出,共有 10 个。我们可以在这十个位点插入各种逻辑,例如加载模型权重、保存模型权重。而我们将同一类型的逻辑组织成一个 Hook。因此,MMCV 中 Hook 的作用就是训练和验证模型时,在不改变其他代码的前提下,灵活地在不同位点插入定制化的逻辑。

哈喽宝子们!

平安夜快乐!!!

640.png

今天给大家带来了超级福利

超全超细的 MMCV Hook 方法

快乐过节前再学习下吧~



1. Hook 是什么



Hook 介绍

维基百科:钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)。                        


在训练过程中,通常有十个关键位点,如下图所示,从训练开始到结束,所有关键位点已用红色标出,共有 10 个。我们可以在这十个位点插入各种逻辑,例如加载模型权重、保存模型权重。而我们将同一类型的逻辑组织成一个 Hook。因此,MMCV 中 Hook 的作用就是训练和验证模型时,在不改变其他代码的前提下,灵活地在不同位点插入定制化的逻辑。

640.png



而控制整个训练过程的抽象在 MMCV 中被设计为 Runner,它的主要行为就是执行上图蓝色的工作流,MMCV 提供了两种类型的 Runner,一种是以 epoch 为单位迭代的 EpochBasedRunner,另一种是以 iteration 为单位迭代的 IterBasedRunner。下面给出 EpochBasedRunner 和 IterBasedRunner 在十个位点调用 Hook 对应方法的代码。

class EpochBasedRunner(BaseRunner):
    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        # 开始运行时调用
        self.call_hook('before_run')
        while self.epoch < self._max_epochs:
            # 开始 epoch 迭代前调用
            self.call_hook('before_train_epoch')
            for i, data_batch in enumerate(self.train_dataloader):
                # 开始 iter 迭代前调用
                self.call_hook('before_train_iter')
                # model forward
                # 经过一次迭代后调用
                self.call_hook('after_train_iter')
            # 经过一个 epoch 迭代后调用
            self.call_hook('after_train_epoch')
            # 开始验证 epoch 迭代前调用
            self.call_hook('before_val_epoch')
            for i, data_batch in enumerate(self.val_dataloader):
                # 开始 iter 迭代前调用
                self.call_hook('before_val_iter')
                # model forward
                # 经过一次迭代后调用
                self.call_hook('after_val_iter')
            # 经过一个 epoch 迭代后调用
            self.call_hook('after_val_epoch')
        # 运行完成前调用
        self.call_hook('after_run')
class IterbasedRunner(BaseRunner):
    def run(self, data_loaders, workflow, max_iters=None, **kwargs):
        # 开始运行时调用
        self.call_hook('before_run')
        iter_loaders = [IterLoader(x) for x in data_loaders]
        # 开始 epoch 迭代前调用
        # 注意:IterBaseRunner 只会调用一次 before_epoch 的位点
        self.call_hook('before_epoch')
        while self.iter < self._max_iters:
            # 开始训练 iter 迭代前调用
            self.call_hook('before_train_iter')
            # model forward
            # 经过一次训练迭代后调用
            self.call_hook('after_train_iter')
            # 开始验证 iter 迭代前调用
            self.call_hook('before_val_iter')
            # model forward
            # 经过一次验证迭代后调用
            self.call_hook('after_val_iter')
        # 经过一个 epoch 迭代后调用
        self.call_hook('after_epoch')
        # 运行完成前调用
        self.call_hook('after_run')

我们以 CheckpointHook 为例简单介绍一下位点对应的方法。注意:并不是每个位点都需要实现对应的方法。

# https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py
class CheckpointHook(Hook):
    """保存 checkpoint"""
    def __init__(self,
                 interval=-1,
                 by_epoch=True,
                 save_optimizer=True,
                 out_dir=None,
                 max_keep_ckpts=-1,
                 save_last=True,
                 sync_buffer=False,
                 file_client_args=None,
                 **kwargs):
        # 参数初始化
    def before_run(self, runner):
        # 设置 out_dir 和创建 FileClient 对象,
        # 其中 out_dir 是保存 checkpoint 的目录,
        # FileClient 对象作为统一接口调用不同的文件后端操作 checkpoint,
        # 在 CheckpointHook 中主要涉及保存 checkpoint 和删除 checkpoint
        # 的操作
    def after_train_epoch(self, runner):
        # 处理 by_epoch 为 True 的情况
        # 判断是否需要同步 buffer 参数以及
        # 调用 _save_checkpoint 保存 checkpoint。
    @master_only
    def _save_checkpoint(self, runner):
        # 保存 checkpoint 并且删除不想要的 checkpoint,
        # 不想要的 checkpoint 是指假设我们只想保存最近的 5 个 checkpoint,
        # 那么我们需要在第 6 个 checkpoint 生成的时候
        # 删除第 1 个 checkpoint,可以通过设置 max_keep_ckpts
        # 实现该功能
    def after_train_iter(self, runner):
        # 处理 by_epoch 为 Fasle 的情况
        # 判断是否需要同步 buffer 参数以及
        # 调用 _save_checkpoint 保存 checkpoint


Hook 列表


MMCV 提供了很多 Hook,每个 Hook 都有对应的优先级,在 Runner 训练过程中,同一位点,不同 Hook 的调用顺序是按它们的优先级所定义的,优先级越高,越早被调用。如果优先级一样,被调用的顺序和 Hook 注册的顺序一致。


我们将 MMCV 提供的 Hook 分为两类,一类是默认 Hook,另一类是定制 Hook。前者表示当我们调用 Runner 的 register_training_hooks 方法时被默认注册(注意,我们同样需要提供配置),后者表示需要手动注册,这里的手动有两种方式,一种是调用 Runner 的 register_hook 注册,另一种在调用 register_training_hooks 时传入 custom_hooks_config 参数。


注意:不建议修改 MMCV 默认 Hook 的优先级,除非你有特殊需求。另外,定制 Hook 的优先级默认为 Normal(50)


默认 Hook

640.png

定制 Hook

640.png


2. Hook 用法介绍



Evalhook


介绍


EvalHook 按照一定的间隔对模型进行验证,在 EvalHook 出现之前,MMCV 对验证的支持是通过设置 workflow,形如 workflow=[('train', 2), ('val', 1)],表示每训练 2 个 epoch(假设使用的 Runner 是 EpochBasedRunner)验证一次。但这种方式灵活度不够,例如不能保存最优的模型。


于是,我们设计了 EvalHook。EvalHook 除了能很好地解决不能保存最优模型的问题,还提供了其他功能,例如支持从指定 epoch 才开始验证模型(因为前面的 epoch 模型效果较差,可以不验证从而节省时间)、支持恢复训练的时候先验证再训练(例如加载模型后想查看 checkpoint 的性能)。


MMCV 除了提供 EvalHook,还提供了 DistEvalHook,其继承自 EvalHook,用于分布式环境下的验证。除了初始化参数有些不同,DistEvalHook 还有一个不同点是重载了 EvalHook 中的 _do_evaluate 方法。EvalHook 中的 _do_evaluate 方法主要执行测试并保存最优模型(如果该模型是当前最优)。而 DistEvalHook 中的 _do_evaluate 作用也是类似的,首先在进行测试前同步 BN 中的 buffer(为了保证各个进程的模型是一致的),然后进行分布式测试(即每个进程单独测试),最后 master 进程收集其他进程的测试结果。


推荐使用 EvalHook 代替 workflow 中的 val


用法


使用 EvalHook 只需两行代码,一行实例化 EvalHook ,另一个行将实例化的对象注册到 Runner 。


- 最简用法

from mmcv.runner.hooks import EvalHook
val_dataloader = ...
runner = EpochBasedRunner(...)
runner.register_hook(EvalHook(val_dataloader))

- 间隔 5 个 epoch 验证一次

from mmcv.runner.hooks import EvalHook
val_dataloader = ...
runner = EpochBasedRunner(...)
runner.register_hook(EvalHook(val_dataloader, interval=5))

- 恢复训练时先验证再训练


假设从第 5 个 epoch 恢复训练,将 start 设置小于等于 5 即可

from mmcv.runner.hooks import EvalHook
val_dataloader = ...
runner = EpochBasedRunner(...)
runner.register_hook(EvalHook(val_dataloader, start=5))

- 保存最优的模型


通过设置 save_best='acc',EvalHook 会根据 'acc' 来选择最优的模型。

from mmcv.runner.hooks import EvalHook
val_dataloader = ...
runner = EpochBasedRunner(...)
runner.register_hook(EvalHook(val_dataloader, save_best='acc'))

当然,也可以设置为 'auto',那么会自动根据返回的验证结果中的第一个 key 作为选择最优模型的依据。


CheckPointHook


介绍


CheckpointHook 主要是对模型参数进行保存,如果是分布式多卡训练,则仅仅会在 master 进程保存。另外,我们可以通过 max_keep_ckpts 参数设置最多保存多少个权重文件,权重文件数超过 `max_keep_ckpts` 时,前面的权重会被删除。


如果以 epoch 为单位进行保存,则该 Hook 实现 after_train_epoch 方法即可,否则仅需实现 after_train_iter 方法。


用法


- 最简用法

checkpoint_config = {
    'interval': 5,  # 每训练 5 个 epoch 保存一次 checkpoint
}
runner = EpochBasedRunner(...)
runner.register_checkpoint_hook(checkpoint_config)

- 保存最新的 n 个 checkpoint

checkpoint_config = {
    'interval': 5,  # 每训练 5 个 epoch 保存一次 checkpoint
    'max_keep_ckpts': 5,  # 只保留最新的 5 个 checkpoint
}
runner = EpochBasedRunner(...)
runner.register_checkpoint_hook(checkpoint_config)

- 将 checkpoint 保存至指定的路径

checkpoint_config = {
    'interval': 5,  # 每训练 5 个 epoch 保存一次 checkpoint
    'out_dir': '/path/of/expected_directory',  # 保存至 /path/of/expected_directory
}
runner = EpochBasedRunner(...)
runner.register_checkpoint_hook(checkpoint_config)

- 同步 buffer


考虑到分布式训练过程,如果有必要(例如分布式训练中没有使用同步 BN,而是普通 BN),则可以通过设置参数 sync_buffer 为 True,在保存权重前,会对模型 buffers(典型的例如 BN 的全局均值和方差参数)进行跨卡同步,让每张卡的 buffers 参数都相同,此时在 master 进程保存权重和 buffer,才是合理的。


checkpoint_config = {
    'interval': 5,  # 每训练 5 个 epoch 保存一次 checkpoint
    'sync_buffer': True,  # 同步 buffer
}
runner = EpochBasedRunner(...)
runner.register_checkpoint_hook(checkpoint_config)


文章来源:公众号【OpenMMLab】

 2021-12-24 13:59

目录
相关文章
跟我从0学Python——函数和模块
第三篇:函数和模块 —— 代码的模块化与重用
|
搜索推荐 数据挖掘 PyTorch
Py之albumentations:albumentations库函数的简介、安装、使用方法之详细攻略续篇
Py之albumentations:albumentations库函数的简介、安装、使用方法之详细攻略续篇
Py之albumentations:albumentations库函数的简介、安装、使用方法之详细攻略续篇
|
3月前
|
JSON 算法 数据安全/隐私保护
brida和frida练习hook逆向技术【中】
本文介绍了如何在未加壳、未混淆的 APK 中定位并破解加密算法,并使用 Burp 插件 autoDecoder 进行自动化加解密及口令爆破。文中详细描述了从反编译到配置插件的全过程,并提供了关键要素如 AES 算法、SECRET_KEY 和 Base64 编码的具体应用。此外,还展示了如何调整并发数以提高爆破成功率。
49 0
brida和frida练习hook逆向技术【中】
|
3月前
|
Java Android开发 数据安全/隐私保护
brida和frida练习hook逆向技术【上】
使用zangcc测试包.apk,练习 Brida 和 Frida 的 Hook 逆向技术。
29 0
brida和frida练习hook逆向技术【上】
|
7月前
|
JavaScript Unix API
Nodejs 第十四章(process)
Nodejs 第十四章(process)
63 0
|
机器学习/深度学习 Web App开发 数据可视化
过节福利 | MMCV Hook 超全使用方法(下)
在训练过程中,通常有十个关键位点,如下图所示,从训练开始到结束,所有关键位点已用红色标出,共有 10 个。我们可以在这十个位点插入各种逻辑,例如加载模型权重、保存模型权重。而我们将同一类型的逻辑组织成一个 Hook。因此,MMCV 中 Hook 的作用就是训练和验证模型时,在不改变其他代码的前提下,灵活地在不同位点插入定制化的逻辑。
1935 0
过节福利 | MMCV Hook 超全使用方法(下)
|
7月前
|
Shell Android开发 数据安全/隐私保护
安卓逆向 -- Frida环境搭建(HOOK实例)
安卓逆向 -- Frida环境搭建(HOOK实例)
166 0
|
7月前
|
缓存 前端开发 JavaScript
【源码共读】Vue2源码 shared 模块中的36个实用工具函数分析
【源码共读】Vue2源码 shared 模块中的36个实用工具函数分析
218 0
|
7月前
|
存储 JSON 缓存
【源码共读】axios的46个工具函数
【源码共读】axios的46个工具函数
152 0
|
JavaScript 前端开发 Linux
Hook神器—Frida安装
Hook神器—Frida安装