Pytorch构建网络模型时super(__class__, self).__init__()的作用

简介: Pytorch构建网络模型时super(__class__, self).__init__()的作用

0 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容不乏不准确的地方,希望批评指正,共同进步。

在使用Pytorch框架定义神经元网络模型的类的时候,首先都会在模型的类__init__()方法下加一行super(__class__, self).__init__()。例如:

class ClassName(torch.nn.Module):
    def __init__(self):
        super(ClassName, self).__init__()

对于所有的教程,这行代码几乎成为一个“潜规则”,虽然对于其作用并不太理解,久而久之也就默认了必须要加上这一行。


因此单独写一篇文章说明其作用,也深入自己的理解。


1 super()方法的说明

所有的Python初级教程,在介绍面向对象编程——类的时候都会提及super()方法,说明其作用是用于类的继承,但缺乏更深入的说明&理解。为了深入理解super()方法的运作原理,首先看下以下代码:


class A():
 
    def __init__(self):
        self.ten = 10
 
    def hello(self):
        return 'hello world'
 
 
class B(A):
 
    def __init__(self,x):
        # super(B, self).__init__()
        self.x = x
 
    def multi_ten(self):
        return self.x * self.ten
 
b = B(8)
 
print(b.hello())
print(b.multi_ten())
-------------------------------------------------
C:\Users\Lenovo\Desktop\DL\Pytest\Scripts\python.exe C:/Users/Lenovo/Desktop/DL/Pytest/test_main.py
hello world
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 23, in <module>
    print(b.multi_ten())
  File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 18, in multi_ten
    return self.x * self.ten
AttributeError: 'B' object has no attribute 'ten'
 
Process finished with exit code 1


如果去掉super(B, self).__init__()可以发现hello()方法还是可以运行的,也就是说:在类的继承时,super()方法并不是必须的


那什么时候必须用super()方法呢?在涉及自动运行的魔术方法时。例如上面的multi_ten()方法,其想要引用父类A方法__init__()中的self.ten,这时就必须在B类中使用super()方法,注明B类要继承A类中的__init__()方法。否则就会像上段代码一样报错并提示:B类中没有ten这个属性!(没有继承到)


魔术方法:Python内部定义,在类的实例化时自动运行的方法。这些方法的命名规则为 __xxxx__(),例如:__init__()。

另外,还有一个细节是super()方法中,括号内的内容是可以不用写的,这点可以用F4查看super()方法的定义,里面有段注释:


"super() -> same as super(__class__, <first argument>)"


__class__为当前的类名,<first argument>为self。


我个人使用的Python interpreter是Python 3.9,或许在更早版本的Python中,super()方法中是必须要填参数的,所以早期的教程都会写成super(__class__, self).__init__(),但是以后我们都不需要了。

2 从torch.nn.Module继承了什么?

再从一段最简单的线性神经元网络模型代码入手:

import torch
 
a = torch.tensor([1,2,3,4,5], dtype = torch.float32)
 
class test(torch.nn.Module):
    def __init__(self):
        # super().__init__()
        self.lin = torch.nn.Linear(5,2)
 
    def forward(self,x):
        return self.lin(x)
 
TEST = test()
 
print(TEST(a))


如果这里仍去掉super()方法,则会报错:


AttributeError: cannot assign module before Module.__init__() call


不出所料,是父类torch.nn.Module中的魔术方法__init__()没有继承(调用)到。


那它究竟定义了什么?

也可以通过F4,找到torch.nn.Module.__init__()的源码:


class Module:
 
...
 
    def __init__(self) -> None:
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")
 
        """
        Calls super().__setattr__('a', a) instead of the typical self.a = a
        to avoid Module.__setattr__ overhead. Module's __setattr__ has special
        handling for parameters, submodules, and buffers but simply calls into
        super().__setattr__ for all other attributes.
        """
        super().__setattr__('training', True)
        super().__setattr__('_parameters', OrderedDict())
        super().__setattr__('_buffers', OrderedDict())
        super().__setattr__('_non_persistent_buffers_set', set())
        super().__setattr__('_backward_hooks', OrderedDict())
        super().__setattr__('_is_full_backward_hook', None)
        super().__setattr__('_forward_hooks', OrderedDict())
        super().__setattr__('_forward_pre_hooks', OrderedDict())
        super().__setattr__('_state_dict_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_pre_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_post_hooks', OrderedDict())
        super().__setattr__('_modules', OrderedDict())
 
    forward: Callable[..., Any] = _forward_unimplemented

这里已经说明,torch.nn.Module.__init__()的作用是Initializes internal Module state(初始化内部模型状态)。具体地,就是初始化training,parameters..._modules这些在Pytorch中内部使用的属性。


其中,super().__setattr__()为调用torch.nn.Module的父类Object的__setattr__()方法,其作用就类似于“赋值”,例如:super().__setattr__('_parameters', OrderedDict()) 的作用就类似 self._parameters = OrderedDict()。那为什么不直接用赋值?这里也解释了: Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. 可以理解为__setattr__相比于简单赋值有着更多的作用。


所以,在Pytorch框架下,所有的神经元网络模型子类,都必须要继承这些内部属性的初始化过程。


相关文章
|
12天前
|
存储 监控 安全
单位网络监控软件:Java 技术驱动的高效网络监管体系构建
在数字化办公时代,构建基于Java技术的单位网络监控软件至关重要。该软件能精准监管单位网络活动,保障信息安全,提升工作效率。通过网络流量监测、访问控制及连接状态监控等模块,实现高效网络监管,确保网络稳定、安全、高效运行。
42 11
|
6天前
|
数据采集 机器学习/深度学习 人工智能
基于AI的网络流量分析:构建智能化运维体系
基于AI的网络流量分析:构建智能化运维体系
53 13
|
25天前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
43 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
10天前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本文详细介绍了如何在昇腾平台上使用PyTorch实现GraphSage算法,在CiteSeer数据集上进行图神经网络的分类训练。内容涵盖GraphSage的创新点、算法原理、网络架构及实战代码分析,通过采样和聚合方法高效处理大规模图数据。实验结果显示,模型在CiteSeer数据集上的分类准确率达到66.5%。
|
5天前
|
网络协议 安全 网络安全
探索网络模型与协议:从OSI到HTTPs的原理解析
OSI七层网络模型和TCP/IP四层模型是理解和设计计算机网络的框架。OSI模型包括物理层、数据链路层、网络层、传输层、会话层、表示层和应用层,而TCP/IP模型则简化为链路层、网络层、传输层和 HTTPS协议基于HTTP并通过TLS/SSL加密数据,确保安全传输。其连接过程涉及TCP三次握手、SSL证书验证、对称密钥交换等步骤,以保障通信的安全性和完整性。数字信封技术使用非对称加密和数字证书确保数据的机密性和身份认证。 浏览器通过Https访问网站的过程包括输入网址、DNS解析、建立TCP连接、发送HTTPS请求、接收响应、验证证书和解析网页内容等步骤,确保用户与服务器之间的安全通信。
40 1
|
10天前
|
监控 安全 BI
什么是零信任模型?如何实施以保证网络安全?
随着数字化转型,网络边界不断变化,组织需采用新的安全方法。零信任基于“永不信任,永远验证”原则,强调无论内外部,任何用户、设备或网络都不可信任。该模型包括微分段、多因素身份验证、单点登录、最小特权原则、持续监控和审核用户活动、监控设备等核心准则,以实现强大的网络安全态势。
|
19天前
|
云安全 人工智能 安全
|
23天前
|
机器学习/深度学习 人工智能 算法
深度学习入门:用Python构建你的第一个神经网络
在人工智能的海洋中,深度学习是那艘能够带你远航的船。本文将作为你的航标,引导你搭建第一个神经网络模型,让你领略深度学习的魅力。通过简单直观的语言和实例,我们将一起探索隐藏在数据背后的模式,体验从零开始创造智能系统的快感。准备好了吗?让我们启航吧!
61 3
|
1月前
|
数据采集 XML 存储
构建高效的Python网络爬虫:从入门到实践
本文旨在通过深入浅出的方式,引导读者从零开始构建一个高效的Python网络爬虫。我们将探索爬虫的基本原理、核心组件以及如何利用Python的强大库进行数据抓取和处理。文章不仅提供理论指导,还结合实战案例,让读者能够快速掌握爬虫技术,并应用于实际项目中。无论你是编程新手还是有一定基础的开发者,都能在这篇文章中找到有价值的内容。
|
1月前
|
并行计算 监控 搜索推荐
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
54 5