PyTorch 小课堂!一篇看懂核心网络模块接口(上)

简介: 小伙伴们大家好呀~前面的文章中(PyTorch 小课堂开课啦!带你解析数据处理全流程(一)、PyTorch 小课堂!带你解析数据处理全流程(二)),我们介绍了数据处理模块。而当我们解决了数据处理部分,接下来就需要构建自己的网络结构,从而才能将我们使用数据预处理模块得到的 batch data 送进网络结构当中。接下来,我们就带领大家一起再认识一下 PyTorch 中的神经网络模块,即 torch.nn。

小伙伴们大家好呀~前面的文章中(PyTorch 小课堂开课啦!带你解析数据处理全流程(一)PyTorch 小课堂!带你解析数据处理全流程(二)),我们介绍了数据处理模块。而当我们解决了数据处理部分,接下来就需要构建自己的网络结构,从而才能将我们使用数据预处理模块得到的 batch data 送进网络结构当中。接下来,我们就带领大家一起再认识一下 PyTorch 中的神经网络模块,即 torch.nn。本文主要对nn.Module 进行剖析。感兴趣的小伙伴快点往下看吧!


核心网络模块接口设计



首先需要了解 nn.Module 其实是 PyTorch 体系下所有神经网络模块的基类,我们可以简单梳理一下 torch.nn 中的各个组件,可知他们的关系概览如下图:

640.png

当我们再展开各模块之后,各模块之间的继承关系与层次结构如下图:

640.png


从各模块的继承关系来看,模块的组织和实现有几个常见的特点,可供我们使用 PyTorch 开发时参考借鉴:


1)一般有一个基类来定义接口,可通过继承来处理不同维度的 input,如:


· Conv1d,Conv2d,Conv3d,ConvTransposeNd 继承自 _ConvNd


· MaxPool1d,MaxPool2d,MaxPool3d 继承自 _MaxPoolNd 等


2)每一个类都有一个对应的 nn.functional 函数,类定义了所需要的 arguments 和模块的 parameters,在 forward 函数中将 arguments 和 parameters 传给 nn.functional 的对应函数来实现 forward 功能。比如:


· 所有的非线性激活函数,都是在 forward 中直接调用对应的 nn.functional 函数


· Normalization 层都是调用的如 F.layer_norm, F.group_norm 等函数


3)继承 nn.Module 的模块主要是重载 init、 forward、 和 extra_repr 函数,而含有 parameters 的模块还会实现 reset_parameters 函数来初始化参数。

1. 常用接口



1.1 init 函数


在 nn.Module 的 init 函数中,会首先调用 torch._C._log_api_usage_once("python.nn_module"), 这一行代码是 PyTorch 1.7 的新功能,可用于监测并记录 API 的调用,在此之后,nn.Module 初始化了一系列重要的成员变量。这些变量初始化了在网络模块进行 forward、 backward 和权重加载等行为时候会被调用到的 hooks,同时也定义了 parameters 和 buffers,如下面的代码所示:

def __init__(self) -> None:
    """
    Initializes internal Module state, shared by both nn.Module and ScriptModule.
    """
    torch._C._log_api_usage_once("python.nn_module")
    self.training = True  # 控制 training/testing 状态
    self._parameters = OrderedDict()  # 在训练过程中会随着 BP 而更新的参数
    self._buffers = OrderedDict()  # 在训练过程中不会随着 BP 而更新的参数
    self._non_persistent_buffers_set = set()
    self._backward_hooks = OrderedDict()  # Backward 完成后会被调用的 hook
    self._forward_hooks = OrderedDict()  # Forward 完成后会被调用的 hook
    self._forward_pre_hooks = OrderedDict()  # Forward 前会被调用的 hook
    self._state_dict_hooks = OrderedDict()  # 得到 state_dict 以后会被调用的 hook
    self._load_state_dict_pre_hooks = OrderedDict()  # load state_dict 前会被调用的 hook
    self._modules = OrderedDict()  # 子神经网络模块

各个成员变量的功能在后面还会继续提到,这里先在注释中简单解释。由源码的实现可见,继承 nn.Module 的神经网络模块在实现自己的 __init__ 函数时,一定要先调用 super().__init__()。只有这样才能正确地初始化自定义的神经网络模块,否则会缺少上面代码中的成员变量而导致模块被调用时出错。实际上,如果没有提前调用 super().__init__(),在增加模块的 parameter 或者 buffer 的时候,被调用的 __setattr__函数也会检查出父类 nn.Module 没被正确地初始化并报错。(敲重点!在面试的过程中,我们经常发现大家在写自定义神经网络模块的时候容易忽略掉这一点,看了这篇文章以后可要千万记得哦~)


1.2 状态的转换


训练与测试


nn.Module 通过 self.training 来区分训练和测试两种状态,使得模块可以在训练和测试时有不同的 forward 行为(如 Batch Normalization)。nn.Module 通过 self.train() 和 self.eval() 来修改训练和测试状态,其中 self.eval 直接调用了 self.train(False),而self.train() 会修改 self.training 并通过 self.children() 来调整所有子模块的状态。关于 self.children() 会在下文 2.3 常见的属性访问中再进行更多的介绍。

def train(self: T, mode: bool = True) -> T:
    if not isinstance(mode, bool):
        raise ValueError("training mode is expected to be boolean")
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self


Example: freeze 部分模型参数


在目标检测等任务中,常见的 training practice 会将 backbone 中的所有 BN 层保留为 eval 状态,即 freeze BN 层中的 running_mean 和 running_var,并且将浅层的模块 freeze。此时我们需要重载 detector 类的 train 函数,比如 MMDetection 中 ResNet 的 train 函数实现:


def train(self, mode=True):
    super(ResNet, self).train(mode)
    self._freeze_stages()
    if mode and self.norm_eval:
        for m in self.modules():
            # trick: eval have effect on BatchNorm only
            if isinstance(m, _BatchNorm):
                m.eval()


梯度的处理


对于梯度的处理 nn.Module 中有两个相关的函数实现,分别是 requires_grad_ 和 zero_grad 函数,他们都调用了 self.parameters() 来访问所有的参数,并修改参数的 requires_grad 状态或者清理参数的梯度。

def requires_grad_(self: T, requires_grad: bool = True) -> T:
    for p in self.parameters():
        p.requires_grad_(requires_grad)
    return self
def zero_grad(self, set_to_none: bool = False) -> None:
    if getattr(self, '_is_replica', False):
        warnings.warn(
            "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
            "The parameters are copied (in a differentiable manner) from the original module. "
            "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
            "If you need gradients in your forward method, consider using autograd.grad instead.")
    for p in self.parameters():
        if p.grad is not None:
            if set_to_none:
                p.grad = None
            else:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()


1.3 参数的转换或转移


nn.Module 实现了如下 8 个常用函数将参数转变成 float16 等类型、转移到 CPU/ GPU 上以及移动模块或/和改变模块的类型。


· cpu():将所有 parameters 和 buffer 转移到 CPU 上

· type():将所有 parameters 和 buffer 转变成另一个类型

· cuda():将所有 parameters 和 buffer 转移到 GPU 上

· float():将所有浮点类型的 parameters 和 buffer 转变成 float32 类型

· double():将所有浮点类型的 parameters 和 buffer 转变成 double 类型

· half():将所有浮点类型的 parameters 和 buffer 转变成 float16 类型

· bfloat16():将所有浮点类型的 parameters 和 buffer 转变成 bfloat16 类型

· to():移动模块或/和改变模块的类型


这些函数的功能最终都是通过 self._apply(function) 来实现的, function 一般是 lambda 表达式或其他自定义函数。因此,我们其实也可以通过 self._apply(function) 来实现一些特殊的转换。self._apply() 函数实际上做了如下 3 件事情,最终将 function 完整地应用于整个网络模块。


1)通过 self.children() 进行递归的调用。


2)对 self._parameters 中的参数及其 gradient 通过 function 进行处理。


3)对 self._buffers 中的 buffer 逐个通过 function 来进行处理。

def _apply(self, fn):
    # 对子模块进行递归调用
    for module in self.children():
        module._apply(fn)
    # 为了 BC-breaking 而新增了一个 tensor 类型判断
    def compute_should_use_set_data(tensor, tensor_applied):
        if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
            # If the new tensor has compatible tensor type as the existing tensor,
            # the current behavior is to change the tensor in-place using `.data =`,
            # and the future behavior is to overwrite the existing tensor. However,
            # changing the current behavior is a BC-breaking change, and we want it
            # to happen in future releases. So for now we introduce the
            # `torch.__future__.get_overwrite_module_params_on_conversion()`
            # global flag to let the user control whether they want the future
            # behavior of overwriting the existing tensor or not.
            return not torch.__future__.get_overwrite_module_params_on_conversion()
        else:
            return False
    # 处理参数及其gradint
    for key, param in self._parameters.items():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't want to
            # track autograd history of `param_applied`, so we have to use
            # `with torch.no_grad():`
            with torch.no_grad():
                param_applied = fn(param)
            should_use_set_data = compute_should_use_set_data(param, param_applied)
            if should_use_set_data:
                param.data = param_applied
            else:
                assert isinstance(param, Parameter)
                assert param.is_leaf
                self._parameters[key] = Parameter(param_applied, param.requires_grad)
            if param.grad is not None:
                with torch.no_grad():
                    grad_applied = fn(param.grad)
                should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                if should_use_set_data:
                    param.grad.data = grad_applied
                else:
                    assert param.grad.is_leaf
                    self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
    # 处理 buffers
    for key, buf in self._buffers.items():
        if buf is not None:
            self._buffers[key] = fn(buf)
    return self


1.4 Apply 函数


nn.Module 还实现了一个 apply 函数,与 _apply() 函数不同的是,apply 函数只是简单地递归调用了 self.children() 去处理自己以及子模块,如下面的代码所示。

def apply(self: T, fn: Callable[['Module'], None]) -> T:
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

apply 函数和 _apply 函数的区别在于,_apply() 是专门针对 parameter 和 buffer 而实现的一个“仅供内部使用”的接口,但是 apply 函数是“公有”接口 (Python 对类的“公有”和“私有”区别并不是很严格,一般通过单前导下划线来区分)。apply 实际上可以通过修改 fn 来实现 _apply 能实现的功能,同时还可以实现其他功能,如下面的重新初始化参数的例子。


Example: 参数重新初始化


可以自定义一个 init_weights 函数,通过 net.apply(init_weights) 来初始化模型权重。

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)


文章来源:【OpenMMLab

 2022-04-20 18:20


目录
相关文章
|
2天前
|
安全
AC/DC电源模块在通信与网络设备中的应用的研究
AC/DC电源模块在通信与网络设备中的应用的研究
AC/DC电源模块在通信与网络设备中的应用的研究
|
2天前
BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究
BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究
BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究
|
2天前
|
存储 算法 网络协议
【探索Linux】P.26(网络编程套接字基本概念—— socket编程接口 | socket编程接口相关函数详细介绍 )
【探索Linux】P.26(网络编程套接字基本概念—— socket编程接口 | socket编程接口相关函数详细介绍 )
13 0
|
2天前
|
存储 安全 光互联
|
2天前
|
网络安全 数据安全/隐私保护 Linux
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套B模块-2
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套B模块
|
2天前
|
网络虚拟化 网络协议 Windows
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套B模块-1
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套B模块
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套B模块-1
|
2天前
|
数据安全/隐私保护 网络协议 网络虚拟化
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套A模块
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套A模块
【题目】2023年全国职业院校技能大赛 GZ073 网络系统管理赛项赛题第3套A模块
|
2天前
|
网络安全 数据安全/隐私保护 Linux
2023年全国职业院校技能大赛=GZ073 网络系统管理赛项赛题第2套B模块-2
2023年全国职业院校技能大赛=GZ073 网络系统管理赛项赛题第2套B模块
2023年全国职业院校技能大赛=GZ073 网络系统管理赛项赛题第2套B模块-2
|
2天前
|
网络虚拟化 Windows 网络协议
2023年全国职业院校技能大赛=GZ073 网络系统管理赛项赛题第2套B模块-1
2023年全国职业院校技能大赛=GZ073 网络系统管理赛项赛题第2套B模块
2023年全国职业院校技能大赛=GZ073 网络系统管理赛项赛题第2套B模块-1
|
2天前
|
网络虚拟化 SDN 数据安全/隐私保护
2023年全国职业院校技能大赛GZ073 网络系统管理赛项赛题第2套A模块
2023年全国职业院校技能大赛GZ073 网络系统管理赛项赛题第2套A模块
2023年全国职业院校技能大赛GZ073 网络系统管理赛项赛题第2套A模块

热门文章

最新文章