0 前言
按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容不乏不准确的地方,希望批评指正,共同进步。
在使用Pytorch框架定义神经元网络模型的类的时候,首先都会在模型的类__init__()方法下加一行super(__class__, self).__init__()。例如:
class ClassName(torch.nn.Module): def __init__(self): super(ClassName, self).__init__()
对于所有的教程,这行代码几乎成为一个“潜规则”,虽然对于其作用并不太理解,久而久之也就默认了必须要加上这一行。
因此单独写一篇文章说明其作用,也深入自己的理解。
1 super()方法的说明
所有的Python初级教程,在介绍面向对象编程——类的时候都会提及super()方法,说明其作用是用于类的继承,但缺乏更深入的说明&理解。为了深入理解super()方法的运作原理,首先看下以下代码:
class A(): def __init__(self): self.ten = 10 def hello(self): return 'hello world' class B(A): def __init__(self,x): # super(B, self).__init__() self.x = x def multi_ten(self): return self.x * self.ten b = B(8) print(b.hello()) print(b.multi_ten()) ------------------------------------------------- C:\Users\Lenovo\Desktop\DL\Pytest\Scripts\python.exe C:/Users/Lenovo/Desktop/DL/Pytest/test_main.py hello world Traceback (most recent call last): File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 23, in <module> print(b.multi_ten()) File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 18, in multi_ten return self.x * self.ten AttributeError: 'B' object has no attribute 'ten' Process finished with exit code 1
如果去掉super(B, self).__init__()可以发现hello()方法还是可以运行的,也就是说:在类的继承时,super()方法并不是必须的。
那什么时候必须用super()方法呢?在涉及自动运行的魔术方法时。例如上面的multi_ten()方法,其想要引用父类A方法__init__()中的self.ten,这时就必须在B类中使用super()方法,注明B类要继承A类中的__init__()方法。否则就会像上段代码一样报错并提示:B类中没有ten这个属性!(没有继承到)
魔术方法:Python内部定义,在类的实例化时自动运行的方法。这些方法的命名规则为 __xxxx__(),例如:__init__()。
另外,还有一个细节是super()方法中,括号内的内容是可以不用写的,这点可以用F4查看super()方法的定义,里面有段注释:
"super() -> same as super(__class__, <first argument>)"
__class__为当前的类名,<first argument>为self。
我个人使用的Python interpreter是Python 3.9,或许在更早版本的Python中,super()方法中是必须要填参数的,所以早期的教程都会写成super(__class__, self).__init__(),但是以后我们都不需要了。
2 从torch.nn.Module继承了什么?
再从一段最简单的线性神经元网络模型代码入手:
import torch a = torch.tensor([1,2,3,4,5], dtype = torch.float32) class test(torch.nn.Module): def __init__(self): # super().__init__() self.lin = torch.nn.Linear(5,2) def forward(self,x): return self.lin(x) TEST = test() print(TEST(a))
如果这里仍去掉super()方法,则会报错:
AttributeError: cannot assign module before Module.__init__() call
不出所料,是父类torch.nn.Module中的魔术方法__init__()没有继承(调用)到。
那它究竟定义了什么?
也可以通过F4,找到torch.nn.Module.__init__()的源码:
class Module: ... def __init__(self) -> None: """ Initializes internal Module state, shared by both nn.Module and ScriptModule. """ torch._C._log_api_usage_once("python.nn_module") """ Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. """ super().__setattr__('training', True) super().__setattr__('_parameters', OrderedDict()) super().__setattr__('_buffers', OrderedDict()) super().__setattr__('_non_persistent_buffers_set', set()) super().__setattr__('_backward_hooks', OrderedDict()) super().__setattr__('_is_full_backward_hook', None) super().__setattr__('_forward_hooks', OrderedDict()) super().__setattr__('_forward_pre_hooks', OrderedDict()) super().__setattr__('_state_dict_hooks', OrderedDict()) super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) super().__setattr__('_modules', OrderedDict()) forward: Callable[..., Any] = _forward_unimplemented
这里已经说明,torch.nn.Module.__init__()的作用是Initializes internal Module state(初始化内部模型状态)。具体地,就是初始化training,parameters..._modules这些在Pytorch中内部使用的属性。
其中,super().__setattr__()为调用torch.nn.Module的父类Object的__setattr__()方法,其作用就类似于“赋值”,例如:super().__setattr__('_parameters', OrderedDict()) 的作用就类似 self._parameters = OrderedDict()。那为什么不直接用赋值?这里也解释了: Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. 可以理解为__setattr__相比于简单赋值有着更多的作用。
所以,在Pytorch框架下,所有的神经元网络模型子类,都必须要继承这些内部属性的初始化过程。