1.函数语法格式和作用
作用:卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
2.参数解释
num_features:一般输入参数为batch_size×num_features×height×width,即为其中特征的数量
eps:分母中添加的一个值,目的是为了计算的稳定性,默认为:1e-5
momentum:一个用于运行过程中均值和方差的一个估计参数(我的理解是一个稳定系数,类似于SGD中的momentum的系数)
affine:当设为true时,会给定可以学习的系数矩阵gamma和beta
3.具体代码
# encoding:utf-8
import torch
import torch.nn as nn
# num_features - num_features from an expected input of size:batch_size*num_features*height*width
# eps:default:1e-5 (公式中为数值稳定性加到分母上的值)
# momentum:动量参数,用于running_mean and running_var计算的值,default:0.1
m = nn.BatchNorm2d(3) # affine参数设为True表示weight和bias将被使用
m1 = nn.BatchNorm2d(3, affine=False) # affine参数设为True表示weight和bias将被使用
input = torch.randn(2, 3, 2, 3)
output = m(input)
output1 = m1(input)
print('"""affine=True"""')
print(input)
print(m.weight)
print(m.bias)
print(output)
print(output.size())
print('"""affine=False"""')
print(output1)
print(output1.size())
结果如下
"""affine=True"""
tensor([[[[ 0.5408, 0.2707, -0.4395],
[ 0.7942, -1.3403, 0.9146]],
[[ 0.0082, 0.3639, -0.1986],
[ 1.6522, -0.3494, -0.8619]],
[[ 0.1021, 0.2455, 0.9168],
[-0.2652, 0.0869, -1.3121]]],
[[[-0.5038, -1.0989, 1.3820],
[ 1.5612, -0.0384, -1.5507]],
[[-0.4546, 2.5124, -1.1012],
[ 1.0045, -0.7018, 1.3485]],
[[-2.7837, -0.6371, -0.7099],
[-0.0732, 1.1424, 0.6456]]]])
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0., 0.], requires_grad=True)
tensor([[[[ 0.4995, 0.2295, -0.4802],
[ 0.7527, -1.3803, 0.8730]],
[[-0.2414, 0.0885, -0.4332],
[ 1.2832, -0.5730, -1.0483]],
[[ 0.3156, 0.4560, 1.1133],
[-0.0441, 0.3006, -1.0692]]],
[[[-0.5444, -1.1390, 1.3400],
[ 1.5191, -0.0794, -1.5906]],
[[-0.6706, 2.0809, -1.2702],
[ 0.6825, -0.8999, 1.0016]],
[[-2.5102, -0.4082, -0.4795],
[ 0.1439, 1.3342, 0.8478]]]], grad_fn=<NativeBatchNormBackward>)
torch.Size([2, 3, 2, 3])
"""affine=False"""
tensor([[[[ 0.4995, 0.2295, -0.4802],
[ 0.7527, -1.3803, 0.8730]],
[[-0.2414, 0.0885, -0.4332],
[ 1.2832, -0.5730, -1.0483]],
[[ 0.3156, 0.4560, 1.1133],
[-0.0441, 0.3006, -1.0692]]],
[[[-0.5444, -1.1390, 1.3400],
[ 1.5191, -0.0794, -1.5906]],
[[-0.6706, 2.0809, -1.2702],
[ 0.6825, -0.8999, 1.0016]],
[[-2.5102, -0.4082, -0.4795],
[ 0.1439, 1.3342, 0.8478]]]])
torch.Size([2, 3, 2, 3])
Process finished with exit code 0