PyTorch的F.dropout为什么要加self.training?

简介: 以下介绍Module的training属性,F(torch.nn.functional).dropout 和 nn(torch.nn).Dropout 中相应操作的实现方式,以及Module的training属性受train()和eval()方法影响而改变的机制。

1. Module的training属性


见torch.nn.Module官方文档

是Module的属性,布尔值,返回Module是否处于训练状态。也就是说在训练时training就是True。

默认为True,也就是Module初始化时默认为训练状态。


2. torch.nn.functional.dropout的入参training


torch.nn.functional.dropout官方文档


torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False)

入参training默认为True,置True时应用Dropout,置False时不用。

因此在调用F.dropout()时,直接将self.training传入函数,就可以在训练时应用dropout,评估时关闭dropout。


示例代码:

x=F.dropout(x,p,self.training)


3. torch.nn.Dropout不需要手动开关


torch.nn.Dropout官方文档


torch.nn.Dropout(p=0.5, inplace=False)


其源代码为(Dropout源码):

class Dropout(_DropoutNd):
    def forward(self, input: Tensor) -> Tensor:
        return F.dropout(input, self.p, self.training, self.inplace)


就这个类相当于将 F.dropout() 进行了包装,内置传入了self.training,就不用像在 F.dropout() 里需要手动传参,也能实现在训练时应用dropout,评估时关闭dropout。


示例代码:

m = nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)


4. Module的train()和eval()方法改变self.training


torch.nn.Module.train官方文档

train(mode=True)

如果入参为True,则将Module设置为training mode,training随之变为True;反之则设置为evaluation mode,training为False。


torch.nn.Module.eval官方文档

eval()

将Module设置为evaluation mode,相当于 self.train(False)




相关文章
|
10月前
|
Numpy学习笔记(五):np.concatenate函数和np.append函数用于数组拼接
NumPy库中的`np.concatenate`和`np.append`函数,它们分别用于沿指定轴拼接多个数组以及在指定轴上追加数组元素。
394 0
Numpy学习笔记(五):np.concatenate函数和np.append函数用于数组拼接
蓝桥杯嵌入式第十二届省赛
蓝桥杯嵌入式第十二届省赛
262 0
Pytorch Lightning使用:【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】
Pytorch Lightning使用:【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】
1009 0
深度学习之相机内参标定
相机内参标定(Camera Intrinsic Calibration)是计算机视觉中的关键步骤,用于确定相机的内部参数(如焦距、主点位置、畸变系数等)。传统的标定方法依赖于已知尺寸的标定板,通常需要手动操作,繁琐且耗时。基于深度学习的方法则通过自动化处理,提供了一种高效、准确的内参标定方式。
611 13
基于Java的电影评论系统的设计与实现(源码+lw+部署文档+讲解等)
基于Java的电影评论系统的设计与实现(源码+lw+部署文档+讲解等)
201 0
单细胞不同样本数据整合-解决AnnData合并时ValueError: cannot reindex from a duplicate axis问题
单细胞不同样本数据整合-解决AnnData合并时ValueError: cannot reindex from a duplicate axis问题
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等

登录插画

登录以查看您的控制台资源

管理云资源
状态一览
快捷访问