1. Module的training属性
见torch.nn.Module官方文档
是Module的属性,布尔值,返回Module是否处于训练状态。也就是说在训练时training就是True。
默认为True,也就是Module初始化时默认为训练状态。
2. torch.nn.functional.dropout的入参training
torch.nn.functional.dropout官方文档
torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False)
入参training默认为True,置True时应用Dropout,置False时不用。
因此在调用F.dropout()时,直接将self.training传入函数,就可以在训练时应用dropout,评估时关闭dropout。
示例代码:
x=F.dropout(x,p,self.training)
3. torch.nn.Dropout不需要手动开关
torch.nn.Dropout官方文档
torch.nn.Dropout(p=0.5, inplace=False)
其源代码为(Dropout源码):
class Dropout(_DropoutNd): def forward(self, input: Tensor) -> Tensor: return F.dropout(input, self.p, self.training, self.inplace)
就这个类相当于将 F.dropout() 进行了包装,内置传入了self.training,就不用像在 F.dropout() 里需要手动传参,也能实现在训练时应用dropout,评估时关闭dropout。
示例代码:
m = nn.Dropout(p=0.2) input = torch.randn(20, 16) output = m(input)
4. Module的train()和eval()方法改变self.training
torch.nn.Module.train官方文档
train(mode=True)
如果入参为True,则将Module设置为training mode,training随之变为True;反之则设置为evaluation mode,training为False。
torch.nn.Module.eval官方文档
eval()
将Module设置为evaluation mode,相当于 self.train(False)