【Pytorch】(八)Batch Normalization

简介:

炼丹神trick,nb的BN算法总结

torch.nn模块的BN类

pytorch的torch.nn模块中有几个BN类:nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d

主要参数有:

  • num_features:特征数
  • eps=1e-05:$\epsilon$,防止分母为0

    • momentum=0.1:均值和方差滑动平局的动量值
    • affine=True: 是否仿射变换
  • track_running_stats=True:是否计算均值和方差的滑动平均。

通常情况下除了num_features其他默认即可。

如果要深究track_running_stats的取值,有以下两种情况:
(1)track_running_stats=True
训练阶段model.train():BN用训练集当前批次的均值和方差计算,并计算均值和方差的滑动平均
测试阶段model.eval():BN用训练阶段得到的均值和方差的滑动平均计算

(2)track_running_stats=False
训练阶段model.train():BN用训练集当前批次的均值和方差计算,不计算均值和方差的滑动平均
测试阶段model.eval():BN用测试集当前批次的均值和方差计算

备注:训练阶段model.eval(),测试阶段model.train()这种错误的设置我们不考虑。

nn.BatchNorm1d

对2D或3D输入(带有可选附加通道尺寸的一小批1D输入)应用批量标准化。可用于全连接层。

Input: (N, C)

Output: (N, C)

import torch
import torch.nn as nn
m = nn.BatchNorm1d(100)
input_1d = torch.randn(64, 100)
output = m(input_1d)
print(output.size())

输出:

torch.Size([64, 100])

Input:(N, C, L)

Output:(N, C, L)

m = nn.BatchNorm1d(100)
input_1d = torch.randn(64, 100,2)
output = m(input_1d)
print(output.size())

输出:

torch.Size([64, 100, 2])

nn.BatchNorm2d

在4D输入(带有附加通道尺寸的2D输入的小批量)上应用批量标准化,可用于卷积层。

Input: (N, C, H, W)

Output: (N, C, H, W)

m = nn.BatchNorm2d(3)
input_2d = torch.randn(32, 3, 64, 64)
output = m(input_2d)
print(output.size())

输出:

torch.Size([32, 3, 64, 64)

nn.BatchNorm3d

在5D输入上应用批量标准化(带有附加通道尺寸的一小批3D输入)

Input: (N, C, D, H, W)

Output: (N, C, D, H, W)

m = nn.BatchNorm3d(3)
input_3d = torch.randn(64,3,64,64,100)
output = m(input_3d)
print(output.size())

输出:

torch.Size([64,3,64,64,100)

以上BN类需指定特征数。新版本Pytorch的nn.LazyBatchNorm1d,nn.LazyBatchNorm2d,nn.LazyBatchNorm3d,则能从input.size(1)推断出特征数,无需指定。

LeNet-5 + BN

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms


class Flatten(nn.Module):
    '''新版本的pytorch可直接使用nn.Flatten'''
    def forward(self, x):
        return x.flatten(1)


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv_bn_act = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.BatchNorm2d(6),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.BatchNorm1d(120),
            nn.ReLU(True),
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(True),
            nn.Linear(84, 10)
        )

        self.conv_act = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(True),
            nn.Linear(120, 84),
            nn.ReLU(True),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        x = self.conv_bn_act(x)
        return x


def train_loop(dataloader, model, loss_fn, optimizer, device):
    for i, data in enumerate(dataloader, 0):
        # 获取输入
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # 计算预测值和损失
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)

        # 反向传播优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('[Batch%4d] loss: %.3f' % (i + 1, loss.item()))



def test_loop(dataloader, model, device):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))


if __name__ == '__main__':
    # 设备
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    # 数据集
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 标准化图像数据

    trainset = datasets.CIFAR10(root='./cifar10_data', train=True,
                                download=True, transform=transform)
    # 使用num_workers个子进程进行数据加载
    trainloader = DataLoader(trainset, batch_size=64,
                             shuffle=True, num_workers=2)

    testset = datasets.CIFAR10(root='./cifar10_data', train=False,
                               download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=64,
                            shuffle=False, num_workers=2)
    # 超参数
    lr = 0.01  # 选较大的学习率0.001->0.01
    epochs = 10
    # 模型实例
    model = LeNet5().to(device)
    # 损失函数实例
    loss_fn = nn.CrossEntropyLoss()
    # 优化器实例
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        model.train()
        train_loop(trainloader, model, loss_fn, optimizer, device=device)
        model.eval()
        test_loop(testloader, model, device=device)
    print("Done!")




不使用BN:

Epoch 1
-------------------------------
[Batch 100] loss: 1.904
[Batch 200] loss: 2.112
[Batch 300] loss: 1.721
[Batch 400] loss: 1.797
[Batch 500] loss: 1.863
[Batch 600] loss: 1.757
[Batch 700] loss: 1.891
Accuracy of the network on the 10000 test images: 33 %
Epoch 2
-------------------------------
[Batch 100] loss: 1.673
[Batch 200] loss: 1.708
[Batch 300] loss: 1.686
[Batch 400] loss: 1.736
[Batch 500] loss: 1.548
[Batch 600] loss: 1.646
[Batch 700] loss: 1.800
Accuracy of the network on the 10000 test images: 36 %
Epoch 3
-------------------------------
[Batch 100] loss: 1.754
[Batch 200] loss: 1.568
[Batch 300] loss: 1.581
[Batch 400] loss: 1.609
[Batch 500] loss: 1.700
[Batch 600] loss: 1.845
[Batch 700] loss: 1.626
Accuracy of the network on the 10000 test images: 39 %
Epoch 4
-------------------------------
[Batch 100] loss: 1.699
[Batch 200] loss: 1.585
[Batch 300] loss: 1.840
[Batch 400] loss: 1.688
[Batch 500] loss: 1.412
[Batch 600] loss: 1.569
[Batch 700] loss: 1.587
Accuracy of the network on the 10000 test images: 42 %
Epoch 5
-------------------------------
[Batch 100] loss: 1.727
[Batch 200] loss: 1.425
[Batch 300] loss: 1.699
[Batch 400] loss: 1.471
[Batch 500] loss: 1.702
[Batch 600] loss: 1.374
[Batch 700] loss: 1.497
Accuracy of the network on the 10000 test images: 41 %
Epoch 6
-------------------------------
[Batch 100] loss: 1.365
[Batch 200] loss: 1.664
[Batch 300] loss: 1.528
[Batch 400] loss: 1.444
[Batch 500] loss: 1.623
[Batch 600] loss: 1.382
[Batch 700] loss: 1.896
Accuracy of the network on the 10000 test images: 44 %
Epoch 7
-------------------------------
[Batch 100] loss: 1.783
[Batch 200] loss: 1.728
[Batch 300] loss: 1.500
[Batch 400] loss: 1.522
[Batch 500] loss: 1.400
[Batch 600] loss: 1.552
[Batch 700] loss: 1.482
Accuracy of the network on the 10000 test images: 44 %
Epoch 8
-------------------------------
[Batch 100] loss: 1.572
[Batch 200] loss: 1.088
[Batch 300] loss: 1.555
[Batch 400] loss: 1.380
[Batch 500] loss: 1.774
[Batch 600] loss: 1.589
[Batch 700] loss: 1.500
Accuracy of the network on the 10000 test images: 45 %
Epoch 9
-------------------------------
[Batch 100] loss: 1.411
[Batch 200] loss: 1.696
[Batch 300] loss: 1.494
[Batch 400] loss: 1.454
[Batch 500] loss: 1.401
[Batch 600] loss: 1.552
[Batch 700] loss: 1.766
Accuracy of the network on the 10000 test images: 48 %
Epoch 10
-------------------------------
[Batch 100] loss: 1.431
[Batch 200] loss: 1.309
[Batch 300] loss: 1.555
[Batch 400] loss: 1.436
[Batch 500] loss: 1.485
[Batch 600] loss: 1.440
[Batch 700] loss: 1.373
Accuracy of the network on the 10000 test images: 47 %
Done!

使用BN:

Epoch 1
-------------------------------
[Batch 100] loss: 1.571
[Batch 200] loss: 1.588
[Batch 300] loss: 1.443
[Batch 400] loss: 1.439
[Batch 500] loss: 1.209
[Batch 600] loss: 1.205
[Batch 700] loss: 0.996
Accuracy of the network on the 10000 test images: 55 %
Epoch 2
-------------------------------
[Batch 100] loss: 1.134
[Batch 200] loss: 1.395
[Batch 300] loss: 1.279
[Batch 400] loss: 1.043
[Batch 500] loss: 1.000
[Batch 600] loss: 1.141
[Batch 700] loss: 1.191
Accuracy of the network on the 10000 test images: 59 %
Epoch 3
-------------------------------
[Batch 100] loss: 1.456
[Batch 200] loss: 0.928
[Batch 300] loss: 0.987
[Batch 400] loss: 1.119
[Batch 500] loss: 1.186
[Batch 600] loss: 1.055
[Batch 700] loss: 0.952
Accuracy of the network on the 10000 test images: 62 %
Epoch 4
-------------------------------
[Batch 100] loss: 0.956
[Batch 200] loss: 0.979
[Batch 300] loss: 0.830
[Batch 400] loss: 1.061
[Batch 500] loss: 0.885
[Batch 600] loss: 0.904
[Batch 700] loss: 0.807
Accuracy of the network on the 10000 test images: 61 %
Epoch 5
-------------------------------
[Batch 100] loss: 0.843
[Batch 200] loss: 0.854
[Batch 300] loss: 0.993
[Batch 400] loss: 1.025
[Batch 500] loss: 0.898
[Batch 600] loss: 1.075
[Batch 700] loss: 0.654
Accuracy of the network on the 10000 test images: 63 %
Epoch 6
-------------------------------
[Batch 100] loss: 0.623
[Batch 200] loss: 0.704
[Batch 300] loss: 0.821
[Batch 400] loss: 1.147
[Batch 500] loss: 0.761
[Batch 600] loss: 1.032
[Batch 700] loss: 0.852
Accuracy of the network on the 10000 test images: 64 %
Epoch 7
-------------------------------
[Batch 100] loss: 0.718
[Batch 200] loss: 0.882
[Batch 300] loss: 0.855
[Batch 400] loss: 0.818
[Batch 500] loss: 0.888
[Batch 600] loss: 0.576
[Batch 700] loss: 0.963
Accuracy of the network on the 10000 test images: 65 %
Epoch 8
-------------------------------
[Batch 100] loss: 0.706
[Batch 200] loss: 0.515
[Batch 300] loss: 0.742
[Batch 400] loss: 0.491
[Batch 500] loss: 0.714
[Batch 600] loss: 0.878
[Batch 700] loss: 0.821
Accuracy of the network on the 10000 test images: 66 %
Epoch 9
-------------------------------
[Batch 100] loss: 0.814
[Batch 200] loss: 0.968
[Batch 300] loss: 0.729
[Batch 400] loss: 0.838
[Batch 500] loss: 0.649
[Batch 600] loss: 0.664
[Batch 700] loss: 0.692
Accuracy of the network on the 10000 test images: 67 %
Epoch 10
-------------------------------
[Batch 100] loss: 0.792
[Batch 200] loss: 0.560
[Batch 300] loss: 0.698
[Batch 400] loss: 0.857
[Batch 500] loss: 0.815
[Batch 600] loss: 0.853
[Batch 700] loss: 0.724
Accuracy of the network on the 10000 test images: 66 %
Done!
相关文章
|
1月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
17天前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
17天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现
|
17天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
|
17天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】16. Pytorch中神经网络模型的构造方法:Module、Sequential、ModuleList、ModuleDict的区别
【从零开始学习深度学习】16. Pytorch中神经网络模型的构造方法:Module、Sequential、ModuleList、ModuleDict的区别
|
1月前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
36 1
|
17天前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
|
17天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍、Pytorch实现GRU并进行训练预测
【从零开始学习深度学习】35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍、Pytorch实现GRU并进行训练预测
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
神经网络基本概念以及Pytorch实现,多线程编程面试题
神经网络基本概念以及Pytorch实现,多线程编程面试题
|
1月前
|
机器学习/深度学习 数据可视化 PyTorch
时空图神经网络ST-GNN的概念以及Pytorch实现
本文介绍了图神经网络(GNN)在处理各种领域中相互关联的图数据时的作用,如分子结构和社交网络。GNN与序列模型(如RNN)结合形成的时空图神经网络(ST-GNN)能捕捉时间和空间依赖性。文章通过图示和代码示例解释了GNN和ST-GNN的基本原理,展示了如何将GNN应用于股票市场的数据,尽管不推荐将其用于实际的股市预测。提供的PyTorch实现展示了如何将时间序列数据转换为图结构并训练ST-GNN模型。
48 1