1. Temperature Control
在我的上一篇文章中:知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍,介绍了知识蒸馏的温度控制,这里展示一下不同的温度对logits带来的进行,进行一个具体的可视化展示不同温度之间的区别。
import numpy as np import matplotlib.pyplot as plt %matplotlib inline
# 设置类别输出的logit logits = np.array([-5, 2, 7, 9]) labels = ['cat', 'dog', 'donkey', 'horse']
- 普通的softmax(T = 1)
对于普通的softmax函数来说,对某个类别经过softmax的计算公式为:
# 计算普通的softmax函数已经以及绘图 # 其实普通的softmax函数就是T=1时的情况 softmax_1 = np.exp(logits) / sum(np.exp(logits)) plt.plot(softmax_1, label="softmax_1") plt.legend() plt.show()
- 知识蒸馏的softmax(T = k)
对于知识蒸馏的温度系数,一般会大于1,当T越大时输出的类别概率越平滑;而当T越小时输出的类别概率差别越大,曲线越尖锐。
对于知识蒸馏的softmax函数来说,对某个类别经过softmax的计算公式为:
# 设置不同的温度系数来展示最后输出概率的区别 T = 0.6 softmax_06 = np.exp(logits/T) / sum(np.exp(logits/T)) plt.plot(softmax_06, label="softmax_06") T = 0.8 softmax_08 = np.exp(logits/T) / sum(np.exp(logits/T)) plt.plot(softmax_08, label="softmax_08") # 保留softmax_1以供不同T的对比 softmax_1 = np.exp(logits) / sum(np.exp(logits)) plt.plot(softmax_1, label="softmax_1") T = 3 softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T)) plt.plot(softmax_3, label="softmax_3") T = 5 softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T)) plt.plot(softmax_3, label="softmax_3") T = 10 softmax_10 = np.exp(logits/T) / sum(np.exp(logits/T)) plt.plot(softmax_10, label="softmax_10") T = 100 softmax_100 = np.exp(logits/T) / sum(np.exp(logits/T)) plt.plot(softmax_100, label="softmax_100") plt.xticks(np.arange(4), labels=labels) plt.legend() plt.show()
2. Learnable Parameters
import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt %matplotlib inline
# 设置随机种子,保证结果可复现 def SetSeed(SEED=42): random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) SetSeed()
# 定义一个绘图函数,每隔一定的epoch就画出数据点与直线的关系 def Drawing(points, params, figsize=(8, 4)): # points: [[x1,y1],[x2,y2]...] # params: [k, b] k = params[0].item() b = params[1].item() x = np.linspace(-5.,5.,100) y = k*x+b plt.figure(figsize=figsize) # 根据points:画出数据点的分布情况 plt.scatter(points[:, 0], points[:, 1], marker="^", c="blue") # 根据params:画出直线的拟合情况 plt.plot(x, y, c="red") # 设置图表格式 plt.title("line fit") plt.xlabel("x") plt.ylabel("y") plt.show() plt.close()
# 设置一些训练的参数 epoch = 40 learning_rate = 3e-2 # 准备好数据 points = torch.tensor([[-1., -1.], [3., 2.]], dtype=torch.float32) targets = point[:, 1] inputs = point[:, 0] print("inputs:",inputs, "targets:",targets) # 设置两个可学习参数 k = nn.Parameter(torch.randn(1), requires_grad=True) b = nn.Parameter(torch.randn(1), requires_grad=True) params = [k, b] print(params) # 设置优化器与损失函数 optimizer = optim.Adam(params, lr=learning_rate) criterion = nn.MSELoss() # 训练两个参数 loss_lists = [] for i in range(epoch): optimizer.zero_grad() outputs = inputs*k + b loss = criterion(outputs, targets) loss_lists.append(loss.item()) loss.backward() optimizer.step() if (i+1) % 4 == 0: Drawing(points, [k,b]) # print("outputs:",outputs) # print("k:", k) # print("b:", b) # 查看训练后的参数 print("k:", k) print("b:", b)
输出:
# 展示图片前的输出结果 inputs: tensor([-1., 3.]) targets: tensor([-1., 2.]) [Parameter containing: tensor([-0.0223], requires_grad=True), Parameter containing: tensor([0.3827], requires_grad=True)] # 展示完图片后的输出结果 k: Parameter containing: tensor([0.7866], requires_grad=True) b: Parameter containing: tensor([-0.3185], requires_grad=True)
3. Knowledge Distillation
这里我的想法是通过搭建两个神经网络,一个大网络一个小网络,查看小网络知识蒸馏前后的效果。ps:这里的大神经网络模型也可以有CNN模型替换
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader import tqdm from thop import profile
# 定义教师网络与学生网络 class Teacher(nn.Module): def __init__(self, num_classes=10): super(Teacher, self).__init__() self.model = nn.Sequential( nn.Linear(784, 1200), nn.Dropout(p=0.5), nn.ReLU(), nn.Linear(1200, 1200), nn.Dropout(p=0.5), nn.ReLU(), nn.Linear(1200, num_classes) ) def forward(self, x): x = x.view(-1, 784) return self.model(x) class Student(nn.Module): def __init__(self, num_classes=10): super(Student, self).__init__() self.model = nn.Sequential( nn.Linear(784, 20), nn.ReLU(), # nn.Linear(20, 20), # nn.ReLU(), nn.Linear(20, num_classes) ) def forward(self, x): x = x.view(-1, 784) return self.model(x) # 测试网络的参数量与浮点计算量 def test_flops_params(): x = torch.rand([1, 1, 28, 28]) # x = x.reshape(1, 1, -1) T_network = Teacher() S_network = Student() t_flops, t_params = profile(T_network, inputs=(x, )) print('t_flops:{}e-3G'.format(t_flops / 1000000), 't_params:{}M'.format(t_params / 1000000)) s_flops, s_params = profile(S_network, inputs=(x, )) print('s_flops:{}e-3G'.format(s_flops / 1000000), 's_params:{}M'.format(s_params / 1000000)) test_flops_params()
输出结果:可以看见教师网络的参数量和浮点计算量都是学生网络的上百倍
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>. [91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m [91m[WARN] Cannot find rule for <class '__main__.Teacher'>. Treat it as zero Macs and zero Params.[00m t_flops:2.3928e-3G t_params:2.39521M [INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>. [91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m [91m[WARN] Cannot find rule for <class '__main__.Student'>. Treat it as zero Macs and zero Params.[00m s_flops:0.01588e-3G s_params:0.01591M
# 设置超参数 epoch_size = 5 batch_size = 128 learning_rate = 1e-4 # 训练集下载 train_data = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.1307]) ]), download=True) train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) # 测试集下载 test_data = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.1307]) ]), download=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) # 定义模型 device = torch.device('cuda') t_model = Teacher().to(device) s_model = Student().to(device) # 定义优化器与损失 criterion = nn.CrossEntropyLoss().to(device) # 训练过程 def train_one_epoch(model, criterion, optimizer, dataloader): model.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (image, targets) in enumerate(dataloader): image, targets = image.to(device), targets.to(device) outputs = model(image) loss = criterion(outputs, targets) train_loss += loss.item() # 反向更新训练 optimizer.zero_grad() loss.backward() optimizer.step() # 计算正确个数 _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() if batch_idx % int(len(dataloader)/5) == 0: train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\ .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1), correct / total, correct, total) print(train_info) # 测试过程 def validate(model, criterion, dataloader): model.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() test_info = "Test ==> loss:{}, acc:{} ({}/{})"\ .format(test_loss/len(dataloader), correct/total, correct, total) print(test_info)
查看教师网络的效果
# 定义教师网络的优化器 t_optimizer = optim.Adam(t_model.parameters(), lr=learning_rate) # 训练教师网络 for epoch in range(epoch_size): print("[Epoch:{}]".format(epoch)) train_one_epoch(t_model, criterion, t_optimizer, train_loader) validate(t_model, criterion, test_loader) # 训练好教师模型后,先保存教师网络的模型参数,有需要再重新导入即可 torch.save(t_model.state_dict(), "t_model.mdl")
训练教师网络的输出结果:
[Epoch:0] batch:0/469, loss:2.451192855834961, acc:0.046875 (6/128) batch:93/469, loss:0.9547713903036523, acc:0.7139295212765957 (8590/12032) batch:186/469, loss:0.6837506957232633, acc:0.7931149732620321 (18984/23936) batch:279/469, loss:0.5656270824904953, acc:0.8287388392857142 (29702/35840) batch:372/469, loss:0.49507016175234286, acc:0.8506618632707775 (40614/47744) batch:465/469, loss:0.44825181595245656, acc:0.864538626609442 (51568/59648) Test ==> loss:0.1785062526977515, acc:0.9461 (9461/10000) [Epoch:1] batch:0/469, loss:0.2513591945171356, acc:0.9375 (120/128) batch:93/469, loss:0.22307029120782587, acc:0.9340093085106383 (11238/12032) batch:186/469, loss:0.2112935145988184, acc:0.9371657754010695 (22432/23936) batch:279/469, loss:0.2037411120853254, acc:0.9394252232142857 (33669/35840) batch:372/469, loss:0.19692633960186318, acc:0.9414167225201072 (44947/47744) batch:465/469, loss:0.19179708627352388, acc:0.9429821620171673 (56247/59648) Test ==> loss:0.11782180916376506, acc:0.9622 (9622/10000) [Epoch:2] batch:0/469, loss:0.15417274832725525, acc:0.9453125 (121/128) batch:93/469, loss:0.14990339277589576, acc:0.9549534574468085 (11490/12032) batch:186/469, loss:0.14525708044196833, acc:0.9562583556149733 (22889/23936) batch:279/469, loss:0.14779337254752006, acc:0.9551060267857143 (34231/35840) batch:372/469, loss:0.1445239940433496, acc:0.9564133713136729 (45663/47744) batch:465/469, loss:0.14156085961846068, acc:0.9572659603004292 (57099/59648) Test ==> loss:0.09912304252480404, acc:0.9695 (9695/10000) [Epoch:3] batch:0/469, loss:0.09023044258356094, acc:0.984375 (126/128) batch:93/469, loss:0.11060939039638702, acc:0.9670877659574468 (11636/12032) batch:186/469, loss:0.11260852741605458, acc:0.9668699866310161 (23143/23936) batch:279/469, loss:0.11275576776159661, acc:0.9667410714285715 (34648/35840) batch:372/469, loss:0.11253649257023597, acc:0.9668440013404825 (46161/47744) batch:465/469, loss:0.11281515193839314, acc:0.9665873122317596 (57655/59648) Test ==> loss:0.0813662743231258, acc:0.9734 (9734/10000) [Epoch:4] batch:0/469, loss:0.10590803623199463, acc:0.9765625 (125/128) batch:93/469, loss:0.0938354417523171, acc:0.9718251329787234 (11693/12032) batch:186/469, loss:0.09741261341474591, acc:0.9707971256684492 (23237/23936) batch:279/469, loss:0.0959280665631273, acc:0.9712332589285714 (34809/35840) batch:372/469, loss:0.09434855888140745, acc:0.9716823056300268 (46392/47744) batch:465/469, loss:0.09377776978481481, acc:0.9719521190987125 (57975/59648) Test ==> loss:0.07517792291562014, acc:0.975 (9750/10000)
以上过程,已经成功的训练出了一个参数量稍大的教师模型,那么下面看看如果只是简单的训练一个学生网络的效果会是怎么样的呢?
训练学生网络的输出结果:
# 定义学生网络的优化器 s_optimizer = optim.Adam(s_model.parameters(), lr=learning_rate) # 训练学生网络 for epoch in range(epoch_size): print("[Epoch:{}]".format(epoch)) train_one_epoch(s_model, criterion, s_optimizer, train_loader) validate(s_model, criterion, test_loader)
[Epoch:0] batch:0/469, loss:2.435654878616333, acc:0.0859375 (11/128) batch:93/469, loss:1.8272213606124228, acc:0.4320146276595745 (5198/12032) batch:186/469, loss:1.431586041170008, acc:0.59295621657754 (14193/23936) batch:279/469, loss:1.1955812790564129, acc:0.6700892857142857 (24016/35840) batch:372/469, loss:1.0443579082354784, acc:0.7167392761394102 (34220/47744) batch:465/469, loss:0.9376831678233944, acc:0.7478205472103004 (44606/59648) Test ==> loss:0.46302026283891895, acc:0.8882 (8882/10000) [Epoch:1] batch:0/469, loss:0.5130173563957214, acc:0.890625 (114/128) batch:93/469, loss:0.45758569145456274, acc:0.8769946808510638 (10552/12032) batch:186/469, loss:0.44705012376933173, acc:0.8820604946524064 (21113/23936) batch:279/469, loss:0.43086895591446334, acc:0.8864676339285714 (31771/35840) batch:372/469, loss:0.41966651623754014, acc:0.8885723860589813 (42424/47744) batch:465/469, loss:0.4076770568993982, acc:0.8913123658798283 (53165/59648) Test ==> loss:0.3379575525280796, acc:0.9091 (9091/10000) [Epoch:2] batch:0/469, loss:0.46021485328674316, acc:0.8671875 (111/128) batch:93/469, loss:0.35463421569859727, acc:0.8988530585106383 (10815/12032) batch:186/469, loss:0.3480192156717739, acc:0.9024481951871658 (21601/23936) batch:279/469, loss:0.34022124212767396, acc:0.9047154017857143 (32425/35840) batch:372/469, loss:0.33286028996549405, acc:0.9072553619302949 (43316/47744) batch:465/469, loss:0.32942312831147036, acc:0.9081109173819742 (54167/59648) Test ==> loss:0.291702199491519, acc:0.9215 (9215/10000) [Epoch:3] batch:0/469, loss:0.2687709629535675, acc:0.9453125 (121/128) batch:93/469, loss:0.29896670643319473, acc:0.9164727393617021 (11027/12032) batch:186/469, loss:0.3032062678413595, acc:0.9152322860962567 (21907/23936) batch:279/469, loss:0.2976516788559301, acc:0.9162946428571429 (32840/35840) batch:372/469, loss:0.2963846751735933, acc:0.9160941689008043 (43738/47744) batch:465/469, loss:0.29447377907999595, acc:0.9167784334763949 (54684/59648) Test ==> loss:0.2693275752701337, acc:0.9252 (9252/10000) [Epoch:4] batch:0/469, loss:0.21400471031665802, acc:0.9296875 (119/128) batch:93/469, loss:0.2811283932087269, acc:0.922124335106383 (11095/12032) batch:186/469, loss:0.2739176594796665, acc:0.9235461229946524 (22106/23936) batch:279/469, loss:0.27122129941625256, acc:0.9234933035714286 (33098/35840) batch:372/469, loss:0.2737213251737743, acc:0.9226290214477212 (44050/47744) batch:465/469, loss:0.27158979208172646, acc:0.9234173819742489 (55080/59648) Test ==> loss:0.25467539342898354, acc:0.9293 (9293/10000)
由上面的测试结果可以看出,学生网络的最高正确率为0.9215;而教师网络的最高正确率为0.9769;也就是教师网络要比学生网络要好的,那么下面就是进行知识蒸馏,来看看教师网络带给学生网络的提升。
# 训练过程 def train_one_epoch_kd(s_model, t_model, hard_loss, soft_loss, optimizer, dataloader): s_model.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (image, targets) in enumerate(dataloader): image, targets = image.to(device), targets.to(device) # 教师网络预测 with torch.no_grad(): teacher_preds = t_model(image) # 学生模型预测 student_preds = s_model(image) # 计算与真实标签的损失:hard loss student_loss = hard_loss(student_preds, targets) # 计算与教师输出的损失:soft loss ditillation_loss = soft_loss( F.softmax(student_preds / temp, dim=1), F.softmax(teacher_preds / temp, dim=1) ) # 总损失即为:hard loss与soft loss的加权和 loss = alpha * student_loss + (1-alpha) * ditillation_loss train_loss += loss.item() # 反向更新训练 optimizer.zero_grad() loss.backward() optimizer.step() # 计算正确个数 _, predicted = student_preds.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() if batch_idx % int(len(dataloader)/5) == 0: train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\ .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1), correct / total, correct, total) print(train_info)
# 设置知识蒸馏的超参数 temp = 7 # 蒸馏温度 alpha = 0.3 # 权重系数 # 准备新的学生模型/优化器/损失函数 kd_model = Student().to(device) hard_loss = nn.CrossEntropyLoss() # 包含softmax操作 soft_loss = nn.KLDivLoss(reduction="batchmean") # 不包含softmax操作(所以可以自己设定温度系数) # 定义蒸馏学生网络的优化器 kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate) # 利用知识蒸馏来训练学生网络 for epoch in range(epoch_size): print("[Epoch:{}]".format(epoch)) train_one_epoch_kd(kd_model, t_model, hard_loss, soft_loss, kd_optimizer, train_loader) validate(kd_model, criterion, test_loader)
[Epoch:0] batch:0/469, loss:-0.7217433452606201, acc:0.0625 (8/128) batch:93/469, loss:-0.8849591688906893, acc:0.36826795212765956 (4431/12032) batch:186/469, loss:-0.9996793939468057, acc:0.5336313502673797 (12773/23936) batch:279/469, loss:-1.0756173231772015, acc:0.6284319196428572 (22523/35840) batch:372/469, loss:-1.1279368988630278, acc:0.6846724195710456 (32689/47744) batch:465/469, loss:-1.164531718252043, acc:0.7206444474248928 (42985/59648) Test ==> loss:0.4428194357624537, acc:0.8843 (8843/10000) [Epoch:1] batch:0/469, loss:-1.308181643486023, acc:0.8828125 (113/128) batch:93/469, loss:-1.333051804532396, acc:0.8811502659574468 (10602/12032) batch:186/469, loss:-1.3385243995941896, acc:0.8837316176470589 (21153/23936) batch:279/469, loss:-1.342562198638916, acc:0.8844029017857142 (31697/35840) batch:372/469, loss:-1.3486905705193732, acc:0.8871062332439679 (42354/47744) batch:465/469, loss:-1.3531442113188714, acc:0.8896190987124464 (53064/59648) Test ==> loss:0.32884856549244895, acc:0.907 (9070/10000) [Epoch:2] batch:0/469, loss:-1.3262287378311157, acc:0.859375 (110/128) batch:93/469, loss:-1.3759282218649032, acc:0.901595744680851 (10848/12032) batch:186/469, loss:-1.3805451450500896, acc:0.9045788770053476 (21652/23936) batch:279/469, loss:-1.381890092577253, acc:0.903515625 (32382/35840) batch:372/469, loss:-1.3839336115936811, acc:0.9044277815013405 (43181/47744) batch:465/469, loss:-1.3856471659287874, acc:0.9056296942060086 (54019/59648) Test ==> loss:0.30868444205084933, acc:0.9142 (9142/10000) [Epoch:3] batch:0/469, loss:-1.3871097564697266, acc:0.90625 (116/128) batch:93/469, loss:-1.398324375456952, acc:0.9108211436170213 (10959/12032) batch:186/469, loss:-1.4016364086120523, acc:0.9119318181818182 (21828/23936) batch:279/469, loss:-1.4019242635795048, acc:0.9117466517857142 (32677/35840) batch:372/469, loss:-1.4023852696687222, acc:0.9129314678284183 (43587/47744) batch:465/469, loss:-1.403283029666786, acc:0.9131907188841202 (54470/59648) Test ==> loss:0.2895406449708757, acc:0.9191 (9191/10000) [Epoch:4] batch:0/469, loss:-1.4002737998962402, acc:0.8828125 (113/128) batch:93/469, loss:-1.4150411626125903, acc:0.9187998670212766 (11055/12032) batch:186/469, loss:-1.415715930933621, acc:0.9188669786096256 (21994/23936) batch:279/469, loss:-1.4160895236900874, acc:0.9184709821428572 (32918/35840) batch:372/469, loss:-1.4162878402116792, acc:0.9184400134048257 (43850/47744) batch:465/469, loss:-1.4160627477158805, acc:0.9182202253218884 (54770/59648) Test ==> loss:0.2826786540165732, acc:0.921 (9210/10000)
很可惜,经过短暂的测试,没有体现出知识蒸馏的优点…
4. Learnable Parameters for Knowledge Distillation
对于上面的实验是使用了一个固定的权重参数alpha来控制Distillation loss与Student loss的叠加和,那么能否设置两个可学习的参数来在线的选择权重来控制温度系数呢,下面就是围绕我这个想法进行的实验。
这里由于想要使用可学习的参数来在线的调整权重,所以需要重新设置训练函数:
# 训练过程 def train_one_epoch_kd(s_model, t_model, hard_loss, soft_loss, optimizer, parms_optimizer, dataloader): s_model.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (image, targets) in enumerate(dataloader): image, targets = image.to(device), targets.to(device) # 教师网络预测 with torch.no_grad(): teacher_preds = t_model(image) # 学生模型预测 student_preds = s_model(image) # 计算与真实标签的损失:hard loss student_loss = hard_loss(student_preds, targets) # 计算与教师输出的损失:soft loss ditillation_loss = soft_loss( F.softmax(student_preds / temp, dim=1), F.softmax(teacher_preds / temp, dim=1) ) # 总损失即为:hard loss与soft loss的加权和 # 这里经过测试:一个可学习的参数就足够了,使用两个反而效果不好 # loss = alpha * student_loss + gama * ditillation_loss loss = alpha * student_loss + (1 - alpha) * ditillation_loss train_loss += loss.item() # 反向更新训练 optimizer.zero_grad() parms_optimizer.zero_grad() loss.backward() optimizer.step() # 网络优化器更新 parms_optimizer.step() # 可学习参数优化器更新 # 计算正确个数 _, predicted = student_preds.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() if batch_idx % int(len(dataloader)/5) == 0: train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\ .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1), correct / total, correct, total) print(train_info)
利用可学习参数来进行知识蒸馏核心代码:
# 设置知识蒸馏的超参数 temp = 7 # 蒸馏温度 alpha = nn.Parameter(torch.tensor(0.3), requires_grad=True) # 权重系数 gama = nn.Parameter(torch.tensor(0.7), requires_grad=True) # 权重系数 # params = [alpha, gama] # 设置两个可学习参数 params = [alpha, ] # 设置一个可学习参数 # print(params) # [Parameter containing: # tensor(0.3000, requires_grad=True)] # 准备新的学生模型/优化器/损失函数 t_model = Teacher().to(device) t_model.load_state_dict(torch.load("./t_model.mdl")) # 导入训练好的Tercher模型 kd_model = Student().to(device) # 导入待训练的Student模型 hard_loss = nn.CrossEntropyLoss() # 包含softmax操作 soft_loss = nn.KLDivLoss(reduction="batchmean") # 不包含softmax操作(所以可以自己设定温度系数) # 定义蒸馏学生网络的优化器 # kd_optimizer = optim.Adam([ # {'params': kd_model.parameters()}, # {'params': alpha, 'lr':3e-4} # ], lr=learning_rate) # 由于alpha属于non-leaf Tensor, 以上代码是错误的,所以需要设计两个优化器 kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate) parms_optimizer = optim.Adam(params, lr=3e-5) # 利用知识蒸馏来训练学生网络 for epoch in range(epoch_size): print("[Epoch:{}]\n".format(epoch), "[params:alpha:{},gama:{}]".format(alpha, gama)) train_one_epoch_kd(kd_model, t_model, hard_loss, soft_loss, kd_optimizer, parms_optimizer, train_loader) validate(kd_model, criterion, test_loader)
输出结果:
[Epoch:0] [params:alpha:0.30000001192092896,gama:0.699999988079071] batch:0/469, loss:-0.6614086031913757, acc:0.0859375 (11/128) batch:93/469, loss:-0.8821925444805876, acc:0.3480718085106383 (4188/12032) batch:186/469, loss:-0.9977140796375784, acc:0.5088987299465241 (12181/23936) batch:279/469, loss:-1.073494029045105, acc:0.6039341517857143 (21645/35840) batch:372/469, loss:-1.1280986780135944, acc:0.6661151139410187 (31803/47744) batch:465/469, loss:-1.1685462095195132, acc:0.7057571083690987 (42097/59648) Test ==> loss:0.46250298619270325, acc:0.8812 (8812/10000) [Epoch:1] [params:alpha:0.2877715528011322,gama:0.699999988079071] batch:0/469, loss:-1.3155369758605957, acc:0.8203125 (105/128) batch:93/469, loss:-1.3573756332093097, acc:0.8779089095744681 (10563/12032) batch:186/469, loss:-1.367437780859636, acc:0.8814756016042781 (21099/23936) batch:279/469, loss:-1.3752499222755432, acc:0.8830357142857143 (31648/35840) batch:372/469, loss:-1.3843675500266355, acc:0.8863103217158177 (42316/47744) batch:465/469, loss:-1.391783864713022, acc:0.8887640826180258 (53013/59648) Test ==> loss:0.3379697990191134, acc:0.9072 (9072/10000) [Epoch:2] [params:alpha:0.2754247784614563,gama:0.699999988079071] batch:0/469, loss:-1.4017305374145508, acc:0.9140625 (117/128) batch:93/469, loss:-1.4349181017977126, acc:0.9048371010638298 (10887/12032) batch:186/469, loss:-1.440221704901221, acc:0.9034508689839572 (21625/23936) batch:279/469, loss:-1.4445673780781882, acc:0.9034040178571429 (32378/35840) batch:372/469, loss:-1.450090474801153, acc:0.9046791219839142 (43193/47744) batch:465/469, loss:-1.456252665990412, acc:0.905730284334764 (54025/59648) Test ==> loss:0.2985522945093203, acc:0.9149 (9149/10000) [Epoch:3] [params:alpha:0.2624278962612152,gama:0.699999988079071] batch:0/469, loss:-1.4728816747665405, acc:0.9140625 (117/128) batch:93/469, loss:-1.4891292845949213, acc:0.9151429521276596 (11011/12032) batch:186/469, loss:-1.4935716301362145, acc:0.9136864973262032 (21870/23936) batch:279/469, loss:-1.4979333805186408, acc:0.9130301339285715 (32723/35840) batch:372/469, loss:-1.5017808774840735, acc:0.9121565013404825 (43550/47744) batch:465/469, loss:-1.5070314509674203, acc:0.9135930793991416 (54494/59648) Test ==> loss:0.28384389652858805, acc:0.9218 (9218/10000) [Epoch:4] [params:alpha:0.2490001767873764,gama:0.699999988079071] batch:0/469, loss:-1.536271572113037, acc:0.921875 (118/128) batch:93/469, loss:-1.536556749901873, acc:0.914311835106383 (11001/12032) batch:186/469, loss:-1.5403413033102924, acc:0.915817179144385 (21921/23936) batch:279/469, loss:-1.5454376697540284, acc:0.9179129464285715 (32898/35840) batch:372/469, loss:-1.5496390562594415, acc:0.9183352882037533 (43845/47744) batch:465/469, loss:-1.5540172580485692, acc:0.9185387607296137 (54789/59648) Test ==> loss:0.29046644185540044, acc:0.9234 (9234/10000)
这里分别测试两个可学习参数与一个可学习参数训练的差别,我测试过使用一个可学习参数在前5个epoch中的效果是比使用两个可学习参数要好的,而且收敛得也快;同时利用在线的权重参数调整是要比手动设置知识蒸馏的参数在效果上是更好的。
但是,很可惜的是,在线权重调整的蒸馏学习还是没有比直接训练一个小的神经网络要好。我估计,这里的主要原因是数据集是MNIST太简单了,只是对数字进行10分类,所以简单训练的一个小模型的效果往往效果也不会太差,下次有机会会尝试一个困难点的数据集。