【29】知识蒸馏(knowledge distillation)测试以及利用可学习参数辅助知识蒸馏训练Student模型

简介: 【29】知识蒸馏(knowledge distillation)测试以及利用可学习参数辅助知识蒸馏训练Student模型

1. Temperature Control


在我的上一篇文章中:知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍,介绍了知识蒸馏的温度控制,这里展示一下不同的温度对logits带来的进行,进行一个具体的可视化展示不同温度之间的区别。


import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# 设置类别输出的logit
logits = np.array([-5, 2, 7, 9])
labels = ['cat', 'dog', 'donkey', 'horse']


  • 普通的softmax(T = 1)

对于普通的softmax函数来说,对某个类别经过softmax的计算公式为:

image.png

# 计算普通的softmax函数已经以及绘图
# 其实普通的softmax函数就是T=1时的情况
softmax_1 = np.exp(logits) / sum(np.exp(logits))
plt.plot(softmax_1, label="softmax_1")
plt.legend()
plt.show()

image.png

  • 知识蒸馏的softmax(T = k)

对于知识蒸馏的温度系数,一般会大于1,当T越大时输出的类别概率越平滑;而当T越小时输出的类别概率差别越大,曲线越尖锐。

对于知识蒸馏的softmax函数来说,对某个类别经过softmax的计算公式为:

image.png

# 设置不同的温度系数来展示最后输出概率的区别
T = 0.6
softmax_06 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_06, label="softmax_06")
T = 0.8
softmax_08 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_08, label="softmax_08")
# 保留softmax_1以供不同T的对比
softmax_1 = np.exp(logits) / sum(np.exp(logits))
plt.plot(softmax_1, label="softmax_1")
T = 3
softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_3, label="softmax_3")
T = 5
softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_3, label="softmax_3")
T = 10
softmax_10 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_10, label="softmax_10")
T = 100
softmax_100 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_100, label="softmax_100")
plt.xticks(np.arange(4), labels=labels)
plt.legend()
plt.show()

image.png


2. Learnable Parameters


image.png

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# 设置随机种子,保证结果可复现
def SetSeed(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
SetSeed()
# 定义一个绘图函数,每隔一定的epoch就画出数据点与直线的关系
def Drawing(points, params, figsize=(8, 4)):
    # points: [[x1,y1],[x2,y2]...]
    # params: [k, b]
    k = params[0].item()
    b = params[1].item()
    x = np.linspace(-5.,5.,100)
    y = k*x+b
    plt.figure(figsize=figsize)
    # 根据points:画出数据点的分布情况
    plt.scatter(points[:, 0], points[:, 1], marker="^", c="blue")
    # 根据params:画出直线的拟合情况
    plt.plot(x, y, c="red")
    # 设置图表格式
    plt.title("line fit")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()
    plt.close()
# 设置一些训练的参数
epoch = 40
learning_rate = 3e-2
# 准备好数据
points = torch.tensor([[-1., -1.], [3., 2.]], dtype=torch.float32)
targets = point[:, 1]
inputs = point[:, 0]
print("inputs:",inputs, "targets:",targets)
# 设置两个可学习参数
k = nn.Parameter(torch.randn(1), requires_grad=True)
b = nn.Parameter(torch.randn(1), requires_grad=True)
params = [k, b]
print(params)
# 设置优化器与损失函数
optimizer = optim.Adam(params, lr=learning_rate)
criterion = nn.MSELoss()
# 训练两个参数
loss_lists = []
for i in range(epoch):
    optimizer.zero_grad()
    outputs = inputs*k + b
    loss = criterion(outputs, targets)
    loss_lists.append(loss.item())
    loss.backward()
    optimizer.step()
    if (i+1) % 4 == 0:
        Drawing(points, [k,b])
#         print("outputs:",outputs)
#         print("k:", k)
#         print("b:", b)
# 查看训练后的参数
print("k:", k)
print("b:", b)


输出:

# 展示图片前的输出结果
inputs: tensor([-1.,  3.]) targets: tensor([-1.,  2.])
[Parameter containing:
tensor([-0.0223], requires_grad=True), Parameter containing:
tensor([0.3827], requires_grad=True)]
# 展示完图片后的输出结果
k: Parameter containing:
tensor([0.7866], requires_grad=True)
b: Parameter containing:
tensor([-0.3185], requires_grad=True)

image.png

image.png

image.png

image.png

image.png


3. Knowledge Distillation


这里我的想法是通过搭建两个神经网络,一个大网络一个小网络,查看小网络知识蒸馏前后的效果。ps:这里的大神经网络模型也可以有CNN模型替换

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import tqdm
from thop import profile
# 定义教师网络与学生网络
class Teacher(nn.Module):
    def __init__(self, num_classes=10):
        super(Teacher, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)
class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 20),
            nn.ReLU(),
#             nn.Linear(20, 20),
#             nn.ReLU(),
            nn.Linear(20, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)
# 测试网络的参数量与浮点计算量
def test_flops_params():
    x = torch.rand([1, 1, 28, 28])
#     x = x.reshape(1, 1, -1)
    T_network = Teacher()
    S_network = Student()
    t_flops, t_params = profile(T_network, inputs=(x, ))
    print('t_flops:{}e-3G'.format(t_flops / 1000000), 't_params:{}M'.format(t_params / 1000000))
    s_flops, s_params = profile(S_network, inputs=(x, ))
    print('s_flops:{}e-3G'.format(s_flops / 1000000), 's_params:{}M'.format(s_params / 1000000))
test_flops_params()


输出结果:可以看见教师网络的参数量和浮点计算量都是学生网络的上百倍

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.Teacher'>. Treat it as zero Macs and zero Params.[00m
t_flops:2.3928e-3G t_params:2.39521M
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.Student'>. Treat it as zero Macs and zero Params.[00m
s_flops:0.01588e-3G s_params:0.01591M
# 设置超参数
epoch_size = 5
batch_size = 128
learning_rate = 1e-4
# 训练集下载
train_data = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 测试集下载
test_data = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
# 定义模型
device = torch.device('cuda')
t_model = Teacher().to(device)
s_model = Student().to(device)
# 定义优化器与损失
criterion = nn.CrossEntropyLoss().to(device)
# 训练过程
def train_one_epoch(model, criterion, optimizer, dataloader):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        outputs = model(image)
        loss = criterion(outputs, targets)
        train_loss += loss.item()
        # 反向更新训练
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算正确个数
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)
# 测试过程
def validate(model, criterion, dataloader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    test_info = "Test ==> loss:{}, acc:{} ({}/{})"\
          .format(test_loss/len(dataloader), correct/total, correct, total)
    print(test_info)


查看教师网络的效果

# 定义教师网络的优化器
t_optimizer = optim.Adam(t_model.parameters(), lr=learning_rate)
# 训练教师网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(t_model, criterion, t_optimizer, train_loader)
    validate(t_model, criterion, test_loader)
# 训练好教师模型后,先保存教师网络的模型参数,有需要再重新导入即可
torch.save(t_model.state_dict(), "t_model.mdl")


训练教师网络的输出结果:

[Epoch:0]
batch:0/469, loss:2.451192855834961, acc:0.046875 (6/128)
batch:93/469, loss:0.9547713903036523, acc:0.7139295212765957 (8590/12032)
batch:186/469, loss:0.6837506957232633, acc:0.7931149732620321 (18984/23936)
batch:279/469, loss:0.5656270824904953, acc:0.8287388392857142 (29702/35840)
batch:372/469, loss:0.49507016175234286, acc:0.8506618632707775 (40614/47744)
batch:465/469, loss:0.44825181595245656, acc:0.864538626609442 (51568/59648)
Test ==> loss:0.1785062526977515, acc:0.9461 (9461/10000)
[Epoch:1]
batch:0/469, loss:0.2513591945171356, acc:0.9375 (120/128)
batch:93/469, loss:0.22307029120782587, acc:0.9340093085106383 (11238/12032)
batch:186/469, loss:0.2112935145988184, acc:0.9371657754010695 (22432/23936)
batch:279/469, loss:0.2037411120853254, acc:0.9394252232142857 (33669/35840)
batch:372/469, loss:0.19692633960186318, acc:0.9414167225201072 (44947/47744)
batch:465/469, loss:0.19179708627352388, acc:0.9429821620171673 (56247/59648)
Test ==> loss:0.11782180916376506, acc:0.9622 (9622/10000)
[Epoch:2]
batch:0/469, loss:0.15417274832725525, acc:0.9453125 (121/128)
batch:93/469, loss:0.14990339277589576, acc:0.9549534574468085 (11490/12032)
batch:186/469, loss:0.14525708044196833, acc:0.9562583556149733 (22889/23936)
batch:279/469, loss:0.14779337254752006, acc:0.9551060267857143 (34231/35840)
batch:372/469, loss:0.1445239940433496, acc:0.9564133713136729 (45663/47744)
batch:465/469, loss:0.14156085961846068, acc:0.9572659603004292 (57099/59648)
Test ==> loss:0.09912304252480404, acc:0.9695 (9695/10000)
[Epoch:3]
batch:0/469, loss:0.09023044258356094, acc:0.984375 (126/128)
batch:93/469, loss:0.11060939039638702, acc:0.9670877659574468 (11636/12032)
batch:186/469, loss:0.11260852741605458, acc:0.9668699866310161 (23143/23936)
batch:279/469, loss:0.11275576776159661, acc:0.9667410714285715 (34648/35840)
batch:372/469, loss:0.11253649257023597, acc:0.9668440013404825 (46161/47744)
batch:465/469, loss:0.11281515193839314, acc:0.9665873122317596 (57655/59648)
Test ==> loss:0.0813662743231258, acc:0.9734 (9734/10000)
[Epoch:4]
batch:0/469, loss:0.10590803623199463, acc:0.9765625 (125/128)
batch:93/469, loss:0.0938354417523171, acc:0.9718251329787234 (11693/12032)
batch:186/469, loss:0.09741261341474591, acc:0.9707971256684492 (23237/23936)
batch:279/469, loss:0.0959280665631273, acc:0.9712332589285714 (34809/35840)
batch:372/469, loss:0.09434855888140745, acc:0.9716823056300268 (46392/47744)
batch:465/469, loss:0.09377776978481481, acc:0.9719521190987125 (57975/59648)
Test ==> loss:0.07517792291562014, acc:0.975 (9750/10000)


以上过程,已经成功的训练出了一个参数量稍大的教师模型,那么下面看看如果只是简单的训练一个学生网络的效果会是怎么样的呢?


训练学生网络的输出结果:

# 定义学生网络的优化器
s_optimizer = optim.Adam(s_model.parameters(), lr=learning_rate)
# 训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(s_model, criterion, s_optimizer, train_loader)
    validate(s_model, criterion, test_loader)
[Epoch:0]
batch:0/469, loss:2.435654878616333, acc:0.0859375 (11/128)
batch:93/469, loss:1.8272213606124228, acc:0.4320146276595745 (5198/12032)
batch:186/469, loss:1.431586041170008, acc:0.59295621657754 (14193/23936)
batch:279/469, loss:1.1955812790564129, acc:0.6700892857142857 (24016/35840)
batch:372/469, loss:1.0443579082354784, acc:0.7167392761394102 (34220/47744)
batch:465/469, loss:0.9376831678233944, acc:0.7478205472103004 (44606/59648)
Test ==> loss:0.46302026283891895, acc:0.8882 (8882/10000)
[Epoch:1]
batch:0/469, loss:0.5130173563957214, acc:0.890625 (114/128)
batch:93/469, loss:0.45758569145456274, acc:0.8769946808510638 (10552/12032)
batch:186/469, loss:0.44705012376933173, acc:0.8820604946524064 (21113/23936)
batch:279/469, loss:0.43086895591446334, acc:0.8864676339285714 (31771/35840)
batch:372/469, loss:0.41966651623754014, acc:0.8885723860589813 (42424/47744)
batch:465/469, loss:0.4076770568993982, acc:0.8913123658798283 (53165/59648)
Test ==> loss:0.3379575525280796, acc:0.9091 (9091/10000)
[Epoch:2]
batch:0/469, loss:0.46021485328674316, acc:0.8671875 (111/128)
batch:93/469, loss:0.35463421569859727, acc:0.8988530585106383 (10815/12032)
batch:186/469, loss:0.3480192156717739, acc:0.9024481951871658 (21601/23936)
batch:279/469, loss:0.34022124212767396, acc:0.9047154017857143 (32425/35840)
batch:372/469, loss:0.33286028996549405, acc:0.9072553619302949 (43316/47744)
batch:465/469, loss:0.32942312831147036, acc:0.9081109173819742 (54167/59648)
Test ==> loss:0.291702199491519, acc:0.9215 (9215/10000)
[Epoch:3]
batch:0/469, loss:0.2687709629535675, acc:0.9453125 (121/128)
batch:93/469, loss:0.29896670643319473, acc:0.9164727393617021 (11027/12032)
batch:186/469, loss:0.3032062678413595, acc:0.9152322860962567 (21907/23936)
batch:279/469, loss:0.2976516788559301, acc:0.9162946428571429 (32840/35840)
batch:372/469, loss:0.2963846751735933, acc:0.9160941689008043 (43738/47744)
batch:465/469, loss:0.29447377907999595, acc:0.9167784334763949 (54684/59648)
Test ==> loss:0.2693275752701337, acc:0.9252 (9252/10000)
[Epoch:4]
batch:0/469, loss:0.21400471031665802, acc:0.9296875 (119/128)
batch:93/469, loss:0.2811283932087269, acc:0.922124335106383 (11095/12032)
batch:186/469, loss:0.2739176594796665, acc:0.9235461229946524 (22106/23936)
batch:279/469, loss:0.27122129941625256, acc:0.9234933035714286 (33098/35840)
batch:372/469, loss:0.2737213251737743, acc:0.9226290214477212 (44050/47744)
batch:465/469, loss:0.27158979208172646, acc:0.9234173819742489 (55080/59648)
Test ==> loss:0.25467539342898354, acc:0.9293 (9293/10000)


由上面的测试结果可以看出,学生网络的最高正确率为0.9215;而教师网络的最高正确率为0.9769;也就是教师网络要比学生网络要好的,那么下面就是进行知识蒸馏,来看看教师网络带给学生网络的提升。

# 训练过程
def train_one_epoch_kd(s_model, t_model, hard_loss, soft_loss, optimizer, dataloader):
    s_model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        # 教师网络预测
        with torch.no_grad():
            teacher_preds = t_model(image)
        # 学生模型预测
        student_preds = s_model(image)
        # 计算与真实标签的损失:hard loss
        student_loss = hard_loss(student_preds, targets)
        # 计算与教师输出的损失:soft loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        # 总损失即为:hard loss与soft loss的加权和
        loss = alpha * student_loss + (1-alpha) * ditillation_loss
        train_loss += loss.item()
        # 反向更新训练
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算正确个数
        _, predicted = student_preds.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)
# 设置知识蒸馏的超参数
temp = 7      # 蒸馏温度
alpha = 0.3   # 权重系数
# 准备新的学生模型/优化器/损失函数
kd_model = Student().to(device)
hard_loss = nn.CrossEntropyLoss()                # 包含softmax操作
soft_loss = nn.KLDivLoss(reduction="batchmean")  # 不包含softmax操作(所以可以自己设定温度系数)
# 定义蒸馏学生网络的优化器
kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate)
# 利用知识蒸馏来训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch_kd(kd_model, t_model, hard_loss, soft_loss, kd_optimizer, train_loader)
    validate(kd_model, criterion, test_loader)
[Epoch:0]
batch:0/469, loss:-0.7217433452606201, acc:0.0625 (8/128)
batch:93/469, loss:-0.8849591688906893, acc:0.36826795212765956 (4431/12032)
batch:186/469, loss:-0.9996793939468057, acc:0.5336313502673797 (12773/23936)
batch:279/469, loss:-1.0756173231772015, acc:0.6284319196428572 (22523/35840)
batch:372/469, loss:-1.1279368988630278, acc:0.6846724195710456 (32689/47744)
batch:465/469, loss:-1.164531718252043, acc:0.7206444474248928 (42985/59648)
Test ==> loss:0.4428194357624537, acc:0.8843 (8843/10000)
[Epoch:1]
batch:0/469, loss:-1.308181643486023, acc:0.8828125 (113/128)
batch:93/469, loss:-1.333051804532396, acc:0.8811502659574468 (10602/12032)
batch:186/469, loss:-1.3385243995941896, acc:0.8837316176470589 (21153/23936)
batch:279/469, loss:-1.342562198638916, acc:0.8844029017857142 (31697/35840)
batch:372/469, loss:-1.3486905705193732, acc:0.8871062332439679 (42354/47744)
batch:465/469, loss:-1.3531442113188714, acc:0.8896190987124464 (53064/59648)
Test ==> loss:0.32884856549244895, acc:0.907 (9070/10000)
[Epoch:2]
batch:0/469, loss:-1.3262287378311157, acc:0.859375 (110/128)
batch:93/469, loss:-1.3759282218649032, acc:0.901595744680851 (10848/12032)
batch:186/469, loss:-1.3805451450500896, acc:0.9045788770053476 (21652/23936)
batch:279/469, loss:-1.381890092577253, acc:0.903515625 (32382/35840)
batch:372/469, loss:-1.3839336115936811, acc:0.9044277815013405 (43181/47744)
batch:465/469, loss:-1.3856471659287874, acc:0.9056296942060086 (54019/59648)
Test ==> loss:0.30868444205084933, acc:0.9142 (9142/10000)
[Epoch:3]
batch:0/469, loss:-1.3871097564697266, acc:0.90625 (116/128)
batch:93/469, loss:-1.398324375456952, acc:0.9108211436170213 (10959/12032)
batch:186/469, loss:-1.4016364086120523, acc:0.9119318181818182 (21828/23936)
batch:279/469, loss:-1.4019242635795048, acc:0.9117466517857142 (32677/35840)
batch:372/469, loss:-1.4023852696687222, acc:0.9129314678284183 (43587/47744)
batch:465/469, loss:-1.403283029666786, acc:0.9131907188841202 (54470/59648)
Test ==> loss:0.2895406449708757, acc:0.9191 (9191/10000)
[Epoch:4]
batch:0/469, loss:-1.4002737998962402, acc:0.8828125 (113/128)
batch:93/469, loss:-1.4150411626125903, acc:0.9187998670212766 (11055/12032)
batch:186/469, loss:-1.415715930933621, acc:0.9188669786096256 (21994/23936)
batch:279/469, loss:-1.4160895236900874, acc:0.9184709821428572 (32918/35840)
batch:372/469, loss:-1.4162878402116792, acc:0.9184400134048257 (43850/47744)
batch:465/469, loss:-1.4160627477158805, acc:0.9182202253218884 (54770/59648)
Test ==> loss:0.2826786540165732, acc:0.921 (9210/10000)


很可惜,经过短暂的测试,没有体现出知识蒸馏的优点…


4. Learnable Parameters for Knowledge Distillation


对于上面的实验是使用了一个固定的权重参数alpha来控制Distillation loss与Student loss的叠加和,那么能否设置两个可学习的参数来在线的选择权重来控制温度系数呢,下面就是围绕我这个想法进行的实验。


这里由于想要使用可学习的参数来在线的调整权重,所以需要重新设置训练函数:

# 训练过程
def train_one_epoch_kd(s_model, t_model, hard_loss, soft_loss, optimizer, parms_optimizer, dataloader):
    s_model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        # 教师网络预测
        with torch.no_grad():
            teacher_preds = t_model(image)
        # 学生模型预测
        student_preds = s_model(image)
        # 计算与真实标签的损失:hard loss
        student_loss = hard_loss(student_preds, targets)
        # 计算与教师输出的损失:soft loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        # 总损失即为:hard loss与soft loss的加权和
        # 这里经过测试:一个可学习的参数就足够了,使用两个反而效果不好
        # loss = alpha * student_loss + gama * ditillation_loss
        loss = alpha * student_loss + (1 - alpha) * ditillation_loss
        train_loss += loss.item()
        # 反向更新训练
        optimizer.zero_grad()  
        parms_optimizer.zero_grad() 
        loss.backward()
        optimizer.step()  # 网络优化器更新
        parms_optimizer.step()  # 可学习参数优化器更新
        # 计算正确个数
        _, predicted = student_preds.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)


利用可学习参数来进行知识蒸馏核心代码:

# 设置知识蒸馏的超参数
temp = 7      # 蒸馏温度
alpha = nn.Parameter(torch.tensor(0.3), requires_grad=True)   # 权重系数
gama  = nn.Parameter(torch.tensor(0.7), requires_grad=True)   # 权重系数
# params = [alpha, gama]     # 设置两个可学习参数
params = [alpha, ]           # 设置一个可学习参数
# print(params)
#  [Parameter containing:
#    tensor(0.3000, requires_grad=True)]
# 准备新的学生模型/优化器/损失函数
t_model = Teacher().to(device)  
t_model.load_state_dict(torch.load("./t_model.mdl"))   # 导入训练好的Tercher模型
kd_model = Student().to(device)      # 导入待训练的Student模型
hard_loss = nn.CrossEntropyLoss()                # 包含softmax操作
soft_loss = nn.KLDivLoss(reduction="batchmean")  # 不包含softmax操作(所以可以自己设定温度系数)
# 定义蒸馏学生网络的优化器
# kd_optimizer = optim.Adam([
#     {'params': kd_model.parameters()},
#     {'params': alpha, 'lr':3e-4}
# ], lr=learning_rate)
# 由于alpha属于non-leaf Tensor, 以上代码是错误的,所以需要设计两个优化器
kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate)
parms_optimizer = optim.Adam(params, lr=3e-5)
# 利用知识蒸馏来训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]\n".format(epoch), "[params:alpha:{},gama:{}]".format(alpha, gama))
    train_one_epoch_kd(kd_model, t_model, hard_loss, soft_loss, kd_optimizer, parms_optimizer, train_loader)
    validate(kd_model, criterion, test_loader)


输出结果:

[Epoch:0]
 [params:alpha:0.30000001192092896,gama:0.699999988079071]
batch:0/469, loss:-0.6614086031913757, acc:0.0859375 (11/128)
batch:93/469, loss:-0.8821925444805876, acc:0.3480718085106383 (4188/12032)
batch:186/469, loss:-0.9977140796375784, acc:0.5088987299465241 (12181/23936)
batch:279/469, loss:-1.073494029045105, acc:0.6039341517857143 (21645/35840)
batch:372/469, loss:-1.1280986780135944, acc:0.6661151139410187 (31803/47744)
batch:465/469, loss:-1.1685462095195132, acc:0.7057571083690987 (42097/59648)
Test ==> loss:0.46250298619270325, acc:0.8812 (8812/10000)
[Epoch:1]
 [params:alpha:0.2877715528011322,gama:0.699999988079071]
batch:0/469, loss:-1.3155369758605957, acc:0.8203125 (105/128)
batch:93/469, loss:-1.3573756332093097, acc:0.8779089095744681 (10563/12032)
batch:186/469, loss:-1.367437780859636, acc:0.8814756016042781 (21099/23936)
batch:279/469, loss:-1.3752499222755432, acc:0.8830357142857143 (31648/35840)
batch:372/469, loss:-1.3843675500266355, acc:0.8863103217158177 (42316/47744)
batch:465/469, loss:-1.391783864713022, acc:0.8887640826180258 (53013/59648)
Test ==> loss:0.3379697990191134, acc:0.9072 (9072/10000)
[Epoch:2]
 [params:alpha:0.2754247784614563,gama:0.699999988079071]
batch:0/469, loss:-1.4017305374145508, acc:0.9140625 (117/128)
batch:93/469, loss:-1.4349181017977126, acc:0.9048371010638298 (10887/12032)
batch:186/469, loss:-1.440221704901221, acc:0.9034508689839572 (21625/23936)
batch:279/469, loss:-1.4445673780781882, acc:0.9034040178571429 (32378/35840)
batch:372/469, loss:-1.450090474801153, acc:0.9046791219839142 (43193/47744)
batch:465/469, loss:-1.456252665990412, acc:0.905730284334764 (54025/59648)
Test ==> loss:0.2985522945093203, acc:0.9149 (9149/10000)
[Epoch:3]
 [params:alpha:0.2624278962612152,gama:0.699999988079071]
batch:0/469, loss:-1.4728816747665405, acc:0.9140625 (117/128)
batch:93/469, loss:-1.4891292845949213, acc:0.9151429521276596 (11011/12032)
batch:186/469, loss:-1.4935716301362145, acc:0.9136864973262032 (21870/23936)
batch:279/469, loss:-1.4979333805186408, acc:0.9130301339285715 (32723/35840)
batch:372/469, loss:-1.5017808774840735, acc:0.9121565013404825 (43550/47744)
batch:465/469, loss:-1.5070314509674203, acc:0.9135930793991416 (54494/59648)
Test ==> loss:0.28384389652858805, acc:0.9218 (9218/10000)
[Epoch:4]
 [params:alpha:0.2490001767873764,gama:0.699999988079071]
batch:0/469, loss:-1.536271572113037, acc:0.921875 (118/128)
batch:93/469, loss:-1.536556749901873, acc:0.914311835106383 (11001/12032)
batch:186/469, loss:-1.5403413033102924, acc:0.915817179144385 (21921/23936)
batch:279/469, loss:-1.5454376697540284, acc:0.9179129464285715 (32898/35840)
batch:372/469, loss:-1.5496390562594415, acc:0.9183352882037533 (43845/47744)
batch:465/469, loss:-1.5540172580485692, acc:0.9185387607296137 (54789/59648)
Test ==> loss:0.29046644185540044, acc:0.9234 (9234/10000)


这里分别测试两个可学习参数与一个可学习参数训练的差别,我测试过使用一个可学习参数在前5个epoch中的效果是比使用两个可学习参数要好的,而且收敛得也快;同时利用在线的权重参数调整是要比手动设置知识蒸馏的参数在效果上是更好的。


但是,很可惜的是,在线权重调整的蒸馏学习还是没有比直接训练一个小的神经网络要好。我估计,这里的主要原因是数据集是MNIST太简单了,只是对数字进行10分类,所以简单训练的一个小模型的效果往往效果也不会太差,下次有机会会尝试一个困难点的数据集。


目录
相关文章
|
20天前
|
安全 关系型数据库 测试技术
学习Python Web开发的安全测试需要具备哪些知识?
学习Python Web开发的安全测试需要具备哪些知识?
31 4
|
2月前
|
机器学习/深度学习 编解码 监控
目标检测实战(六): 使用YOLOv8完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
这篇文章详细介绍了如何使用YOLOv8进行目标检测任务,包括环境搭建、数据准备、模型训练、验证测试以及模型转换等完整流程。
2419 1
目标检测实战(六): 使用YOLOv8完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
|
2月前
|
自然语言处理 机器人 Python
ChatGPT使用学习:ChatPaper安装到测试详细教程(一文包会)
ChatPaper是一个基于文本生成技术的智能研究论文工具,能够根据用户输入进行智能回复和互动。它支持快速下载、阅读论文,并通过分析论文的关键信息帮助用户判断是否需要深入了解。用户可以通过命令行或网页界面操作,进行论文搜索、下载、总结等。
59 1
ChatGPT使用学习:ChatPaper安装到测试详细教程(一文包会)
|
1月前
|
前端开发 JavaScript 安全
学习如何为 React 组件编写测试:
学习如何为 React 组件编写测试:
37 2
|
1月前
|
编解码 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(10-2):保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali——Liinux-Debian:就怕你学成黑客啦!)作者——LJS
保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali以及常见的报错及对应解决方案、常用Kali功能简便化以及详解如何具体实现
|
1月前
|
编解码 人工智能 自然语言处理
迈向多语言医疗大模型:大规模预训练语料、开源模型与全面基准测试
【10月更文挑战第23天】Oryx 是一种新型多模态架构,能够灵活处理各种分辨率的图像和视频数据,无需标准化。其核心创新包括任意分辨率编码和动态压缩器模块,适用于从微小图标到长时间视频的多种应用场景。Oryx 在长上下文检索和空间感知数据方面表现出色,并且已开源,为多模态研究提供了强大工具。然而,选择合适的分辨率和压缩率仍需谨慎,以平衡处理效率和识别精度。论文地址:https://www.nature.com/articles/s41467-024-52417-z
46 2
|
2月前
|
机器学习/深度学习 弹性计算 自然语言处理
前端大模型应用笔记(二):最新llama3.2小参数版本1B的古董机测试 - 支持128K上下文,表现优异,和移动端更配
llama3.1支持128K上下文,6万字+输入,适用于多种场景。模型能力超出预期,但处理中文时需加中英翻译。测试显示,其英文支持较好,中文则需改进。llama3.2 1B参数量小,适合移动端和资源受限环境,可在阿里云2vCPU和4G ECS上运行。
114 1
|
2月前
|
机器学习/深度学习 监控 计算机视觉
目标检测实战(八): 使用YOLOv7完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
本文介绍了如何使用YOLOv7进行目标检测,包括环境搭建、数据集准备、模型训练、验证、测试以及常见错误的解决方法。YOLOv7以其高效性能和准确率在目标检测领域受到关注,适用于自动驾驶、安防监控等场景。文中提供了源码和论文链接,以及详细的步骤说明,适合深度学习实践者参考。
475 0
目标检测实战(八): 使用YOLOv7完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
|
2月前
|
分布式计算 Hadoop 大数据
大数据体系知识学习(一):PySpark和Hadoop环境的搭建与测试
这篇文章是关于大数据体系知识学习的,主要介绍了Apache Spark的基本概念、特点、组件,以及如何安装配置Java、PySpark和Hadoop环境。文章还提供了详细的安装步骤和测试代码,帮助读者搭建和测试大数据环境。
67 1
|
2月前
|
机器学习/深度学习 并行计算 数据可视化
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
这篇文章介绍了如何使用PaddleClas框架完成多标签分类任务,包括数据准备、环境搭建、模型训练、预测、评估等完整流程。
128 0
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)