torch.nn模块的BN类
pytorch的torch.nn模块中有几个BN类:nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d
。
主要参数有:
num_features
:特征数eps=1e-05
:$\epsilon$,防止分母为0momentum=0.1
:均值和方差滑动平局的动量值affine=True
: 是否仿射变换
track_running_stats=True
:是否计算均值和方差的滑动平均。
通常情况下除了num_features
其他默认即可。
如果要深究track_running_stats
的取值,有以下两种情况:
(1)track_running_stats=True
训练阶段model.train()
:BN用训练集当前批次的均值和方差计算,并计算均值和方差的滑动平均
测试阶段model.eval()
:BN用训练阶段得到的均值和方差的滑动平均计算
(2)track_running_stats=False
训练阶段model.train()
:BN用训练集当前批次的均值和方差计算,不计算均值和方差的滑动平均
测试阶段model.eval()
:BN用测试集当前批次的均值和方差计算
备注:训练阶段model.eval()
,测试阶段model.train()
这种错误的设置我们不考虑。
nn.BatchNorm1d
对2D或3D输入(带有可选附加通道尺寸的一小批1D输入)应用批量标准化。可用于全连接层。
Input: (N, C)
Output: (N, C)
import torch
import torch.nn as nn
m = nn.BatchNorm1d(100)
input_1d = torch.randn(64, 100)
output = m(input_1d)
print(output.size())
输出:
torch.Size([64, 100])
Input:(N, C, L)
Output:(N, C, L)
m = nn.BatchNorm1d(100)
input_1d = torch.randn(64, 100,2)
output = m(input_1d)
print(output.size())
输出:
torch.Size([64, 100, 2])
nn.BatchNorm2d
在4D输入(带有附加通道尺寸的2D输入的小批量)上应用批量标准化,可用于卷积层。
Input: (N, C, H, W)
Output: (N, C, H, W)
m = nn.BatchNorm2d(3)
input_2d = torch.randn(32, 3, 64, 64)
output = m(input_2d)
print(output.size())
输出:
torch.Size([32, 3, 64, 64)
nn.BatchNorm3d
在5D输入上应用批量标准化(带有附加通道尺寸的一小批3D输入)
Input: (N, C, D, H, W)
Output: (N, C, D, H, W)
m = nn.BatchNorm3d(3)
input_3d = torch.randn(64,3,64,64,100)
output = m(input_3d)
print(output.size())
输出:
torch.Size([64,3,64,64,100)
以上BN类需指定特征数。新版本Pytorch的nn.LazyBatchNorm1d,nn.LazyBatchNorm2d,nn.LazyBatchNorm3d
,则能从input.size(1)
推断出特征数,无需指定。
LeNet-5 + BN
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
class Flatten(nn.Module):
'''新版本的pytorch可直接使用nn.Flatten'''
def forward(self, x):
return x.flatten(1)
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv_bn_act = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.BatchNorm2d(6),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120),
nn.ReLU(True),
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(True),
nn.Linear(84, 10)
)
self.conv_act = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(True),
nn.Linear(120, 84),
nn.ReLU(True),
nn.Linear(84, 10)
)
def forward(self, x):
x = self.conv_bn_act(x)
return x
def train_loop(dataloader, model, loss_fn, optimizer, device):
for i, data in enumerate(dataloader, 0):
# 获取输入
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
# 计算预测值和损失
outputs = model(inputs)
loss = loss_fn(outputs, labels)
# 反向传播优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print('[Batch%4d] loss: %.3f' % (i + 1, loss.item()))
def test_loop(dataloader, model, device):
correct = 0
total = 0
with torch.no_grad():
for data in dataloader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
if __name__ == '__main__':
# 设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
# 数据集
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 标准化图像数据
trainset = datasets.CIFAR10(root='./cifar10_data', train=True,
download=True, transform=transform)
# 使用num_workers个子进程进行数据加载
trainloader = DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=2)
testset = datasets.CIFAR10(root='./cifar10_data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64,
shuffle=False, num_workers=2)
# 超参数
lr = 0.01 # 选较大的学习率0.001->0.01
epochs = 10
# 模型实例
model = LeNet5().to(device)
# 损失函数实例
loss_fn = nn.CrossEntropyLoss()
# 优化器实例
optimizer = optim.Adam(model.parameters(), lr=lr)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
model.train()
train_loop(trainloader, model, loss_fn, optimizer, device=device)
model.eval()
test_loop(testloader, model, device=device)
print("Done!")
不使用BN:
Epoch 1
-------------------------------
[Batch 100] loss: 1.904
[Batch 200] loss: 2.112
[Batch 300] loss: 1.721
[Batch 400] loss: 1.797
[Batch 500] loss: 1.863
[Batch 600] loss: 1.757
[Batch 700] loss: 1.891
Accuracy of the network on the 10000 test images: 33 %
Epoch 2
-------------------------------
[Batch 100] loss: 1.673
[Batch 200] loss: 1.708
[Batch 300] loss: 1.686
[Batch 400] loss: 1.736
[Batch 500] loss: 1.548
[Batch 600] loss: 1.646
[Batch 700] loss: 1.800
Accuracy of the network on the 10000 test images: 36 %
Epoch 3
-------------------------------
[Batch 100] loss: 1.754
[Batch 200] loss: 1.568
[Batch 300] loss: 1.581
[Batch 400] loss: 1.609
[Batch 500] loss: 1.700
[Batch 600] loss: 1.845
[Batch 700] loss: 1.626
Accuracy of the network on the 10000 test images: 39 %
Epoch 4
-------------------------------
[Batch 100] loss: 1.699
[Batch 200] loss: 1.585
[Batch 300] loss: 1.840
[Batch 400] loss: 1.688
[Batch 500] loss: 1.412
[Batch 600] loss: 1.569
[Batch 700] loss: 1.587
Accuracy of the network on the 10000 test images: 42 %
Epoch 5
-------------------------------
[Batch 100] loss: 1.727
[Batch 200] loss: 1.425
[Batch 300] loss: 1.699
[Batch 400] loss: 1.471
[Batch 500] loss: 1.702
[Batch 600] loss: 1.374
[Batch 700] loss: 1.497
Accuracy of the network on the 10000 test images: 41 %
Epoch 6
-------------------------------
[Batch 100] loss: 1.365
[Batch 200] loss: 1.664
[Batch 300] loss: 1.528
[Batch 400] loss: 1.444
[Batch 500] loss: 1.623
[Batch 600] loss: 1.382
[Batch 700] loss: 1.896
Accuracy of the network on the 10000 test images: 44 %
Epoch 7
-------------------------------
[Batch 100] loss: 1.783
[Batch 200] loss: 1.728
[Batch 300] loss: 1.500
[Batch 400] loss: 1.522
[Batch 500] loss: 1.400
[Batch 600] loss: 1.552
[Batch 700] loss: 1.482
Accuracy of the network on the 10000 test images: 44 %
Epoch 8
-------------------------------
[Batch 100] loss: 1.572
[Batch 200] loss: 1.088
[Batch 300] loss: 1.555
[Batch 400] loss: 1.380
[Batch 500] loss: 1.774
[Batch 600] loss: 1.589
[Batch 700] loss: 1.500
Accuracy of the network on the 10000 test images: 45 %
Epoch 9
-------------------------------
[Batch 100] loss: 1.411
[Batch 200] loss: 1.696
[Batch 300] loss: 1.494
[Batch 400] loss: 1.454
[Batch 500] loss: 1.401
[Batch 600] loss: 1.552
[Batch 700] loss: 1.766
Accuracy of the network on the 10000 test images: 48 %
Epoch 10
-------------------------------
[Batch 100] loss: 1.431
[Batch 200] loss: 1.309
[Batch 300] loss: 1.555
[Batch 400] loss: 1.436
[Batch 500] loss: 1.485
[Batch 600] loss: 1.440
[Batch 700] loss: 1.373
Accuracy of the network on the 10000 test images: 47 %
Done!
使用BN:
Epoch 1
-------------------------------
[Batch 100] loss: 1.571
[Batch 200] loss: 1.588
[Batch 300] loss: 1.443
[Batch 400] loss: 1.439
[Batch 500] loss: 1.209
[Batch 600] loss: 1.205
[Batch 700] loss: 0.996
Accuracy of the network on the 10000 test images: 55 %
Epoch 2
-------------------------------
[Batch 100] loss: 1.134
[Batch 200] loss: 1.395
[Batch 300] loss: 1.279
[Batch 400] loss: 1.043
[Batch 500] loss: 1.000
[Batch 600] loss: 1.141
[Batch 700] loss: 1.191
Accuracy of the network on the 10000 test images: 59 %
Epoch 3
-------------------------------
[Batch 100] loss: 1.456
[Batch 200] loss: 0.928
[Batch 300] loss: 0.987
[Batch 400] loss: 1.119
[Batch 500] loss: 1.186
[Batch 600] loss: 1.055
[Batch 700] loss: 0.952
Accuracy of the network on the 10000 test images: 62 %
Epoch 4
-------------------------------
[Batch 100] loss: 0.956
[Batch 200] loss: 0.979
[Batch 300] loss: 0.830
[Batch 400] loss: 1.061
[Batch 500] loss: 0.885
[Batch 600] loss: 0.904
[Batch 700] loss: 0.807
Accuracy of the network on the 10000 test images: 61 %
Epoch 5
-------------------------------
[Batch 100] loss: 0.843
[Batch 200] loss: 0.854
[Batch 300] loss: 0.993
[Batch 400] loss: 1.025
[Batch 500] loss: 0.898
[Batch 600] loss: 1.075
[Batch 700] loss: 0.654
Accuracy of the network on the 10000 test images: 63 %
Epoch 6
-------------------------------
[Batch 100] loss: 0.623
[Batch 200] loss: 0.704
[Batch 300] loss: 0.821
[Batch 400] loss: 1.147
[Batch 500] loss: 0.761
[Batch 600] loss: 1.032
[Batch 700] loss: 0.852
Accuracy of the network on the 10000 test images: 64 %
Epoch 7
-------------------------------
[Batch 100] loss: 0.718
[Batch 200] loss: 0.882
[Batch 300] loss: 0.855
[Batch 400] loss: 0.818
[Batch 500] loss: 0.888
[Batch 600] loss: 0.576
[Batch 700] loss: 0.963
Accuracy of the network on the 10000 test images: 65 %
Epoch 8
-------------------------------
[Batch 100] loss: 0.706
[Batch 200] loss: 0.515
[Batch 300] loss: 0.742
[Batch 400] loss: 0.491
[Batch 500] loss: 0.714
[Batch 600] loss: 0.878
[Batch 700] loss: 0.821
Accuracy of the network on the 10000 test images: 66 %
Epoch 9
-------------------------------
[Batch 100] loss: 0.814
[Batch 200] loss: 0.968
[Batch 300] loss: 0.729
[Batch 400] loss: 0.838
[Batch 500] loss: 0.649
[Batch 600] loss: 0.664
[Batch 700] loss: 0.692
Accuracy of the network on the 10000 test images: 67 %
Epoch 10
-------------------------------
[Batch 100] loss: 0.792
[Batch 200] loss: 0.560
[Batch 300] loss: 0.698
[Batch 400] loss: 0.857
[Batch 500] loss: 0.815
[Batch 600] loss: 0.853
[Batch 700] loss: 0.724
Accuracy of the network on the 10000 test images: 66 %
Done!