Pytorch构建网络模型时super(__class__, self).__init__()的作用

简介: Pytorch构建网络模型时super(__class__, self).__init__()的作用

0 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容不乏不准确的地方,希望批评指正,共同进步。

在使用Pytorch框架定义神经元网络模型的类的时候,首先都会在模型的类__init__()方法下加一行super(__class__, self).__init__()。例如:

class ClassName(torch.nn.Module):
    def __init__(self):
        super(ClassName, self).__init__()

对于所有的教程,这行代码几乎成为一个“潜规则”,虽然对于其作用并不太理解,久而久之也就默认了必须要加上这一行。


因此单独写一篇文章说明其作用,也深入自己的理解。


1 super()方法的说明

所有的Python初级教程,在介绍面向对象编程——类的时候都会提及super()方法,说明其作用是用于类的继承,但缺乏更深入的说明&理解。为了深入理解super()方法的运作原理,首先看下以下代码:


class A():
 
    def __init__(self):
        self.ten = 10
 
    def hello(self):
        return 'hello world'
 
 
class B(A):
 
    def __init__(self,x):
        # super(B, self).__init__()
        self.x = x
 
    def multi_ten(self):
        return self.x * self.ten
 
b = B(8)
 
print(b.hello())
print(b.multi_ten())
-------------------------------------------------
C:\Users\Lenovo\Desktop\DL\Pytest\Scripts\python.exe C:/Users/Lenovo/Desktop/DL/Pytest/test_main.py
hello world
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 23, in <module>
    print(b.multi_ten())
  File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 18, in multi_ten
    return self.x * self.ten
AttributeError: 'B' object has no attribute 'ten'
 
Process finished with exit code 1


如果去掉super(B, self).__init__()可以发现hello()方法还是可以运行的,也就是说:在类的继承时,super()方法并不是必须的


那什么时候必须用super()方法呢?在涉及自动运行的魔术方法时。例如上面的multi_ten()方法,其想要引用父类A方法__init__()中的self.ten,这时就必须在B类中使用super()方法,注明B类要继承A类中的__init__()方法。否则就会像上段代码一样报错并提示:B类中没有ten这个属性!(没有继承到)


魔术方法:Python内部定义,在类的实例化时自动运行的方法。这些方法的命名规则为 __xxxx__(),例如:__init__()。

另外,还有一个细节是super()方法中,括号内的内容是可以不用写的,这点可以用F4查看super()方法的定义,里面有段注释:


"super() -> same as super(__class__, <first argument>)"


__class__为当前的类名,<first argument>为self。


我个人使用的Python interpreter是Python 3.9,或许在更早版本的Python中,super()方法中是必须要填参数的,所以早期的教程都会写成super(__class__, self).__init__(),但是以后我们都不需要了。

2 从torch.nn.Module继承了什么?

再从一段最简单的线性神经元网络模型代码入手:

import torch
 
a = torch.tensor([1,2,3,4,5], dtype = torch.float32)
 
class test(torch.nn.Module):
    def __init__(self):
        # super().__init__()
        self.lin = torch.nn.Linear(5,2)
 
    def forward(self,x):
        return self.lin(x)
 
TEST = test()
 
print(TEST(a))


如果这里仍去掉super()方法,则会报错:


AttributeError: cannot assign module before Module.__init__() call


不出所料,是父类torch.nn.Module中的魔术方法__init__()没有继承(调用)到。


那它究竟定义了什么?

也可以通过F4,找到torch.nn.Module.__init__()的源码:


class Module:
 
...
 
    def __init__(self) -> None:
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")
 
        """
        Calls super().__setattr__('a', a) instead of the typical self.a = a
        to avoid Module.__setattr__ overhead. Module's __setattr__ has special
        handling for parameters, submodules, and buffers but simply calls into
        super().__setattr__ for all other attributes.
        """
        super().__setattr__('training', True)
        super().__setattr__('_parameters', OrderedDict())
        super().__setattr__('_buffers', OrderedDict())
        super().__setattr__('_non_persistent_buffers_set', set())
        super().__setattr__('_backward_hooks', OrderedDict())
        super().__setattr__('_is_full_backward_hook', None)
        super().__setattr__('_forward_hooks', OrderedDict())
        super().__setattr__('_forward_pre_hooks', OrderedDict())
        super().__setattr__('_state_dict_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_pre_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_post_hooks', OrderedDict())
        super().__setattr__('_modules', OrderedDict())
 
    forward: Callable[..., Any] = _forward_unimplemented

这里已经说明,torch.nn.Module.__init__()的作用是Initializes internal Module state(初始化内部模型状态)。具体地,就是初始化training,parameters..._modules这些在Pytorch中内部使用的属性。


其中,super().__setattr__()为调用torch.nn.Module的父类Object的__setattr__()方法,其作用就类似于“赋值”,例如:super().__setattr__('_parameters', OrderedDict()) 的作用就类似 self._parameters = OrderedDict()。那为什么不直接用赋值?这里也解释了: Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. 可以理解为__setattr__相比于简单赋值有着更多的作用。


所以,在Pytorch框架下,所有的神经元网络模型子类,都必须要继承这些内部属性的初始化过程。


相关文章
|
3月前
|
前端开发 JavaScript 开发者
JavaScript:构建动态网络的引擎
JavaScript:构建动态网络的引擎
|
3月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
199 1
|
3月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
163 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
2月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
3月前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
209 2
|
3月前
|
机器学习/深度学习 并行计算 算法
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
|
3月前
|
人工智能 监控 数据可视化
如何破解AI推理延迟难题:构建敏捷多云算力网络
本文探讨了AI企业在突破算力瓶颈后,如何构建高效、稳定的网络架构以支撑AI产品化落地。文章分析了典型AI IT架构的四个层次——流量接入层、调度决策层、推理服务层和训练算力层,并深入解析了AI架构对网络提出的三大核心挑战:跨云互联、逻辑隔离与业务识别、网络可视化与QoS控制。最终提出了一站式网络解决方案,助力AI企业实现多云调度、业务融合承载与精细化流量管理,推动AI服务高效、稳定交付。
|
2月前
|
机器学习/深度学习 分布式计算 Java
Java与图神经网络:构建企业级知识图谱与智能推理系统
图神经网络(GNN)作为处理非欧几里得数据的前沿技术,正成为企业知识管理和智能推理的核心引擎。本文深入探讨如何在Java生态中构建基于GNN的知识图谱系统,涵盖从图数据建模、GNN模型集成、分布式图计算到实时推理的全流程。通过具体的代码实现和架构设计,展示如何将先进的图神经网络技术融入传统Java企业应用,为构建下一代智能决策系统提供完整解决方案。
320 0
|
3月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
147 0

推荐镜像

更多