【32】多教师网络进行联合蒸馏测试

简介: 【32】多教师网络进行联合蒸馏测试

在之前的知识蒸馏的例子中,我搭建了一个大的神经教师网络来对小的神经学生网络进行蒸馏。但是,看完那篇博客的朋友可能知道,其实效果不是很好,最后的效果甚至还不如直接训练学生网络,这成为了一时的心结。ps:这里贴上之前的两篇文章,如下所示:

1. 《知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍》

2. 《【29】知识蒸馏(knowledge distillation)测试以及利用可学习参数辅助知识蒸馏训练Student模型》


为此,我想,一个不够,那就再加一个。而且加的是卷积网络来称为另外的一个教师网络,这样又有卷积的教师网络又有普通的神经网络,可能会带来一点好的提升与性能。


下面话不多说,直接贴上jupyter notebook的测试过程:

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import alexnet
import random
import tqdm
import numpy as np
from thop import profile
# 设置随机种子,保证结果可复现
def SetSeed(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
SetSeed()


1. 定义网络


#  定义一个cnn的教师网络(没有过多考虑参数,仿AlexNet进行搭建的)
class CNN_Teancher(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN_Teancher, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, num_classes),
            nn.Dropout(p=0.5)
        )
    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)
# 定义教师网络与学生网络
class Teacher(nn.Module):
    def __init__(self, num_classes=10):
        super(Teacher, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)
# 定义学生网络
class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 20),
            nn.ReLU(),
            nn.Linear(20, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)


2. 数据集配置


# 设置超参数
epoch_size = 10
batch_size = 128
learning_rate = 1e-4
# 训练集下载
train_data = datasets.MNIST('./', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 测试集下载
test_data = datasets.MNIST('./', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
# 定义模型
device = torch.device('cuda')
t_model = Teacher().to(device)
s_model = Student().to(device)
# 定义优化器与损失
criterion = nn.CrossEntropyLoss().to(device)

服务器上没有mnist数据集,所以这里我得下载一下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
/home/fs/anaconda3/envs/yolo/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


3. 训练与测试


# 训练过程
def train_one_epoch(model, criterion, optimizer, dataloader):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        outputs = model(image)
        loss = criterion(outputs, targets)
        train_loss += loss.item()
        # 反向更新训练
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算正确个数
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)
# 测试过程
def validate(model, criterion, dataloader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    test_info = "Test ==> loss:{}, acc:{} ({}/{})"\
          .format(test_loss/len(dataloader), correct/total, correct, total)
    print(test_info)


3.1 训练教师神经网络

# 定义教师网络的优化器
t_optimizer = optim.Adam(t_model.parameters(), lr=learning_rate)
# 训练教师网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(t_model, criterion, t_optimizer, train_loader)
    validate(t_model, criterion, test_loader)
# 训练好教师模型后,先保存教师网络的模型参数,有需要再重新导入即可
torch.save(t_model.state_dict(), "t_model.mdl")


输出结果:


[Epoch:0]
batch:0/469, loss:2.3867623805999756, acc:0.0703125 (9/128)
batch:93/469, loss:0.9396106711727508, acc:0.7176695478723404 (8635/12032)
batch:186/469, loss:0.6751111858190699, acc:0.7959558823529411 (19052/23936)
batch:279/469, loss:0.560909423657826, acc:0.8306919642857142 (29772/35840)
batch:372/469, loss:0.4910400741141859, acc:0.8519604557640751 (40676/47744)
batch:465/469, loss:0.44459033963698175, acc:0.8668689645922747 (51707/59648)
Test ==> loss:0.18145106082098394, acc:0.9461 (9461/10000)
[Epoch:1]
batch:0/469, loss:0.20238660275936127, acc:0.9296875 (119/128)
batch:93/469, loss:0.2128744441619579, acc:0.9353390957446809 (11254/12032)
batch:186/469, loss:0.21077364432142381, acc:0.9370404411764706 (22429/23936)
batch:279/469, loss:0.20751248164368527, acc:0.937890625 (33614/35840)
batch:372/469, loss:0.20383650300170397, acc:0.938861427613941 (44825/47744)
batch:465/469, loss:0.19661776378526963, acc:0.9409368293991416 (56125/59648)
Test ==> loss:0.1174718544146494, acc:0.9613 (9613/10000)
[Epoch:2]
batch:0/469, loss:0.14401938021183014, acc:0.9296875 (119/128)
batch:93/469, loss:0.15022155472097246, acc:0.9551196808510638 (11492/12032)
batch:186/469, loss:0.15136209516761137, acc:0.9552139037433155 (22864/23936)
batch:279/469, loss:0.14719437270292213, acc:0.9559709821428571 (34262/35840)
batch:372/469, loss:0.14265749993017468, acc:0.9573349530831099 (45707/47744)
batch:465/469, loss:0.14064400804592816, acc:0.9580036212446352 (57143/59648)
Test ==> loss:0.09666164789961863, acc:0.9707 (9707/10000)
[Epoch:3]
batch:0/469, loss:0.0973980501294136, acc:0.96875 (124/128)
batch:93/469, loss:0.12336844725019121, acc:0.9634308510638298 (11592/12032)
batch:186/469, loss:0.11866759833566008, acc:0.9649481951871658 (23097/23936)
batch:279/469, loss:0.11617265337013773, acc:0.9653738839285714 (34599/35840)
batch:372/469, loss:0.114620619051498, acc:0.9653359584450402 (46089/47744)
batch:465/469, loss:0.11367744497571125, acc:0.9656484710300429 (57599/59648)
Test ==> loss:0.08241578724966207, acc:0.9736 (9736/10000)
[Epoch:4]
batch:0/469, loss:0.1517380326986313, acc:0.953125 (122/128)
batch:93/469, loss:0.09908725943495618, acc:0.9703291223404256 (11675/12032)
batch:186/469, loss:0.1005439937792041, acc:0.9687917780748663 (23189/23936)
batch:279/469, loss:0.09571810397984726, acc:0.9707868303571429 (34793/35840)
batch:372/469, loss:0.09503172322029083, acc:0.9709701742627346 (46358/47744)
batch:465/469, loss:0.0932354472793415, acc:0.9714491684549357 (57945/59648)
Test ==> loss:0.07466217042005892, acc:0.9777 (9777/10000)
[Epoch:5]
batch:0/469, loss:0.11435826867818832, acc:0.9453125 (121/128)
batch:93/469, loss:0.07882857798261846, acc:0.9756482712765957 (11739/12032)
batch:186/469, loss:0.07980489694976552, acc:0.9756433823529411 (23353/23936)
batch:279/469, loss:0.07940430943854153, acc:0.9759486607142858 (34978/35840)
batch:372/469, loss:0.07858294055824624, acc:0.9759341487935657 (46595/47744)
batch:465/469, loss:0.07859047989924578, acc:0.9759589592274678 (58214/59648)
Test ==> loss:0.06638586930812726, acc:0.9793 (9793/10000)
[Epoch:6]
batch:0/469, loss:0.041755419224500656, acc:0.9921875 (127/128)
batch:93/469, loss:0.072672658381944, acc:0.9777260638297872 (11764/12032)
batch:186/469, loss:0.07089516308337929, acc:0.9781918449197861 (23414/23936)
batch:279/469, loss:0.07038291455579124, acc:0.9781808035714286 (35058/35840)
batch:372/469, loss:0.07037656083852852, acc:0.9786988941018767 (46727/47744)
batch:465/469, loss:0.07013101020604053, acc:0.9785743025751072 (58370/59648)
Test ==> loss:0.06481085321276531, acc:0.9798 (9798/10000)
[Epoch:7]
batch:0/469, loss:0.08536697179079056, acc:0.9765625 (125/128)
batch:93/469, loss:0.05673933652368315, acc:0.9819647606382979 (11815/12032)
batch:186/469, loss:0.05650453631801003, acc:0.9825785427807486 (23519/23936)
batch:279/469, loss:0.05738637866951259, acc:0.9823381696428571 (35207/35840)
batch:372/469, loss:0.058486481106043584, acc:0.9817987600536193 (46875/47744)
batch:465/469, loss:0.06104880005676827, acc:0.9808040504291845 (58503/59648)
Test ==> loss:0.057878647279583764, acc:0.9822 (9822/10000)
[Epoch:8]
batch:0/469, loss:0.14456196129322052, acc:0.953125 (122/128)
batch:93/469, loss:0.05808326327539188, acc:0.9830452127659575 (11828/12032)
batch:186/469, loss:0.054870715339912134, acc:0.9829127673796791 (23527/23936)
batch:279/469, loss:0.05355608092754015, acc:0.9828125 (35224/35840)
batch:372/469, loss:0.052837840473683846, acc:0.9833277479892761 (46948/47744)
batch:465/469, loss:0.052799447139829016, acc:0.9832517435622318 (58649/59648)
Test ==> loss:0.05802279706054096, acc:0.9823 (9823/10000)
[Epoch:9]
batch:0/469, loss:0.02145610749721527, acc:0.9921875 (127/128)
batch:93/469, loss:0.04951975690795386, acc:0.9847074468085106 (11848/12032)
batch:186/469, loss:0.04819783547464858, acc:0.9857536764705882 (23595/23936)
batch:279/469, loss:0.04857771968735116, acc:0.9849051339285714 (35299/35840)
batch:372/469, loss:0.04838609968838919, acc:0.9847729557640751 (47017/47744)
batch:465/469, loss:0.04821777474155146, acc:0.9848779506437768 (58746/59648)
Test ==> loss:0.051948129977512206, acc:0.9838 (9838/10000)


3.2 训练教师卷积网络

# 构建卷积教师网络
ct_model = CNN_Teancher().to(device)
# 定义cnn教师网络的优化器
ct_optimizer = optim.Adam(ct_model.parameters(), lr=learning_rate)
# 训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(ct_model, criterion, ct_optimizer, train_loader)
    validate(ct_model, criterion, test_loader)
# 训练好教师模型后,先保存教师网络的模型参数,有需要再重新导入即可
torch.save(ct_model.state_dict(), "ct_model.mdl")


输出结果:

[Epoch:0]
batch:0/469, loss:2.424455404281616, acc:0.0859375 (11/128)
batch:93/469, loss:1.961557946306594, acc:0.3346908244680851 (4027/12032)
batch:186/469, loss:1.7907984537236832, acc:0.3976019385026738 (9517/23936)
batch:279/469, loss:1.6823635867663793, acc:0.4309151785714286 (15444/35840)
batch:372/469, loss:1.5959120215423626, acc:0.4554289544235925 (21744/47744)
batch:465/469, loss:1.5276136935524673, acc:0.4724718347639485 (28182/59648)
Test ==> loss:0.7926474839826173, acc:0.93 (9300/10000)
[Epoch:1]
batch:0/469, loss:1.2736332416534424, acc:0.546875 (70/128)
batch:93/469, loss:1.2054280004602798, acc:0.5482047872340425 (6596/12032)
batch:186/469, loss:1.172185863721817, acc:0.5583639705882353 (13365/23936)
batch:279/469, loss:1.157212756574154, acc:0.559375 (20048/35840)
batch:372/469, loss:1.1450653734539533, acc:0.5607406166219839 (26772/47744)
batch:465/469, loss:1.1353451458681294, acc:0.5617791040772532 (33509/59648)
Test ==> loss:0.47211516431615325, acc:0.9597 (9597/10000)
[Epoch:2]
batch:0/469, loss:1.0209004878997803, acc:0.59375 (76/128)
batch:93/469, loss:1.0488138236898057, acc:0.58203125 (7003/12032)
batch:186/469, loss:1.0366387905921528, acc:0.5831801470588235 (13959/23936)
batch:279/469, loss:1.0379441861595426, acc:0.5789620535714286 (20750/35840)
batch:372/469, loss:1.03275924745258, acc:0.5796539879356568 (27675/47744)
batch:465/469, loss:1.0288827585560059, acc:0.5789800160944206 (34535/59648)
Test ==> loss:0.332201937331429, acc:0.9632 (9632/10000)
[Epoch:3]
batch:0/469, loss:0.9517718553543091, acc:0.6171875 (79/128)
batch:93/469, loss:0.9978312501247893, acc:0.5810339095744681 (6991/12032)
batch:186/469, loss:1.0023587601070099, acc:0.5776654411764706 (13827/23936)
batch:279/469, loss:0.9987784639000893, acc:0.5787667410714286 (20743/35840)
batch:372/469, loss:0.9959872414535236, acc:0.5788790214477212 (27638/47744)
batch:465/469, loss:0.991175599492159, acc:0.5813606491416309 (34677/59648)
Test ==> loss:0.2696448630547222, acc:0.9666 (9666/10000)
[Epoch:4]
batch:0/469, loss:0.7853888869285583, acc:0.6875 (88/128)
batch:93/469, loss:0.9590603220970073, acc:0.5880152925531915 (7075/12032)
batch:186/469, loss:0.9617624403958652, acc:0.5840157085561497 (13979/23936)
batch:279/469, loss:0.9641852915287018, acc:0.5849330357142857 (20964/35840)
batch:372/469, loss:0.9647335381354467, acc:0.5839896112600537 (27882/47744)
batch:465/469, loss:0.9653625727467271, acc:0.583506571888412 (34805/59648)
Test ==> loss:0.2166517706988733, acc:0.9722 (9722/10000)
[Epoch:5]
batch:0/469, loss:0.8839341402053833, acc:0.59375 (76/128)
batch:93/469, loss:0.9494094988133045, acc:0.5852726063829787 (7042/12032)
batch:186/469, loss:0.9495056312989424, acc:0.5849348262032086 (14001/23936)
batch:279/469, loss:0.9469595074653625, acc:0.5874720982142857 (21055/35840)
batch:372/469, loss:0.9451027883281656, acc:0.5875502680965148 (28052/47744)
batch:465/469, loss:0.9445154733412255, acc:0.5877648873390557 (35059/59648)
Test ==> loss:0.1983407870689525, acc:0.9739 (9739/10000)
[Epoch:6]
batch:0/469, loss:0.9596455693244934, acc:0.59375 (76/128)
batch:93/469, loss:0.9371558269287678, acc:0.5885970744680851 (7082/12032)
batch:186/469, loss:0.933486777193406, acc:0.5869401737967914 (14049/23936)
batch:279/469, loss:0.9356063387223652, acc:0.5876953125 (21063/35840)
batch:372/469, loss:0.9321824889080774, acc:0.5887231903485255 (28108/47744)
batch:465/469, loss:0.9308745358379102, acc:0.588955203862661 (35130/59648)
Test ==> loss:0.15833796409866477, acc:0.975 (9750/10000)
[Epoch:7]
batch:0/469, loss:0.9392691254615784, acc:0.609375 (78/128)
batch:93/469, loss:0.9271161625994012, acc:0.5857712765957447 (7048/12032)
batch:186/469, loss:0.9292944392418478, acc:0.5856450534759359 (14018/23936)
batch:279/469, loss:0.9243043863347599, acc:0.5865234375 (21021/35840)
batch:372/469, loss:0.9234551397469344, acc:0.5877178284182306 (28060/47744)
batch:465/469, loss:0.9230064044950346, acc:0.5870775214592274 (35018/59648)
Test ==> loss:0.14993384857720968, acc:0.9756 (9756/10000)
[Epoch:8]
batch:0/469, loss:0.963413655757904, acc:0.5703125 (73/128)
batch:93/469, loss:0.9203153658420482, acc:0.5882646276595744 (7078/12032)
batch:186/469, loss:0.9135349348267132, acc:0.5940006684491979 (14218/23936)
batch:279/469, loss:0.9185146927833557, acc:0.5903459821428572 (21158/35840)
batch:372/469, loss:0.9155782517095673, acc:0.5907548592493298 (28205/47744)
batch:465/469, loss:0.9142030574733095, acc:0.5917884924892703 (35299/59648)
Test ==> loss:0.142572170005569, acc:0.9755 (9755/10000)
[Epoch:9]
batch:0/469, loss:0.8696589469909668, acc:0.6171875 (79/128)
batch:93/469, loss:0.9138033472477122, acc:0.5925864361702128 (7130/12032)
batch:186/469, loss:0.9100717417696581, acc:0.5940842245989305 (14220/23936)
batch:279/469, loss:0.9118563358272825, acc:0.5920479910714286 (21219/35840)
batch:372/469, loss:0.9089114867330557, acc:0.592912198391421 (28308/47744)
batch:465/469, loss:0.9042545945859263, acc:0.5938170600858369 (35420/59648)
Test ==> loss:0.13352179668749434, acc:0.9753 (9753/10000)


3.3 训练学生神经网络

# 构建学生网络
s_model = Student().to(device)
# 定义学生网络的优化器
s_optimizer = optim.Adam(s_model.parameters(), lr=learning_rate)
# 训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(s_model, criterion, s_optimizer, train_loader)
    validate(s_model, criterion, test_loader)


输出结果:


[Epoch:0]
batch:0/469, loss:2.357004404067993, acc:0.1875 (24/128)
batch:93/469, loss:1.7883413776438286, acc:0.4625166223404255 (5565/12032)
batch:186/469, loss:1.4309818958216172, acc:0.5944184491978609 (14228/23936)
batch:279/469, loss:1.1971364739750112, acc:0.6667131696428571 (23895/35840)
batch:372/469, loss:1.0451311646453816, acc:0.7116915214477212 (33979/47744)
batch:465/469, loss:0.9389211163884069, acc:0.7416845493562232 (44240/59648)
Test ==> loss:0.46114580276646194, acc:0.8833 (8833/10000)
[Epoch:1]
batch:0/469, loss:0.5253818035125732, acc:0.84375 (108/128)
batch:93/469, loss:0.45607277180286165, acc:0.8744182180851063 (10521/12032)
batch:186/469, loss:0.44302085488237797, acc:0.8799715909090909 (21063/23936)
batch:279/469, loss:0.42559995853475163, acc:0.8859095982142857 (31751/35840)
batch:372/469, loss:0.4132335932261183, acc:0.8879649798927614 (42395/47744)
batch:465/469, loss:0.4041708303943212, acc:0.8905914699570815 (53122/59648)
Test ==> loss:0.34390119095391863, acc:0.9063 (9063/10000)
[Epoch:2]
batch:0/469, loss:0.27973929047584534, acc:0.9140625 (117/128)
batch:93/469, loss:0.3445089725737876, acc:0.9063331117021277 (10905/12032)
batch:186/469, loss:0.33749938792085904, acc:0.9072526737967914 (21716/23936)
batch:279/469, loss:0.3341189743684871, acc:0.9078125 (32536/35840)
batch:372/469, loss:0.33311483519646184, acc:0.9080722184986595 (43355/47744)
batch:465/469, loss:0.32733086157638114, acc:0.9092509388412017 (54235/59648)
Test ==> loss:0.29618216881269144, acc:0.9179 (9179/10000)
[Epoch:3]
batch:0/469, loss:0.3462095856666565, acc:0.8984375 (115/128)
batch:93/469, loss:0.3092869266550592, acc:0.9137300531914894 (10994/12032)
batch:186/469, loss:0.30020128891748543, acc:0.9171958556149733 (21954/23936)
batch:279/469, loss:0.2974837499537638, acc:0.9169084821428571 (32862/35840)
batch:372/469, loss:0.29631629350517774, acc:0.9173927613941019 (43800/47744)
batch:465/469, loss:0.29146103778366367, acc:0.9190584763948498 (54820/59648)
Test ==> loss:0.2715968393449542, acc:0.9234 (9234/10000)
[Epoch:4]
batch:0/469, loss:0.3141135573387146, acc:0.890625 (114/128)
batch:93/469, loss:0.2594475950649444, acc:0.9300199468085106 (11190/12032)
batch:186/469, loss:0.26877123388377105, acc:0.9263870320855615 (22174/23936)
batch:279/469, loss:0.27071290723979474, acc:0.9251674107142858 (33158/35840)
batch:372/469, loss:0.2679654088560442, acc:0.9255403820375335 (44189/47744)
batch:465/469, loss:0.2687629308119864, acc:0.9247250536480687 (55158/59648)
Test ==> loss:0.25750901823556877, acc:0.9282 (9282/10000)
[Epoch:5]
batch:0/469, loss:0.17168357968330383, acc:0.96875 (124/128)
batch:93/469, loss:0.2522462420165539, acc:0.9288563829787234 (11176/12032)
batch:186/469, loss:0.2581011697171844, acc:0.9273897058823529 (22198/23936)
batch:279/469, loss:0.2550680588132569, acc:0.9282924107142857 (33270/35840)
batch:372/469, loss:0.2549391207722173, acc:0.9280956769436998 (44311/47744)
batch:465/469, loss:0.25167093422791476, acc:0.9286145386266095 (55390/59648)
Test ==> loss:0.24905249646192865, acc:0.9306 (9306/10000)
[Epoch:6]
batch:0/469, loss:0.26825278997421265, acc:0.90625 (116/128)
batch:93/469, loss:0.2381658841796378, acc:0.9311003989361702 (11203/12032)
batch:186/469, loss:0.23873208303821278, acc:0.9309408422459893 (22283/23936)
batch:279/469, loss:0.24032189510762691, acc:0.9308035714285714 (33360/35840)
batch:372/469, loss:0.23933973307184495, acc:0.9313421581769437 (44466/47744)
batch:465/469, loss:0.23892121092124557, acc:0.931749597639485 (55577/59648)
Test ==> loss:0.23393861119505727, acc:0.9336 (9336/10000)
[Epoch:7]
batch:0/469, loss:0.3488951325416565, acc:0.921875 (118/128)
batch:93/469, loss:0.23248760267458063, acc:0.9334275265957447 (11231/12032)
batch:186/469, loss:0.22965810611286266, acc:0.9336981951871658 (22349/23936)
batch:279/469, loss:0.22631887934569803, acc:0.9351841517857142 (33517/35840)
batch:372/469, loss:0.22711910750846762, acc:0.9349028150134048 (44636/47744)
batch:465/469, loss:0.2285995162991495, acc:0.9348175965665236 (55760/59648)
Test ==> loss:0.22850063462046127, acc:0.9344 (9344/10000)
[Epoch:8]
batch:0/469, loss:0.2651296854019165, acc:0.90625 (116/128)
batch:93/469, loss:0.22513935580215555, acc:0.9361702127659575 (11264/12032)
batch:186/469, loss:0.22959564474814717, acc:0.9356199866310161 (22395/23936)
batch:279/469, loss:0.2243422551612769, acc:0.9376116071428572 (33604/35840)
batch:372/469, loss:0.2194721238742565, acc:0.9387567024128687 (44820/47744)
batch:465/469, loss:0.22032797501258583, acc:0.9381538358369099 (55959/59648)
Test ==> loss:0.2206512165220478, acc:0.9351 (9351/10000)
[Epoch:9]
batch:0/469, loss:0.2756182849407196, acc:0.921875 (118/128)
batch:93/469, loss:0.2252512621752759, acc:0.9366688829787234 (11270/12032)
batch:186/469, loss:0.22236462188436384, acc:0.936956885026738 (22427/23936)
batch:279/469, loss:0.2201754414077316, acc:0.9374441964285715 (33598/35840)
batch:372/469, loss:0.21374527243123298, acc:0.9391756032171582 (44840/47744)
batch:465/469, loss:0.21209755683789439, acc:0.9397465128755365 (56054/59648)
Test ==> loss:0.21237921054604686, acc:0.9375 (9375/10000)

现在有两个教师网络,一个是正确率为0.9838的神经网络,一个是正确率为0.9753的卷积神经网络;如果简单的训练学生网络,正确率为0.9375 ,那么现在考虑的内容是两个教师网络对学生网络进行训练。


3.4 多教师联合蒸馏训练

下面重新构建多教师网络知识蒸馏的训练函数:


# 训练过程
def train_one_epoch_kd(s_model, t_model, ct_model, hard_loss, soft_loss, optimizer, parms_optimizer, dataloader):
    s_model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        # 教师网络预测
        with torch.no_grad():
            teacher_1_preds = t_model(image)
            teacher_2_preds = ct_model(image)
        # 学生模型预测
        student_preds = s_model(image)
        # 计算与真实标签的损失:hard loss
        student_loss = hard_loss(student_preds, targets)
        # 计算与教师输出的损失:soft loss
        ditillation_1_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_1_preds / temp, dim=1)
        )
        ditillation_2_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_2_preds / temp, dim=1)
        )
        # 总损失即为:hard loss与soft loss的加权和
        # 自行选择是否设置可学习参数进行知识蒸馏
#         loss = alpha * ditillation_1_loss + \
#                 gama * ditillation_2_loss + \
#                 (1 - alpha - gama) * student_loss
        loss = 0.2 * ditillation_1_loss + \
                0.2 * ditillation_2_loss + \
                0.6 * student_loss
        train_loss += loss.item()
        # 反向更新训练
        optimizer.zero_grad()
#         parms_optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         parms_optimizer.step()
        # 计算正确个数
        _, predicted = student_preds.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)


联合蒸馏训练

# 设置知识蒸馏的超参数
temp = 7      # 蒸馏温度
alpha = nn.Parameter(torch.tensor(0.35), requires_grad=True)   # 权重系数
gama  = nn.Parameter(torch.tensor(0.35), requires_grad=True)   # 权重系数
params = [alpha, gama]     # 设置两个可学习参数
# params = [alpha, ]           # 设置一个可学习参数
print(params)
# 准备新的学生模型的损失函数
hard_loss = nn.CrossEntropyLoss()                # 包含softmax操作
soft_loss = nn.KLDivLoss(reduction="batchmean")  # 不包含softmax操作(所以可以自己设定温度系数)
# 构建蒸馏学生模型
kd_model = Student().to(device)
# 构建蒸馏模型的优化器
kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate)
parms_optimizer = optim.Adam(params, lr=3e-5)
# 利用知识蒸馏来训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]\n".format(epoch), 
          "[weight | alpha:{},gama:{},st:{}]"
              .format(alpha, gama, 1-alpha-gama))
    train_one_epoch_kd(kd_model, t_model, ct_model, hard_loss, soft_loss, kd_optimizer, parms_optimizer, train_loader)
    validate(kd_model, criterion, test_loader)


输出结果:


[Parameter containing:
tensor(0.2000, requires_grad=True), Parameter containing:
tensor(0.2000, requires_grad=True)]
[Epoch:0]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:1.070181131362915, acc:0.1484375 (19/128)
batch:93/469, loss:0.7189208091573512, acc:0.3718417553191489 (4474/12032)
batch:186/469, loss:0.4570929299701344, acc:0.5301637700534759 (12690/23936)
batch:279/469, loss:0.2827999550317015, acc:0.6219587053571428 (22291/35840)
batch:372/469, loss:0.16214995130137527, acc:0.6787449731903485 (32406/47744)
batch:465/469, loss:0.07729014530458164, acc:0.7165034871244635 (42738/59648)
Test ==> loss:0.4550707121438618, acc:0.883 (8830/10000)
[Epoch:1]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.21972870826721191, acc:0.7890625 (101/128)
batch:93/469, loss:-0.3207302990745991, acc:0.8866356382978723 (10668/12032)
batch:186/469, loss:-0.3327708426006338, acc:0.8868231951871658 (21227/23936)
batch:279/469, loss:-0.34525985323957037, acc:0.8899832589285714 (31897/35840)
batch:372/469, loss:-0.3539618340797782, acc:0.891504691689008 (42564/47744)
batch:465/469, loss:-0.36126882240751784, acc:0.8931732832618026 (53276/59648)
Test ==> loss:0.320476306578781, acc:0.9118 (9118/10000)
[Epoch:2]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.42595481872558594, acc:0.8984375 (115/128)
batch:93/469, loss:-0.4049103111028671, acc:0.9043384308510638 (10881/12032)
batch:186/469, loss:-0.4099318265596176, acc:0.9061246657754011 (21689/23936)
batch:279/469, loss:-0.41254837651337894, acc:0.9076450892857143 (32530/35840)
batch:372/469, loss:-0.4177393170208458, acc:0.9097687667560321 (43436/47744)
batch:465/469, loss:-0.4201628497997579, acc:0.91037419527897 (54302/59648)
Test ==> loss:0.2765346652344812, acc:0.9203 (9203/10000)
[Epoch:3]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.47607481479644775, acc:0.9375 (120/128)
batch:93/469, loss:-0.4405997947175452, acc:0.9193816489361702 (11062/12032)
batch:186/469, loss:-0.44154514938114797, acc:0.9194100935828877 (22007/23936)
batch:279/469, loss:-0.44397371081369263, acc:0.9194754464285714 (32954/35840)
batch:372/469, loss:-0.4449578410179302, acc:0.919110254691689 (43882/47744)
batch:465/469, loss:-0.44588555446765965, acc:0.9195111319742489 (54847/59648)
Test ==> loss:0.25252004533628875, acc:0.9255 (9255/10000)
[Epoch:4]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.46281740069389343, acc:0.9296875 (119/128)
batch:93/469, loss:-0.4547458785645505, acc:0.921376329787234 (11086/12032)
batch:186/469, loss:-0.46090967977110714, acc:0.9243816844919787 (22126/23936)
batch:279/469, loss:-0.4611783997288772, acc:0.9245814732142857 (33137/35840)
batch:372/469, loss:-0.46058367931810845, acc:0.9252262064343163 (44174/47744)
batch:465/469, loss:-0.4612702508596903, acc:0.9251944742489271 (55186/59648)
Test ==> loss:0.23789097774255125, acc:0.9292 (9292/10000)
[Epoch:5]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.48149824142456055, acc:0.9375 (120/128)
batch:93/469, loss:-0.47065415629681123, acc:0.9292719414893617 (11181/12032)
batch:186/469, loss:-0.46971190995711054, acc:0.9272225935828877 (22194/23936)
batch:279/469, loss:-0.4715161386345114, acc:0.928515625 (33278/35840)
batch:372/469, loss:-0.472420722326069, acc:0.9290591487935657 (44357/47744)
batch:465/469, loss:-0.47325622069733336, acc:0.9294527896995708 (55440/59648)
Test ==> loss:0.22742461035900477, acc:0.9326 (9326/10000)
[Epoch:6]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.5146299600601196, acc:0.96875 (124/128)
batch:93/469, loss:-0.47914040120358165, acc:0.932845744680851 (11224/12032)
batch:186/469, loss:-0.4779144224317316, acc:0.931442179144385 (22295/23936)
batch:279/469, loss:-0.48003286900264874, acc:0.9315848214285715 (33388/35840)
batch:372/469, loss:-0.48055645328104973, acc:0.9321171246648794 (44503/47744)
batch:465/469, loss:-0.48185253974705805, acc:0.933074034334764 (55656/59648)
Test ==> loss:0.22472369784041296, acc:0.9336 (9336/10000)
[Epoch:7]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.44770073890686035, acc:0.90625 (116/128)
batch:93/469, loss:-0.4870964447234539, acc:0.9355053191489362 (11256/12032)
batch:186/469, loss:-0.48533328188294395, acc:0.9351186497326203 (22383/23936)
batch:279/469, loss:-0.48802188954183034, acc:0.9361607142857142 (33552/35840)
batch:372/469, loss:-0.4895784726251546, acc:0.9360338471849866 (44690/47744)
batch:465/469, loss:-0.4890587088759877, acc:0.9358570278969958 (55822/59648)
Test ==> loss:0.2101580535026291, acc:0.9362 (9362/10000)
[Epoch:8]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.4669145345687866, acc:0.9140625 (117/128)
batch:93/469, loss:-0.4905587165279591, acc:0.9385804521276596 (11293/12032)
batch:186/469, loss:-0.4956321018264893, acc:0.9393382352941176 (22484/23936)
batch:279/469, loss:-0.4955592581204006, acc:0.9392299107142857 (33662/35840)
batch:372/469, loss:-0.4950189735870259, acc:0.9388195375335121 (44823/47744)
batch:465/469, loss:-0.49565435664592383, acc:0.9390423819742489 (56012/59648)
Test ==> loss:0.20754675633167918, acc:0.9388 (9388/10000)
[Epoch:9]
 [weight | alpha:0.20000000298023224,gama:0.20000000298023224,st:0.6000000238418579]
batch:0/469, loss:-0.5581856966018677, acc:0.984375 (126/128)
batch:93/469, loss:-0.5014005094132525, acc:0.9423204787234043 (11338/12032)
batch:186/469, loss:-0.49853202844048566, acc:0.9412182486631016 (22529/23936)
batch:279/469, loss:-0.5006332156913621, acc:0.9414341517857143 (33741/35840)
batch:372/469, loss:-0.5006915050441394, acc:0.9413329423592494 (44943/47744)
batch:465/469, loss:-0.5015091910126895, acc:0.9409703594420601 (56127/59648)
Test ==> loss:0.20188309327711032, acc:0.9411 (9411/10000)


最后,联合训练的效果为0.9411 ,要比单独训练学生模型的效果0.9375要好。


可以看见,经过合适的调参,联合蒸馏训练的效果确实比单独训练学生网络的效果要好。不过前提就是要调好参数。调参的问题是一个血泪的问题,要是调不好,其实并不能发挥出知识蒸馏这个方法的性能,有时往往会出现没有效果的情况,就是用了效果反而还下降了。


而且,这里我本来使用了两个可学习的权重参数分配给教师卷积网络与教师神经网络,但是效果其实并不会比固定权重的训练效果要来得好。经过测试,两个教师网络分配0.2左右的权重效果是最好的。


经过这次的总结,感叹调参是个技术活,深度学习确实像是炼丹一样。


目录
相关文章
|
26天前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
69 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
1月前
|
安全 网络安全
Kali渗透测试:使用Armitage扫描网络
Kali渗透测试:使用Armitage扫描网络
|
3月前
|
SQL 安全 测试技术
网络安全的屏障与钥匙:漏洞防护与加密技术解析软件测试的艺术:探索性测试的力量
【8月更文挑战第27天】在数字时代的海洋中,网络安全是保护我们数据资产的灯塔和堤坝。本文将深入浅出地探讨网络安全领域的关键要素——安全漏洞、加密技术以及不可或缺的安全意识。通过实际案例分析,我们将了解如何识别和修补潜在的安全漏洞,掌握现代加密技术的工作原理,并培养起一道坚固的安全防线。文章旨在为读者提供实用的知识和技能,以便在日益复杂的网络环境中保持警惕,确保个人及组织信息的安全。
|
3月前
|
机器学习/深度学习
神经网络与深度学习---验证集(测试集)准确率高于训练集准确率的原因
本文分析了神经网络中验证集(测试集)准确率高于训练集准确率的四个可能原因,包括数据集大小和分布不均、模型正则化过度、批处理后准确率计算时机不同,以及训练集预处理过度导致分布变化。
|
2天前
|
机器学习/深度学习 自然语言处理 前端开发
前端神经网络入门:Brain.js - 详细介绍和对比不同的实现 - CNN、RNN、DNN、FFNN -无需准备环境打开浏览器即可测试运行-支持WebGPU加速
本文介绍了如何使用 JavaScript 神经网络库 **Brain.js** 实现不同类型的神经网络,包括前馈神经网络(FFNN)、深度神经网络(DNN)和循环神经网络(RNN)。通过简单的示例和代码,帮助前端开发者快速入门并理解神经网络的基本概念。文章还对比了各类神经网络的特点和适用场景,并简要介绍了卷积神经网络(CNN)的替代方案。
|
8天前
|
编解码 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(10-2):保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali——Liinux-Debian:就怕你学成黑客啦!)作者——LJS
保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali以及常见的报错及对应解决方案、常用Kali功能简便化以及详解如何具体实现
|
25天前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
42 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
8天前
|
人工智能 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS
|
8天前
|
安全 大数据 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(3-2):渗透测试行业术语扫盲)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(3-2):渗透测试行业术语扫盲)作者——LJS
|
8天前
|
SQL 安全 网络协议
网络空间安全之一个WH的超前沿全栈技术深入学习之路(1-2):渗透测试行业术语扫盲)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(1-2):渗透测试行业术语扫盲)作者——LJS