【科普向】模型蒸馏和模型量化到底是什么???

简介: 在数字化快速发展的时代,人工智能(AI)技术已广泛应用,但大型深度学习模型对计算资源的需求日益增长,增加了部署成本并限制了其在资源有限环境下的应用。为此,研究人员提出了模型蒸馏和模型量化两种关键技术。模型蒸馏通过将大型教师模型的知识传递给小型学生模型,利用软标签指导训练,使学生模型在保持较高准确性的同时显著减少计算需求,特别适用于移动设备和嵌入式系统。模型量化则是通过降低模型权重的精度(如从32位浮点数到8位整数),大幅减少模型大小和计算量,提高运行速度,并能更好地适应低配置设备。量化分为后训练量化和量化感知训练等多种方法,各有优劣。

引言

在当今数字化快速发展的时代,人工智能(AI)技术已经渗透到我们生活的方方面面,从智能手机中的语音助手到自动驾驶车辆的安全系统,背后都离不开深度学习模型的支持。然而,随着这些模型变得越来越庞大和复杂,它们对计算资源的需求也日益增长,这不仅增加了部署成本,还限制了AI应用在资源有限环境下的广泛应用。例如,在移动设备或边缘计算场景中,由于硬件性能和功耗的限制,直接部署大型模型往往不可行。

为了解决这一问题,研究人员提出了两种关键技术——模型蒸馏(Model Distillation)与模型量化(Model Quantization),它们旨在通过不同的方式压缩复杂的深度学习模型,使得小型化后的模型能够在保持较高准确性的前提下,更高效地运行于各种平台上。这两种方法虽然侧重点不同,但都是为了实现同一个目标:让AI更加轻便、节能且易于部署

本文将以科普的形式向读者介绍模型蒸馏和模型量化的定义、工作原理及其应用场景,帮助大家理解这两项技术如何助力AI模型瘦身

模型蒸馏

模型蒸馏的概念

模型蒸馏(Model Distillation)是一种模型压缩和知识迁移的技术,旨在将一个大型、复杂且性能优异的教师模型(Teacher Model)中的知识传递给一个较小、计算效率更高的学生模型(Student Model),将复杂且大的模型作为Teacher,Student模型结构较为简单,用Teacher来辅助Student模型的训练,Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力,复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

其核心思想是利用教师模型输出的软标签(soft targets)—— 即概率分布而非硬标签(hard labels),来指导学生模型的训练。通过这种方式,学生模型不仅学习到数据的类别信息,还能够捕捉到类别之间的相似性和关系,从而提升其泛化能力。

该方法的优势在于能够在不显著损失性能的情况下,显著减少模型大小和计算需求,特别适用于资源受限的设备,如移动设备和嵌入式系统。

image.png

主要步骤

image.png

模型蒸馏通常包括以下几个步骤。

  1. 训练教师模型(Teacher Model):首先训练一个性能优异但通常较为庞大的教师模型。教师模型可以是任何高性能的深度学习模型,如深层神经网络、卷积神经网络(CNN)、Transformer等。

  2. 生成软标签(Soft Targets):使用训练好的教师模型对训练数据进行预测,获得每个样本的概率分布。这些概率分布作为软标签,包含了类别之间的相对关系信息。

  3. 训练学生模型(Student Model):设计一个较小的学生模型,并使用软标签以及硬标签共同训练。训练过程中,通常采用一个损失函数的加权组合,例如,交叉熵损失(用于硬标签)与 Kullback-Leibler 散度损失(用于软标签)。

  4. 优化与调整:通过调整温度参数、损失函数权重等超参数,优化学生模型的性能,使其尽可能接近教师模型。

关键技术与方法

软标签与温度参数

传统的训练方法通常使用硬标签,即每个样本对应一个确定的类别标签。而在模型蒸馏中,教师模型输出的是概率分布(软标签),这些概率反映了教师模型对各类别的信心程度。通过引入温度系数(temperature),可以平滑或锐化这个概率分布,从而提供更丰富的梯度信息,帮助学生模型更好地学习。

而对于温度系数,我们可以这么理解,假设有一位老师讲课速度非常快,信息密度很高,学生可能有点难以跟上。这时如果老师放慢速度,简化信息,就会让学生更容易理解。在模型蒸馏中,温度参数起到的就是类似“调节讲课速度”的作用,帮助学生模型(小模型)更好地理解和学习教师模型(大模型)的知识。专业点说就是让模型输出更加平滑的概率分布,方便学生模型捕捉和学习教师模型的输出细节。

数学表达式为:

image.png

较高的温度会使得输出分布更加平滑,能够更好地揭示类别之间的相似性,从而提供更丰富的知识给学生模型。训练过程中,通常会同时调整温度参数来优化蒸馏效果。

损失函数设计

模型蒸馏的损失函数通常由两部分组成:

  • 硬标签损失:例如交叉熵损失,用于衡量学生模型预测与真实标签之间的差异。

  • 软标签损失:例如 Kullback-Leibler 散度,用于衡量学生模型预测与教师模型输出概率分布之间的差异。

总损失可以表示为:

image.png

通过加权组合这两部分损失,可以平衡学生模型对硬标签和软标签的学习。

多任务学习与蒸馏

在某些情况下,可以将模型蒸馏与多任务学习结合,通过同时优化多个任务来提升学生模型的表现。这种方法有助于学生模型在多个方面模仿教师模型的能力。

案例分享

以下是一个完整的示例代码,从头训练教师模型并进行模型蒸馏到学生模型,我们以 CIFAR-10 数据集为例。

训练教师模型

首先,我们加载数据集并训练一个教师模型

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义教师模型
teacher_model = resnet34(pretrained=False, num_classes=10).to(device)

# 教师模型训练
print("Training Teacher Model...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-3)

for epoch in range(5):  # 使用较少的epoch演示
    teacher_model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = teacher_model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Teacher Model Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}")

# 保存教师模型
torch.save(teacher_model.state_dict(), 'teacher_model.pth')
print("Teacher Model Saved!")

训练学生模型

student_model = resnet18(pretrained=False, num_classes=10).to(device)

# 定义蒸馏损失函数
class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, true_labels):
        # 软目标损失
        soft_loss = self.kl_loss(
            nn.functional.log_softmax(student_logits / self.temperature, dim=1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        # 硬目标损失
        hard_loss = self.ce_loss(student_logits, true_labels)
        # 总损失
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

distillation_loss = DistillationLoss(temperature=3.0, alpha=0.7)

# 加载教师模型权重
teacher_model.load_state_dict(torch.load('teacher_model.pth'))
teacher_model.eval()

# 蒸馏训练学生模型
print("Training Student Model with Distillation...")
optimizer = optim.Adam(student_model.parameters(), lr=1e-3)

for epoch in range(5):  # 使用较少的epoch演示
    student_model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # 学生模型预测
        student_logits = student_model(images)
        # 教师模型预测(无梯度)
        with torch.no_grad():
            teacher_logits = teacher_model(images)

        # 计算蒸馏损失
        loss = distillation_loss(student_logits, teacher_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Student Model Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}")

# 测试学生模型性能
print("Testing Student Model...")
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = student_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Student Model Accuracy: {100 * correct / total:.2f}%")

# 保存学生模型
torch.save(student_model.state_dict(), 'student_model.pth')
print("Student Model Saved!")

模型量化

模型量化的概念

量化是一种将较大尺寸的模型(如 LLM 或任何深度学习模型)压缩为较小尺寸的方法,比如最开始训练出的权重是32位的浮点数,但是实际使用发现用16位来表示也几乎没有什么损失,但是模型文件大小降低一般,显存使用降低一半,处理器和内存之间的通信带宽要求也降低了,这意味着更低的成本、更高的收益。

image.png

这就像按照菜谱做菜,你需要确定每种食材的重量。你可以使用一个非常精确的电子秤,它可以精确到0.01克,这固然很好,因为你可以非常精确地知道每样食材的重量。但是,如果你只是做一顿家常便饭,实际上并不需要这么高的精度,你可以使用一个简单又便宜的秤,最小刻度是1克,虽然不那么精确,但是足以用来做一顿美味的晚餐。

image.png

左侧:基础模型大小计算(单位:GB),右侧:量化后的模型大小计算(单位:GB)在上图中,基础模型 Llama 3 8B 的大小为 32 GB。经过 Int8 量化后,大小减少到 8GB(减少了 75%)。使用 Int4 量化后,大小进一步减少到 4GB(减少约 90%)。这使模型大小大幅减少。

量化还有一个好处,那就是计算的更快

现代处理器中通常都包含了很多的低精度向量计算单元,模型可以充分利用这些硬件特性,执行更多的并行运算;同时低精度运算通常比高精度运算速度快,单次乘法、加法的耗时更短。这些好处还让模型得以运行在更低配置的机器上,比如没有高性能GPU的普通办公或家用电脑、手机等移动终端。

沿着这个思路,人们继续压缩出了8位、4位、2位的模型,体积更小,使用的计算资源更少。不过随着权重精度的降低,不同权重的值会越来越接近甚至相等,这会降低模型输出的准确度和精确度,模型的性能表现会出现不同程度的下降。

量化技术有很多不同的策略和技术细节,比如如动态量化、静态量化、对称量化、非对称量化等,对于大语言模型,通常采用静态量化的策略,在模型训练完成后,我们就对参数进行一次量化,模型运行时不再需要进行量化计算,这样可以方便地分发和部署。

量化的分类

根据不同的标准,量化方法可以被划分为多种类型:

按照量化时间点分类

  • 后训练量化(Post-Training Quantization, PTQ):这是指在模型训练完成后对模型进行量化的过程。PTQ简单易行,适用于已经训练好的模型,但可能会带来一定的精度损失。

  • 量化感知训练(Quantization-Aware Training, QAT):这种方法是在训练阶段引入量化机制,让模型在训练过程中“感知”到量化的影响,从而尽量减少量化带来的精度损失。虽然训练过程更为复杂且耗时较长,但它可以在保持较高精度的同时实现模型压缩。

按照量化粒度分类

  • Per-tensor量化:整个张量或层级共享相同的量化参数(scale和zero-point)。这种方式的优点是存储和计算效率较高,但可能导致精度损失。

  • Per-channel量化:每个通道或轴都有自己的量化参数。这种方式可以更准确地量化数据,因为每个通道可以根据自身特性调整动态范围,但会增加存储需求和计算复杂度。

  • Per-group量化:将数据分组处理,每组有自己的量化参数,介于上述两者之间。

按照量化后的数值范围分类

  • 二值量化(Binary Quantization):将权重限制在+1和-1两个值之间。

  • 三值量化(Ternary Quantization):允许使用三个离散值,通常是-1、0和+1。

  • 定点数量化(Fixed-Point Quantization):最常见的是INT8和INT4,它们分别用8位和4位整数表示权重。

  • 非均匀量化(Non-uniform Quantization):根据待量化参数的概率分布计算量化节点,以适应特定的数据分布模式。

按照是否线性映射分类

  • 线性量化(Linear Quantization):采用线性映射的方式将浮点数映射到整数范围内。它可以进一步细分为对称量化和非对称量化两种形式。

  • 非线性量化(Non-linear Quantization):例如对数量化,它不是简单的线性变换,而是基于某种函数关系来进行映射。

非对称量化的实现

此处以非对称量化为例。非对称量化方法将原始张量范围(Wmin, Wmax)中的值映射到量化张量范围(Qmin, Qmax)中的值。

image.png

  • Wmin, Wmax:原始张量的最小值和最大值(数据类型:FP32,32 位浮点)。在大多数现代 LLM 中,权重张量的默认数据类型是 FP32。

  • Qmin, Qmax: 量化张量的最小值和最大值(数据类型:INT8,8 位整数)。我们也可以选择其他数据类型,如 INT4、INT8、FP16 和 BF16 来进行量化。我们将在示例中使用 INT8。

  • 缩放值(S):在量化过程中,缩放值将原始张量的值缩小以获得量化后的张量。在反量化过程中,它将量化后的张量值放大以获得反量化值。缩放值的数据类型与原始张量相同,为 FP32。

  • 零点(Z):零点是量化张量范围中的一个非零值,它直接映射到原始张量范围中的值 0。零点的数据类型为 INT8,因为它位于量化张量范围内。

  • 量化:图中的“A”部分展示了量化过程,即 [Wmin, Wmax] -> [Qmin, Qmax] 的映射。

  • 反量化:图中的“B”部分展示了反量化过程,即 [Qmin, Qmax] -> [Wmin, Wmax] 的映射。

那么,我们如何从原始张量值导出量化后的张量值呢?这其实很简单。如果你还记得高中数学,你可以很容易理解下面的推导过程。让我们一步步来(建议在推导公式时参考上面的图表,以便更清晰地理解)。

image.png
image.png

细节1:如果Z值超出范围怎么办?解决方案:使用简单的if-else逻辑将Z值调整为Qmin,如果Z值小于Qmin;若Z值大于Qmax,则调整为Qmax。这个方法在图4的图A中有详细描述。

细节2:如果Q值超出范围怎么办?解决方案:在PyTorch中,有一个名为 clamp 的函数,它可以将值调整到特定范围内(在我们的示例中为-128到127)。因此,clamp函数会将Q值调整为Qmin如果它低于Qmin,将Q值调整为Qmax如果它高于Qmax。

image.png

模型蒸馏和模型量化对比

image.png

相关文章
|
9月前
|
人工智能 自动驾驶 机器人
ICLR 2024:模型选择驱动的鲁棒多模态模型推理
【2月更文挑战第24天】ICLR 2024:模型选择驱动的鲁棒多模态模型推理
97 1
ICLR 2024:模型选择驱动的鲁棒多模态模型推理
|
9月前
|
机器学习/深度学习 前端开发 PyTorch
【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!
【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!
116 0
【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!
|
1月前
|
人工智能 机器人
LeCun 的世界模型初步实现!基于预训练视觉特征,看一眼任务就能零样本规划
纽约大学Gaoyue Zhou等人提出DINO World Model(DINO-WM),利用预训练视觉特征构建世界模型,实现零样本规划。该方法具备离线训练、测试时行为优化和任务无关性三大特性,通过预测未来补丁特征学习离线行为轨迹。实验表明,DINO-WM在迷宫导航、桌面推动等任务中表现出强大的泛化能力,无需依赖专家演示或奖励建模。论文地址:https://arxiv.org/pdf/2411.04983v1。
53 21
|
5月前
|
机器学习/深度学习 搜索推荐
CIKM 2024:LLM蒸馏到GNN,性能提升6.2%!Emory提出大模型蒸馏到文本图
【9月更文挑战第17天】在CIKM 2024会议上,Emory大学的研究人员提出了一种创新框架,将大型语言模型(LLM)的知识蒸馏到图神经网络(GNN)中,以克服文本图(TAGs)学习中的数据稀缺问题。该方法通过LLM生成文本推理,并训练解释器模型理解这些推理,再用学生模型模仿此过程。实验显示,在四个数据集上性能平均提升了6.2%,但依赖于LLM的质量和高性能。论文链接:https://arxiv.org/pdf/2402.12022
145 7
|
8月前
|
Python
技术心得:判别式模型vs.生成式模型
技术心得:判别式模型vs.生成式模型
44 0
|
9月前
|
机器学习/深度学习 人工智能 自然语言处理
论文介绍:自我对弈微调——将弱语言模型转化为强语言模型的新方法
【5月更文挑战第17天】论文《自我对弈微调》提出了一种新方法,名为SPIN,用于在无需额外人工标注数据的情况下增强大型语言模型(LLM)。SPIN利用自我对弈机制,让模型通过与自身历史版本交互生成自我训练数据,实现性能提升。该方法在多个基准数据集上表现出色,超越了传统监督微调和直接偏好优化。SPIN还为生成对抗网络研究提供了新思路,展示了自我对弈在强化学习和深度学习中的潜力。实验表明,SPIN有效提升了模型性能,为未来研究奠定了基础。[[arxiv](https://arxiv.org/abs/2401.01335v1)]
95 3
|
9月前
|
存储 并行计算 算法
大模型量化技术解析和应用
眼看人工智能含智能量越来越高含人量越来越低,是否开始担心自己要跟不上这趟高速列车了?内心是否也充满好奇:大模型背后的奥秘是什么?为何如此强大?它能为我所用吗?哪种技术最适合我的需求?
|
9月前
|
机器学习/深度学习 自然语言处理 计算机视觉
【大模型】小样本学习的概念及其在微调 LLM 中的应用
【5月更文挑战第5天】【大模型】小样本学习的概念及其在微调 LLM 中的应用
|
9月前
|
机器学习/深度学习 人工智能 关系型数据库
南京大学提出量化特征蒸馏方法QFD | 完美结合量化与蒸馏,让AI落地更进一步!!!
南京大学提出量化特征蒸馏方法QFD | 完美结合量化与蒸馏,让AI落地更进一步!!!
235 0
|
机器学习/深度学习 数据可视化 索引
斯坦福训练Transformer替代模型:1.7亿参数,能除偏、可控可解释性强
斯坦福训练Transformer替代模型:1.7亿参数,能除偏、可控可解释性强
156 2

热门文章

最新文章