PyTorch中的forward的理解

简介: PyTorch中的forward的理解

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

1. 关于forward的两个小问题

1.1 为什么都用def forward,而不改个名字?

在Pytorch建立神经元网络模型的时候,经常用到forward方法,表示在建立模型后,进行神经元网络的前向传播。说的直白点,forward就是专门用来计算给定输入,得到神经元网络输出的方法。


在代码实现中,也是用def forward来写forward前向传播的方法,我原来以为这是一种约定熟成的名字,也可以换成任意一个自己喜欢的名字。


但是看的多了之后发现并非如此:Pytorch对于forward方法赋予了一些特殊“功能”


(这里不禁再吐槽,一些看起来挺厉害的Pytorch“大神”,居然不知道这个。。。只能草草解释一下:“就是这样的。。。”)


1.2 forward有什么特殊功能?
第一条:.forward()可以不写

我最开始发现forward()的与众不同之处就是在此,首先举个例子:

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input
    def forward(self,x):
        return self.input * x
T = test(8)
print(T(6))
# print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
Process finished with exit code 0


可以发现,T(6)是可以输出的!而且不用指定,默认了调用forward方法。当然如果非要写上.forward()这也是可以正常运行的,和不写是一样的。


如果不调用Pytorch(正常的Python语法规则),这样肯定会报错的


# import torch.nn as nn  #不再调用torch
class test():
    def __init__(self, input):
        self.input = input
    def forward(self,x):
        return self.input * x
T = test(8)
print(T.forward(6))
print("************************")
print(T(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
************************
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\pythonProject\tt.py", line 77, in <module>
    print(T(6))
TypeError: 'test' object is not callable
Process finished with exit code 1


这里会报:‘test’ object is not callable

因为class不能被直接调用,不知道你想调用哪个方法。


第二条:优先运行forward方法

如果在class中再增加一个方法:

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input
    def byten(self):
        return self.input * 10
    def forward(self,x):
        return self.input * x
T = test(8)
print(T(6))
print(T.byten())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
80
Process finished with exit code 0


可以见到,在class中有多个method的时候,如果不指定method,forward是会被优先执行的。


2. 总结


在Pytorch中,forward方法是一个特殊的方法,被专门用来进行前向传播。


20230605 更新


应评论要求,增加forward的官方定义,这块我就不搬运PyTorch官网的内容了,直接传送门走你:nn.Module.forward


20230919 大更新

首先非常感谢大家喜欢本文!这篇文章本来是我自己的“随手记”没想到有这么多C友浏览过!


其实在写完本文后我是有些遗憾的,因为本文仅是用了实验的方法探索出了.forward()的表象,而它的运作机理却没有说明白,知其然不知其所以然!


在此感谢下面 Mr·小鱼 的评论给了我启迪,因为魔术方法__call__()的特性确实很符合.forward()的表象,但是我对着nn.Module的源码一脸茫然,因为源码中压根没有__call__()方法的定义!!


于是我抱着试试的心态,在PyTorch官网上查了下PyTorch的历史版本,这一查确实查到了线索:



下面是从PyTorch的上古版本v0.1.12中截取forward()__call__()方法的源码:


class Module(object):
#...中间不相关代码省略...
    def forward(self, *input):
        """Defines the computation performed at every call.
        Should be overriden by all subclasses.
        """
        raise NotImplementedError
#...中间不相关代码省略...
    def __call__(self, *input, **kwargs):
        result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                raise RuntimeError(
                    "forward hooks should never return any values, but '{}'"
                    "didn't return None".format(hook))
        var = result
        while not isinstance(var, Variable):
            var = var[0]
        creator = var.creator
        if creator is not None and len(self._backward_hooks) > 0:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                creator.register_hook(wrapper)
        return result


我们可以看到在__call__()方法中直接把方法self.forward()作为函数的返回值,由于魔术方法__call__()可以被自动调用,这也就解释了为什么forward()可以自动运行。


至于该方法中的其他内容,都是与hook钩子函数的操作相关,这部分暂不做探索。。。


那我们回到现在的版本(我现在使用的是1.8.1):


通过源码可以看到经历了多个版本的更迭,forward()__call__()居然改名字了!!


forward: Callable[..., Any] = _forward_unimplemented
    ...
    __call__ : Callable[..., Any] = _call_impl


这也就是为什么我之前在源码中没找到这两个方法定义的原因。。。准确来说这里也不能说是改名字了,而是多了一个名字,至于PyTorch为什么会有这样的更改,我确实也没想到原因。。。


其中_forward_unimplemented()倒是没变:


def _forward_unimplemented(self, *input: Any) -> None:
    r"""Defines the computation performed at every call.
    Should be overridden by all subclasses.
    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.
    """
    raise NotImplementedError


_call_impl()相比于上古版本,已经复杂到了令人发指的地步!


def _call_impl(self, *input, **kwargs):
        # Do not call functions when jit is used
        full_backward_hooks, non_full_backward_hooks = [], []
        if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0:
            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        bw_hook = None
        if len(full_backward_hooks) > 0:
            bw_hook = hooks.BackwardHook(self, full_backward_hooks)
            input = bw_hook.setup_input_hook(input)
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if bw_hook:
            result = bw_hook.setup_output_hook(result)
        # Handle the non-full backward hooks
        if len(non_full_backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in non_full_backward_hooks:
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
                self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
        return result


其变复杂的原因是各种钩子函数_hook的调用,有兴趣的童鞋可以参考这篇文章:pytorch 中_call_impl()函数。这部分绝对是超纲了!


最后我想再做几个实验加深理解:

实验①


import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input
    def forward(self,x):
        return self.input * x
T = test(8)
print(T.__call__(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py 
48
Process finished with exit code 0


这里T.__call__(6)写法等价于T(6)

实验②


import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input
    def forward(self,x):
        return self.input * x
T = test(8)
print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py 
48
Process finished with exit code 0


这里T.forward(6)的写法虽然也能正确地计算出结果,但是不推荐这么写,因为这会导致__call__()调用一遍forward(),然后手动又调用了一遍forward(),造成forward()的重复计算,浪费计算资源。


实验③

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input
    # def forward(self,x):
    #     return self.input * x
T = test(8)
print(T())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py 
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py", line 11, in <module>
    print(T())
  File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 201, in _forward_unimplemented
    raise NotImplementedError
NotImplementedError


forward()是必须要写的,因为__call__()要自动调用forward()。如果压根不写forward()__call__()将无方法可以调用。按照forward()的源码,这里会raise NotImplementedError


至此,我觉得PyTorch中的forward应该算是全说明白了。。。


相关文章
|
3月前
|
机器学习/深度学习 监控 PyTorch
以pytorch的forward hook为例探究hook机制
【10月更文挑战第9天】PyTorch中的Hook机制允许在不修改模型代码的情况下,获取和修改模型中间层的信息,如输入和输出等,适用于模型可视化、特征提取及梯度计算。Forward Hook在前向传播后触发,可用于特征提取和模型监控。实现上,需定义接收模块、输入和输出参数的Hook函数,并将其注册到目标层。与Backward Hook相比,前者关注前向传播,后者侧重反向传播和梯度处理,两者共同增强了对模型内部运行情况的理解和控制。
|
3月前
|
机器学习/深度学习 存储 数据可视化
以pytorch的forward hook为例探究hook机制
【10月更文挑战第10天】PyTorch 的 Hook 机制允许用户在不修改模型代码的情况下介入前向和反向传播过程,适用于模型可视化、特征提取及梯度分析等任务。通过注册 `forward hook`,可以在模型前向传播过程中插入自定义操作,如记录中间层输出。使用时需注意输入输出格式及计算资源占用。
|
3月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
429 2
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
50 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
3月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
89 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
3月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
164 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
4月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
247 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
4月前
|
机器学习/深度学习 监控 PyTorch
PyTorch 模型调试与故障排除指南
在深度学习领域,PyTorch 成为开发和训练神经网络的主要框架之一。本文为 PyTorch 开发者提供全面的调试指南,涵盖从基础概念到高级技术的内容。目标读者包括初学者、中级开发者和高级工程师。本文探讨常见问题及解决方案,帮助读者理解 PyTorch 的核心概念、掌握调试策略、识别性能瓶颈,并通过实际案例获得实践经验。无论是在构建简单神经网络还是复杂模型,本文都将提供宝贵的洞察和实用技巧,帮助开发者更高效地开发和优化 PyTorch 模型。
59 3
PyTorch 模型调试与故障排除指南
|
3月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
5月前
|
机器学习/深度学习 PyTorch 编译器
PyTorch 与 TorchScript:模型的序列化与加速
【8月更文第27天】PyTorch 是一个非常流行的深度学习框架,它以其灵活性和易用性而著称。然而,当涉及到模型的部署和性能优化时,PyTorch 的动态计算图可能会带来一些挑战。为了解决这些问题,PyTorch 引入了 TorchScript,这是一个用于序列化和优化 PyTorch 模型的工具。本文将详细介绍如何使用 TorchScript 来序列化 PyTorch 模型以及如何加速模型的执行。
198 4