PyTorch的F.dropout为什么要加self.training?

简介: 以下介绍Module的training属性,F(torch.nn.functional).dropout 和 nn(torch.nn).Dropout 中相应操作的实现方式,以及Module的training属性受train()和eval()方法影响而改变的机制。

1. Module的training属性


见torch.nn.Module官方文档

是Module的属性,布尔值,返回Module是否处于训练状态。也就是说在训练时training就是True。

默认为True,也就是Module初始化时默认为训练状态。


2. torch.nn.functional.dropout的入参training


torch.nn.functional.dropout官方文档


torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False)

入参training默认为True,置True时应用Dropout,置False时不用。

因此在调用F.dropout()时,直接将self.training传入函数,就可以在训练时应用dropout,评估时关闭dropout。


示例代码:

x=F.dropout(x,p,self.training)


3. torch.nn.Dropout不需要手动开关


torch.nn.Dropout官方文档


torch.nn.Dropout(p=0.5, inplace=False)


其源代码为(Dropout源码):

class Dropout(_DropoutNd):
    def forward(self, input: Tensor) -> Tensor:
        return F.dropout(input, self.p, self.training, self.inplace)


就这个类相当于将 F.dropout() 进行了包装,内置传入了self.training,就不用像在 F.dropout() 里需要手动传参,也能实现在训练时应用dropout,评估时关闭dropout。


示例代码:

m = nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)


4. Module的train()和eval()方法改变self.training


torch.nn.Module.train官方文档

train(mode=True)

如果入参为True,则将Module设置为training mode,training随之变为True;反之则设置为evaluation mode,training为False。


torch.nn.Module.eval官方文档

eval()

将Module设置为evaluation mode,相当于 self.train(False)




相关文章
|
6月前
|
机器学习/深度学习 资源调度 监控
PyTorch使用Tricks:Dropout,R-Dropout和Multi-Sample Dropout等 !!
PyTorch使用Tricks:Dropout,R-Dropout和Multi-Sample Dropout等 !!
98 0
|
1月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(三):nn.BatchNorm2d()函数详解
本文介绍了PyTorch中的BatchNorm2d模块,它用于卷积层后的数据归一化处理,以稳定网络性能,并讨论了其参数如num_features、eps和momentum,以及affine参数对权重和偏置的影响。
141 0
Pytorch学习笔记(三):nn.BatchNorm2d()函数详解
|
1月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(四):nn.MaxPool2d()函数详解
这篇博客文章详细介绍了PyTorch中的nn.MaxPool2d()函数,包括其语法格式、参数解释和具体代码示例,旨在指导读者理解和使用这个二维最大池化函数。
114 0
Pytorch学习笔记(四):nn.MaxPool2d()函数详解
|
6月前
|
机器学习/深度学习 人工智能 PyTorch
基于torch.nn.Dropout通过实例说明Dropout丢弃法(附代码)
基于torch.nn.Dropout通过实例说明Dropout丢弃法(附代码)
136 0
|
PyTorch 算法框架/工具
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
505 2
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch torch.nn库以及nn与nn.functional有什么区别?
Pytorch torch.nn库以及nn与nn.functional有什么区别?
101 0
|
机器学习/深度学习 并行计算 PyTorch
【PyTorch】Training Model
【PyTorch】Training Model
89 0
|
PyTorch 算法框架/工具
PyTorch的nn.Linear()详解
从输入输出的张量的shape角度来理解,相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量。
511 0
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch学习笔记-06 Normalization layers
Pytorch学习笔记-06 Normalization layers
121 0
Pytorch学习笔记-06 Normalization layers