我们在搭建网络时,通常要继承nn.Module
这个模块,并且实现其forward
方法,那么这个基类中到底有何属性呢?
def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hoods = OrderedDict() self.training = True
这个基类有以下属性:
_parameters
:有序字典,保存用户直接设置的Parameter。例如,对于self.param1 = nn.Parameter(torch.randn(3, 3)),构造函数会在字典中加入一个key为param1、value为对应Parameter的item。self.submodule = nn.Linear(3, 4)中的Parameter不会被保存在该字典中。_modules
:子module。例如,通过self.submodel = nn.Linear(3, 4)指定的子module会被保存于此。_buffers
:缓存。例如,BatchNorm使用动量机制,每次前向传播时都需要用到上一次前向传播的结果。_backward_hooks
:钩子技术,用来提取中间变量。_forward_hoods
:钩子技术,用来提取中间变量。
+training
:BatchNorm层与Dropout层在训练阶段和测试阶段采取的策略不同,通过training属性决定前向传播策略。