分类网络知识蒸馏【附代码】

简介: 笔记

知识蒸馏属于模型的压缩一种方法,但其实这种方法又属于一种伪压缩,是将一个性能较好的teacher网络“压缩”进一个性能较差的student网络中,或者是可类似于在teacher的指导下让student进行学习进而提高性能。


知识蒸馏是一种思想,并不像其他压缩方法有现成的库,因此对于实际需求与场景需要自己去实现。蒸馏也分为“离线”蒸馏与“在线”蒸馏。前者是建立T-S进行KD训练,而后者可以说是一种自学习,让student自己做自己的teacher。


同时蒸馏还分为逻辑蒸馏和特征蒸馏,前者是在两个网络最终输出部分建立loss关系,而后者是在网络中间的某些特征部分建立loss进行蒸馏。


本文是以手写数字为例,teacher选用的resnet18,student选用的resnet50【大家可能会想resnet50比resnet18强啊,为啥resnet50是student,这是因为我在实际测试的时候发现在手写数字这个数据上resnet18的准确率比resnet50高,猜测是因为在低分辨率下resnet50虽然loss在下降,但由于网络较深,特征丢失也明显,网络退化较明显】。当然这里你也可以尝试resnet做teacher,mobilnet做student【我这样训练了一下发现对mobilnet提升变化不大】


注:这里不做模型和蒸馏改进,仅仅是给大家展示一下效果,至于更细化的蒸馏有兴趣的可以自己去研究。【有关目标检测方面的KD 训练,我将会在明年以后推出】



teacher train代码


参数说明:

teacher_model:选用的teacher网络

train_loader:训练集

test_loader:测试集

loss_func:损失函数

epochs:训练迭代数

def teacher_train(teacher_model, train_loader, test_loader, loss_func, epochs):
    teacher_model.train()
    teacher_model.cuda()
    # train
    for i in range(epochs):
        for data, label in train_loader:
            data = data.to(device)
            label = label.to(device)
            output = teacher_model(data)
            loss = loss_func(output, label)
            optimizer_teacher.zero_grad()
            loss.backward()
            optimizer_teacher.step()
        print("loss: ", loss)
        # eval
        correct = 0
        teacher_model.eval()
        teacher_model.cuda()
        for test_data, test_label in test_loader:
            test_data = test_data.to(device)
            test_label = test_label.to(device)
            with torch.no_grad():
                output = teacher_model(test_data)
                # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
                # print("teacher acc: ", acc)
                _, pred = torch.max(output, dim=1)
                correct += float(torch.sum(pred == test_label))
        print('test_acc:{}'.format(correct / len(test_dataset)))
    return teacher_model

训练结果(我只训练了5轮):

teacher model train
loss:  tensor(0.0891, device='cuda:0', grad_fn=<NllLossBackward>)
test_acc:0.9845
loss:  tensor(0.0132, device='cuda:0', grad_fn=<NllLossBackward>)
test_acc:0.9865
loss:  tensor(0.0019, device='cuda:0', grad_fn=<NllLossBackward>)
test_acc:0.9909
loss:  tensor(0.0042, device='cuda:0', grad_fn=<NllLossBackward>)
test_acc:0.9909
loss:  tensor(0.0034, device='cuda:0', grad_fn=<NllLossBackward>)
test_acc:0.9917
teacher model trained finished!

student未KD 训练


参数说明:

student_model:选用的student网络

train_loader:训练集

test_loader:测试集

loss_func:损失函数

epochs:训练迭代数

def student_train(student_model, train_loader, test_loader, loss_func, epochs):
    student_model.train()
    student_model.cuda()
    # train
    for i in range(epochs):
        for data, label in train_loader:
            data = data.to(device)
            label = label.to(device)
            output = student_model(data)
            loss = loss_func(output, label)
            optimizer_student.zero_grad()
            loss.backward()
            optimizer_student.step()
        print("student loss: ", loss)
        # eval
        correct = 0
        student_model.eval()
        student_model.cuda()
        for test_data, test_label in test_loader:
            test_data = test_data.to(device)
            test_label = test_label.to(device)
            with torch.no_grad():
                output = student_model(test_data)
                # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
                # print("teacher acc: ", acc)
                _, pred = torch.max(output, dim=1)
                correct += float(torch.sum(pred == test_label))
        print('student test_acc:{}'.format(correct / len(test_dataset)))

没有KD train的效果如下:

student model ready train
student loss:  tensor(0.1876, device='cuda:0', grad_fn=<NllLossBackward>)
student test_acc:0.9588
student loss:  tensor(0.0219, device='cuda:0', grad_fn=<NllLossBackward>)
student test_acc:0.9737
student loss:  tensor(0.0588, device='cuda:0', grad_fn=<NllLossBackward>)
student test_acc:0.9812
student loss:  tensor(0.0024, device='cuda:0', grad_fn=<NllLossBackward>)
student test_acc:0.9853
student loss:  tensor(0.0022, device='cuda:0', grad_fn=<NllLossBackward>)
student test_acc:0.9814
 student model trained finished!

KD train代码


参数说明:

teacher_model:为已经训练好的teacher

student_model:待KD的student网络

train_loader:训练集

test_loader:测试集

def KD_train(teacher_model, student_model, train_loader, test_loader,loss_func, epochs):
    teacher_model.eval()
    student_model.train()
    student_model.cuda()
    HL = nn.CrossEntropyLoss()
    for i in range(epochs):
        for data, labels in train_loader:
            data = data.to(device)
            labels = labels.to(device)
            teacher_output = teacher_model(data)
            student_output = student_model(data)
            soft_loss = KD_loss(teacher_output, student_output)
            hard_loss = HL(student_output, labels)
            loss = hard_loss + alpha*soft_loss
            optimizer_student.zero_grad()
            loss.backward()
            optimizer_student.step()
        print("KD loss: ", loss)
        student_model.eval()
        ACC = 0
        for data, labels in test_loader:
            with torch.no_grad():
                data = data.to(device)
                labels = labels.to(device)
                output = student_model(data)
                _, pred = torch.max(output, dim=1)
                ACC += float(torch.sum(pred == labels))
        print('KD test_acc:{}'.format(ACC / len(test_dataset)))

代码中的teacher_output是teacher网络的输出,student_output是student的输出,两者之间设计的KD_loss代码如下:


KD_loss代码:

Temp为温度系数,默认为2【可以根据自己的数据集去尝试】


alpha是hard与soft的平衡系数【默认0.5,也是根据自己的实际情况调整】


损失函数采用的KL,你也可以改为交叉熵。

Temp = 2.  # 温度常数
alpha = 0.5
def KD_loss(p, q):  # p指的老师老师的预测(经过softmax),q是学生的预测
    pt = F.softmax(p / Temp, dim=1)
    ps = F.log_softmax(q / Temp, dim=1)
    return nn.KLDivLoss(reduction='mean')(ps, pt) * (Temp**2)

KD tran后student结果:

KD loss:  tensor(0.2580, device='cuda:0', grad_fn=<AddBackward0>)
KD test_acc:0.9753
KD loss:  tensor(0.1686, device='cuda:0', grad_fn=<AddBackward0>)
KD test_acc:0.9748
KD loss:  tensor(0.0827, device='cuda:0', grad_fn=<AddBackward0>)
KD test_acc:0.9849
KD loss:  tensor(0.0098, device='cuda:0', grad_fn=<AddBackward0>)
KD test_acc:0.9865
KD loss:  tensor(0.0114, device='cuda:0', grad_fn=<AddBackward0>)
KD test_acc:0.988

可以看出经过KD训练后student略有提升【主要手写数字这个太容易训练,稍微一训练就可以有较高的准确率】,如果换成别的数据集【比如猫狗数据集可能会明显点,可以自己试试】。


如果要换teacher和student网络,只需要在代码中将teacher_model和student_model网络进行替换即可。


完整代码


目标检测方面的KD比较麻烦,这个以后再讲。

import torch
from torch.optim import Adam, SGD
import torch.nn.functional as F
import torch.nn as nn
from torchvision.models import resnet50, resnet34, resnet18, MobileNetV2
import torchvision
import torchvision.transforms as transforms
Temp = 2.  # 温度常数
alpha = 0.5
def KD_loss(p, q):  # p指的老师老师的预测(经过softmax),q是学生的预测
    pt = F.softmax(p / Temp, dim=1)
    ps = F.log_softmax(q / Temp, dim=1)
    return nn.KLDivLoss(reduction='mean')(ps, pt) * (Temp**2)
def teacher_train(teacher_model, train_loader, test_loader, loss_func, epochs):
    teacher_model.train()
    teacher_model.cuda()
    # train
    for i in range(epochs):
        for data, label in train_loader:
            data = data.to(device)
            label = label.to(device)
            output = teacher_model(data)
            loss = loss_func(output, label)
            optimizer_teacher.zero_grad()
            loss.backward()
            optimizer_teacher.step()
        print("loss: ", loss)
        # eval
        correct = 0
        teacher_model.eval()
        teacher_model.cuda()
        for test_data, test_label in test_loader:
            test_data = test_data.to(device)
            test_label = test_label.to(device)
            with torch.no_grad():
                output = teacher_model(test_data)
                # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
                # print("teacher acc: ", acc)
                _, pred = torch.max(output, dim=1)
                correct += float(torch.sum(pred == test_label))
        print('test_acc:{}'.format(correct / len(test_dataset)))
    return teacher_model
def student_train(student_model, train_loader, test_loader, loss_func, epochs):
    student_model.train()
    student_model.cuda()
    # train
    for i in range(epochs):
        for data, label in train_loader:
            data = data.to(device)
            label = label.to(device)
            output = student_model(data)
            loss = loss_func(output, label)
            optimizer_student.zero_grad()
            loss.backward()
            optimizer_student.step()
        print("student loss: ", loss)
        # eval
        correct = 0
        student_model.eval()
        student_model.cuda()
        for test_data, test_label in test_loader:
            test_data = test_data.to(device)
            test_label = test_label.to(device)
            with torch.no_grad():
                output = student_model(test_data)
                # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
                # print("teacher acc: ", acc)
                _, pred = torch.max(output, dim=1)
                correct += float(torch.sum(pred == test_label))
        print('student test_acc:{}'.format(correct / len(test_dataset)))
def KD_train(teacher_model, student_model, train_loader, test_loader,loss_func, epochs):
    teacher_model.eval()
    student_model.train()
    student_model.cuda()
    HL = nn.CrossEntropyLoss()
    for i in range(epochs):
        for data, labels in train_loader:
            data = data.to(device)
            labels = labels.to(device)
            teacher_output = teacher_model(data)
            student_output = student_model(data)
            soft_loss = KD_loss(teacher_output, student_output)
            hard_loss = HL(student_output, labels)
            loss = hard_loss + alpha*soft_loss
            optimizer_student.zero_grad()
            loss.backward()
            optimizer_student.step()
        print("KD loss: ", loss)
        student_model.eval()
        ACC = 0
        for data, labels in test_loader:
            with torch.no_grad():
                data = data.to(device)
                labels = labels.to(device)
                output = student_model(data)
                _, pred = torch.max(output, dim=1)
                ACC += float(torch.sum(pred == labels))
        print('KD test_acc:{}'.format(ACC / len(test_dataset)))
def do_train(teacher_model, student_model, train_loader, test_loader, loss_func, epochs):
    #教师训练
    teacher_model.train()
    teacher_model.to(device)
    print("teacher model train")
    Teacher = teacher_train(teacher_model, train_loader, test_loader, loss_func, epochs)
    print("teacher model trained finished!")
    # print("\n student model ready train")
    # student_train(student_model, train_loader, test_loader, loss_func, epochs)
    # print("\n student model trained finished!")
    print("\n KD model ready train")
    KD_train(Teacher, student_model, train_loader, test_loader, loss_func, epochs)
if __name__=="__main__":
    # 准备数据集
    batch_size = 64
    train_dataset = torchvision.datasets.MNIST('./data/', train=True, download=True,
                                               transform=transforms.Compose([
                                                   transforms.Resize(28),
                                                   transforms.ToTensor(),
                                                   transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                   transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                                                   transforms.Grayscale(num_output_channels=3)
                                               ])
                                               )
    test_dataset = torchvision.datasets.MNIST('./data/', train=False, download=True,
                                              transform=transforms.Compose([
                                                  transforms.Resize(28),  # resnet默认图片输入大小224*224
                                                  transforms.ToTensor(),
                                                  transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                                                  transforms.Grayscale(num_output_channels=3)
                                              ])
                                              )
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    sample, label = next(iter(train_loader))
    print(sample.shape)
    print("当前类: ", label)
    num_classes = 10
    lr = 0.01
    epochs = 5
    device = torch.device('cuda:0')
    teacher_model = resnet18(num_classes=num_classes)
    student_model = resnet50(num_classes=num_classes)
    optimizer_teacher = SGD(teacher_model.parameters(), lr=lr, momentum=0.9)
    optimizer_student = SGD(student_model.parameters(), lr=lr, momentum=0.9)
    loss_function = nn.CrossEntropyLoss()
    do_train(teacher_model, student_model, train_loader, test_loader, loss_function, epochs)


目录
相关文章
|
2月前
|
机器学习/深度学习 算法 调度
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
291 0
|
23天前
|
机器学习/深度学习 数据采集 存储
概率神经网络的分类预测--基于PNN的变压器故障诊断(Matlab代码实现)
概率神经网络的分类预测--基于PNN的变压器故障诊断(Matlab代码实现)
166 0
|
3月前
|
机器学习/深度学习 数据采集 运维
匹配网络处理不平衡数据集的6种优化策略:有效提升分类准确率
匹配网络是一种基于度量的元学习方法,通过计算查询样本与支持集样本的相似性实现分类。其核心依赖距离度量函数(如余弦相似度),并引入注意力机制对特征维度加权,提升对关键特征的关注能力,尤其在处理复杂或噪声数据时表现出更强的泛化性。
188 6
匹配网络处理不平衡数据集的6种优化策略:有效提升分类准确率
|
2月前
|
安全 网络性能优化 网络虚拟化
网络交换机分类与功能解析
接入交换机(ASW)连接终端设备,提供高密度端口与基础安全策略;二层交换机(LSW)基于MAC地址转发数据,构成局域网基础;汇聚交换机(DSW)聚合流量并实施VLAN路由、QoS等高级策略;核心交换机(CSW)作为网络骨干,具备高性能、高可靠性的高速转发能力;中间交换机(ISW)可指汇聚层设备或刀片服务器内交换模块。典型流量路径为:终端→ASW→DSW/ISW→CSW,分层架构提升网络扩展性与管理效率。(238字)
688 0
|
6月前
|
存储 数据管理 网络虚拟化
特殊网络类型分类
本文介绍了网络技术中的关键概念,包括虚拟局域网(VLAN)、存储区域网络(SAN)、网络桥接、接入网以及按拓扑结构和交换方式分类的网络类型。VLAN通过逻辑分隔提高性能与安全性;SAN提供高性能的数据存储解决方案;网络桥接实现不同网络间的互联互通;接入网解决“最后一千米”的连接问题。此外,文章详细对比了总线型、星型、树型、环型和网状型等网络拓扑结构的特点,并分析了电路交换、报文交换和分组交换的优缺点,为网络设计与应用提供了全面参考。
228 8
|
9月前
|
计算机视觉
RT-DETR改进策略【卷积层】| CGblock 内容引导网络 利用不同层次信息,提高多类别分类能力 (含二次创新)
RT-DETR改进策略【卷积层】| CGblock 内容引导网络 利用不同层次信息,提高多类别分类能力 (含二次创新)
194 5
RT-DETR改进策略【卷积层】| CGblock 内容引导网络 利用不同层次信息,提高多类别分类能力 (含二次创新)
|
网络协议
计算机网络的分类
【10月更文挑战第11天】 计算机网络可按覆盖范围(局域网、城域网、广域网)、传输技术(有线、无线)、拓扑结构(星型、总线型、环型、网状型)、使用者(公用、专用)、交换方式(电路交换、分组交换)和服务类型(面向连接、无连接)等多种方式进行分类,每种分类方式揭示了网络的不同特性和应用场景。
|
9月前
|
计算机视觉
YOLOv11改进策略【卷积层】| CGblock 内容引导网络 利用不同层次信息,提高多类别分类能力 (含二次创新)
YOLOv11改进策略【卷积层】| CGblock 内容引导网络 利用不同层次信息,提高多类别分类能力 (含二次创新)
382 0
|
11月前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot编码的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
344 2
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
336 3

热门文章

最新文章