【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!

简介: 【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!

导读

  深度学习中的蒸馏机制是一种有效的模型压缩方法,可以将大型神经网络压缩为小型神经网络,同时可以最大限度的保持原始网络的性能。在实际应用中,蒸馏机制可以帮助我们在资源受限的设备上部署更加高效的神经网络模型,从而在保证准确性的前提下提升计算效率。本文将会深入探讨蒸馏机制的工作原理、应用场景和实现方法,并为读者提供实际案例和代码实现,帮助大家更好地理解和应用蒸馏机制。

前言

  在本文中的实验数据集为花朵数据集,该数据集类别数经过筛选总共为5类,数据集划分为训练集与测试集,训练:测试 = 500 : 150。


image.png   

在网络的选择方面,我们选择高精度的大模型为Vgg11,小模型为轻量级的shufflenet_v2_x0_5。这两个模型在torchvision\models中都可以找到,有现成的,不需要我们自己手撸代码。这两个权重的大小分别为:

         image.png

实操蒸馏

   常见的蒸馏有四大类:基础蒸馏、负责蒸馏、激励蒸馏和自适应蒸馏,下面我们会对这四大类蒸馏进行实操演示。

基础模块

   在实操之前,我们有一些代码是公共的,这里包括数据集的获取,网络的定义及调用,可以构建一个【init.py】文件进行定义基础公共代码部分:

ini

复制代码

import torch
import torchvision
from torch.utils.data import DataLoader
from shufflenetv2 import shufflenet_v2_x0_5
from vgg import vgg11
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_Data():
    train_root = 'D:/Data_ALL/datas/train/'
    test_root = 'D:/Data_ALL/datas/test/'
    # 将文件夹的内容载入dataset
    train_dataset = torchvision.datasets.ImageFolder(root=train_root, transform=torchvision.transforms.ToTensor())
    test_dataset = torchvision.datasets.ImageFolder(root=test_root, transform=torchvision.transforms.ToTensor())
    train_dataloader = DataLoader(train_dataset, batch_size=12,shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=8)
    train_num = len(train_dataset)
    test_num = len(test_dataset)
    return train_dataloader, test_dataloader, train_num, test_num
def get_net():
    model_t = vgg11(pretrained=False, num_classes=5).to(device)
    model_s = shufflenet_v2_x0_5(pretrained=False, num_classes=5).to(device)
    return model_t, model_s
def Soomth_temp(init_temp, global_step, decay_steps=0.07):
    temp = init_temp - (global_step * decay_steps)
    if temp < 1:
        temp = 1
    return temp

测试模块

   同时再构建【demo.py】函数对核心代码部分的main部分进行设定好损失函数以及优化器和读取init.py数据。

ini

复制代码

def test(dataloader):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for test_image, test_label in dataloader:
            test_image, test_label = test_image.to(device), test_label.to(device)
            test_pred = model_s(test_image)
            test_loss += criterion(test_pred, test_label).item()
            correct += (test_pred.argmax(1) == test_label).type(torch.float).sum().item()
    test_loss /= test_num
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
if __name__ == "__main__":
    # 设置网络的部分
    model_t, model_s = get_net()
    # 设置数据集的部分
    train_dataloader, test_dataloader, train_num, test_num = get_Data()
    # 定义温度参数T
    T = 10
    # 定义交叉熵损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义优化器
    optimizer = optim.Adam(net_s.parameters(), lr=0.001)
    # 设置超参数的部分
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.94)    # LR变动的学习率
    epochs = 100
    for epoch in range(epochs):
        train(train_dataloader)
        test(test_dataloader)

   需要注意的是,在训练结束后,我们需要使用小型模型进行预测,而不是使用带有温度参数的“软目标”进行预测。

基础蒸馏

   基础蒸馏是最简单的蒸馏机制,它通过让小型模型学习大型模型的预测结果来传递知识。在这种方法中,大型模型的输出被视为“软目标”,并被用作小型模型的训练目标。这种方法的优点是简单易用,但是它通常需要使用较高的温度参数,导致模型的精度不如其他方法。

css

复制代码

from torch import optim, nn
from init import *
import torch.nn.functional as F
# 定义训练函数
def train(dataloader):
    model_s.train()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_s(inputs)
        # 对大型模型输出进行softmax转换并缩放温度参数
        teacher_outputs = model_t(inputs).detach()
        teacher_outputs = nn.functional.softmax(teacher_outputs / T, dim=1)
        # 计算损失函数,其中使用“软目标”作为辅助损失
        loss = criterion(outputs, labels) + nn.KLDivLoss()(nn.functional.log_softmax(outputs / T, dim=1),
                                                           teacher_outputs) * T * T
        # 反向传播更新参数
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

   在训练函数中,我们首先将大型模型的输出进行softmax转换,并使用温度参数对其进行缩放,从而得到“软目标”的分布。然后,我们使用“软目标”作为辅助损失来帮助小型模型进行训练。最后,我们将两个损失函数相加,得到最终的损失函数,并使用反向传播更新参数。在训练过程中,我们使用KL散度来衡量小型模型输出分布与“软目标”分布之间的距离,从而提高小型模型的泛化能力。

负责蒸馏

   负责蒸馏是一种基于网络层次结构的蒸馏方法。它的基本思想是让小型模型学习大型模型中的“特定”层次信息,如中间层的特征图。这些信息可以被看作是大型模型的“中间目标”,并被用作小型模型的训练目标。这种方法的优点是可以更精确地传递知识,从而获得更高的模型精度。

ini

复制代码

def train(dataloader):
    alpha = 0.5
    model_s.train()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_s(inputs)
        # 对大型模型输出进行softmax转换并缩放温度参数
        teacher_outputs = model_t(inputs).detach()
        teacher_outputs = nn.functional.softmax(teacher_outputs / T, dim=1)
        # 计算损失函数,其中使用“软目标”作为辅助损失
        loss = (1 - alpha) * criterion(outputs, labels) + alpha * nn.KLDivLoss()(
            nn.functional.log_softmax(outputs / T, dim=1), teacher_outputs) * T * T
        # 反向传播更新参数
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

   我们使用VGG11作为大型模型,ShuffleNet V2 x0.5作为小型模型。在训练函数中,我们首先对大型模型的输出进行softmax转换,并使用温度参数对其进行缩放,从而得到“软目标”的分布。然后,我们对小型模型的输出进行相同的操作。接着,我们通过对所有层的输出计算KL散度来定义辅助损失函数,并将所有辅助损失函数相加,得到总辅助损失。最后,我们将交叉熵损失函数和总辅助损失加权相加,得到最终的损失函数,并使用反向传播更新参数。在训练过程中,我们使用KL散度来衡量小型模型输出分布与“软目标”分布之间的距离,并使用权重参数。

激励蒸馏

   激励蒸馏是一种基于激活函数的蒸馏方法,它通过让小型模型学习大型模型中的激活函数来传递知识。具体而言,我们可以通过让小型模型学习大型模型中的激活函数输出,来获得更加精确的模型压缩效果。这种方法的优点是可以更精确地控制模型的软硬程度,从而获得更好的模型压缩效果。

ini

复制代码

alpha = 0.5
def train(dataloader):
    model_s.train()
    model_t.eval()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_s(inputs)
        # 获取大模型的中间层激活
        with torch.no_grad():
            teacher_outputs = model_t.features(inputs)  # 大家可以自行寻找
        # 获取小模型的中间层激活
        student_outputs = model_s.features(inputs)  # 大家可以自行寻找
        # 定义辅助损失函数
        loss = 0
        for i in range(len(teacher_outputs)):
            # 对大模型和小模型的中间层激活进行softmax转换并缩放温度参数
            teacher_activation = nn.functional.softmax(teacher_outputs[i] / T, dim=1)
            student_activation = nn.functional.softmax(student_outputs[i] / T, dim=1)
            # 计算中间层激活的KL散度作为辅助损失
            loss += nn.KLDivLoss()(student_activation, teacher_activation.detach()) * T * T
        # 计算主要损失函数
        loss += (1 - alpha) * criterion(outputs, labels)
        # 反向传播更新参数
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

自适应蒸馏

   自适应蒸馏是一种基于动态温度控制的蒸馏方法,它通过动态调整温度参数来实现更好的模型压缩效果。具体而言,我们可以根据小型模型的训练进展来自适应地调整温度参数,从而达到更好的模型压缩效果。这种方法的优点是可以自适应地调整模型的软硬程度,从而获得更好的压缩效果。

ini

复制代码

alpha = 0.5
def train(dataloader):
    model_s.train()
    model_t.eval()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_s(inputs)
        # 获取大模型的中间层激活
        with torch.no_grad():
            teacher_outputs = model_t.features(inputs)
        # 获取小模型的中间层激活
        student_outputs = model_s.features(inputs)
        # 定义辅助损失函数
        loss = 0
        for i in range(len(teacher_outputs)):
            # 计算大模型和小模型中间层激活的L2范数,作为动态的温度参数
            T = torch.norm(teacher_outputs[i] - student_outputs[i], p=2) / torch.numel(teacher_outputs[i])
            # 对大模型和小模型的中间层激活进行softmax转换并缩放温度参数
            teacher_activation = nn.functional.softmax(teacher_outputs[i] / T, dim=1)
            student_activation = nn.functional.softmax(student_outputs[i] / T, dim=1)
            # 计算中间层激活的KL散度作为辅助损失
            loss += nn.KLDivLoss()(student_activation, teacher_activation.detach()) * T * T
        # 计算主要损失函数
        loss += criterion(outputs, labels)
        # 反向传播更新参数
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

   在上述示例中,我们使用了 PyTorch 的 vgg11shufflenet_v2_x0_5 模型,并通过 features 属性获取了这些模型的中间层激活。我们通过计算大模型和小模型中间层激活的L2范数来动态地确定温度参数T。

结语

   蒸馏机制是一种有效的深度学习模型压缩技术,它可以将一个复杂的模型压缩成一个简单的模型,同时保持高精度。在实际应用中,蒸馏机制已经被广泛应用于各种场景,例如移动端深度学习模型的压缩、模型部署以及高效模型训练等方面。不同的蒸馏机制适用于不同的应用场景,需要根据具体的应用需求进行选择


相关文章
|
缓存 NoSQL 关系型数据库
面试必问:Redis 如何实现库存扣减操作?
面试必问:Redis 如何实现库存扣减操作?
1394 7
面试必问:Redis 如何实现库存扣减操作?
|
JavaScript 前端开发 程序员
|
监控 安全 Java
【Java面试】多线程操作中常用的几个方法
【Java面试】多线程操作中常用的几个方法
74 0
|
Java Spring
面试突击82:SpringBoot 中如何操作事务?
面试突击82:SpringBoot 中如何操作事务?
327 0
|
存储 算法 C语言
C++模拟面试:从数组“紧凑”操作说开来
C++模拟面试:从数组“紧凑”操作说开来
195 0
C++模拟面试:从数组“紧凑”操作说开来
|
存储 安全 Java
java 中操作字符串都有哪些类?它们之间有什么区别?面试篇(第七天)
String:对字符串不进行重复操作时选择用String。 StringBuilder:在单线程中对字符串进行重复操作时选择用StringBuilder。 StringBuffer:在多线程中对字符串进行重复操作时选择用StringBuffer。
145 0
java 中操作字符串都有哪些类?它们之间有什么区别?面试篇(第七天)
|
机器学习/深度学习 Java 开发工具
面试/工作必备的vim基础及快捷键操作
面试/工作必备的vim基础及快捷键操作
203 0
面试/工作必备的vim基础及快捷键操作
|
开发工具 git
Git工作/面试必知必会操作-命令行篇(五)
Git工作/面试必知必会操作-命令行篇(五)
117 0
Git工作/面试必知必会操作-命令行篇(五)
|
开发工具 git
Git工作/面试必知必会操作-命令行篇(四)
Git工作/面试必知必会操作-命令行篇(四)
128 0
Git工作/面试必知必会操作-命令行篇(四)
|
Linux 网络安全 开发工具
Git工作/面试必知必会操作-命令行篇(三)
Git工作/面试必知必会操作-命令行篇(三)
162 0
Git工作/面试必知必会操作-命令行篇(三)