PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践

本文涉及的产品
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,5000CU*H 3个月
简介: 本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。

在神经网络研究的前沿,我们正面临着模型精度与运行效率之间的权衡挑战。尽管架构优化、层融合和模型编译等技术已取得显著进展,但这些方法往往不足以同时满足边缘设备部署所需的模型尺寸和精度要求。

研究人员通常采用三种主要策略来实现模型压缩同时保持准确性:

  • 模型量化:通过降低模型权重的数值精度表示(例如将16位浮点数转换为8位整数),减少神经网络的内存占用和计算复杂度。
  • 模型剪枝:识别并移除训练好的神经网络中贡献较小的神经元或权重,以简化网络架构而不显著影响性能。
  • 知识蒸馏(又称教师-学生训练):训练一个更小、更高效的网络(学生模型)来复现更大、更复杂模型(教师模型)的软预测输出。软标签使学生模型获得更好的泛化能力,因为它们代表了类别相似性的高层次抽象理解,而非传统的独热编码表示。

本文将深入探讨模型量化的原理、主要量化技术类型以及如何使用PyTorch实现这些技术。

量化技术基础


量化是神经网络优化中最强大且实用的技术之一。它通过将模型数据(包括网络参数和激活值)从高精度浮点表示(通常为16位)转换为低精度表示(通常为8位整数),从而降低神经网络的计算和内存需求。这种转换带来多方面的优势:

  • GPU可利用更快速、更经济的8位计算单元(如NVIDIA GPU的Tensor Cores)执行卷积和矩阵乘法运算,显著提高计算吞吐量。
  • 对于受内存带宽限制的网络层,量化可显著降低数据传输需求,减少总体运行时间。这类层的运行瓶颈主要在数据读写而非计算本身,因此从带宽优化中获益最大。
  • 模型内存占用的减少不仅节省存储空间,还能减小参数更新大小,提高缓存利用率。
  • 数据从内存传输到计算单元的过程消耗能量。将精度从16位降至8位能使数据量减半,有效降低功耗。

将高精度数值映射至低精度表示有多种方法(如零点量化、绝对最大值量化等),本文不作深入讨论。对此感兴趣的读者可参考Hao Wu等人和Amir Gholani等人的相关技术论文。

量化方法体系


神经网络量化主要分为两种方法:

1、训练后量化 (PTQ)

PTQ在模型完成训练后应用,无需重新训练即可将模型转换为低精度表示。该方法使用校准数据集确定最优量化参数,通过收集模型激活的统计信息并计算适当的量化参数,以最小化浮点表示和量化表示之间的差异。

PTQ具有资源效率高、实现部署快速的优势,适用于无法重新训练的场景。然而,此类模型的准确度相对较低,需要精心校准和参数调优,因此更适合快速原型验证而非正式部署。

训练后量化可进一步细分为两种实现方式:

动态训练后量化

这种方法在推理过程中根据实时输入数据分布动态调整激活值的量化范围。

静态训练后量化

该方法引入额外的校准步骤,使用代表性数据集估计激活值范围。估计过程在完整精度下进行以最小化误差,随后将激活值缩减为低精度数据类型。

2、量化感知训练 (QAT)

QAT是一种在模型训练过程中模拟量化效应的方法。它通过引入"伪量化"操作来模拟低精度对权重和激活值的影响。本质上模型在量化约束条件下进行训练。网络在训练期间使用直通估计器(STE)等技术计算梯度,学习适应量化引入的噪声,从而在低精度环境中保持高性能。

QAT通常能获得更高的准确率,因为模型能在训练过程中适应量化效应,特别适用于对量化误差敏感的架构。但这也意味着需要额外的计算资源和训练时间,实现复杂度也相对较高。

量化感知训练原理


相比于PTQ在训练后应用量化,QAT的优势在于它在训练期间插入"伪量化"模块。这使模型能够"感知"量化噪声并学习如何补偿这种噪声,最终得到一个量化模型,其准确率与全精度对应版本非常接近。QAT工作流程如下:

准备阶段:用模拟量化的包装器替换网络中的敏感层(如卷积层、线性层、激活函数层)。在PyTorch中,这通过

prepare_qat

prepare_qat_fx

函数实现。

训练阶段:在每次前向传播中,权重和激活值都经过"伪量化"处理——即进行类似INT8/INT4精度的四舍五入和截断。反向传播采用STE技术,使梯度计算如同量化操作是恒等函数一样。

转换阶段:训练完成后,使用

convert

convert_fx

函数将伪量化模块替换为实际的量化运算核心。此时模型已准备好进行高效的

int8/int4

推理。

伪量化的数学基础


以下是量化过程的简化数学表达。

假设

x_float

为实值激活。均匀仿射量化使用:

 scale  = (x_max – x_min) / (q_max – q_min)  
 zeroPt = round(q_min – x_min / scale)  
 x_q    = clamp( round(x_float / scale) + zeroPt, q_min, q_max )  
 x_deq  = (x_q – zeroPt) * scale

在QAT期间,伪量化操作表示为:

 x_fake = (round(x_float/scale)+zeroPt – zeroPt) * scale

因此

x_fake

仍然是浮点数,但被限制在与

int8

张量相同的离散格点上。

梯度传播机制 — 直通估计器


训练前向传播(L)和后向传播(R)中的QAT伪量化算子

由于四舍五入操作不可微分,PyTorch采用如下近似:

 dL/dx_float ≈ dL/dx_fake

在反向传播中,伪量化模块被视为梯度计算的恒等函数,这使优化器能够调整上游权重以抵消量化产生的噪声。

这一过程引导网络权重自然地向整数中心靠拢,结合优化后的

scale

zeroPt

参数,最小化整体重建误差。

实践实现

PyTorch提供三种不同的量化模式:

1、Eager模式量化

这是一项Beta阶段功能。用户需要手动执行层融合并明确指定量化和反量化的位置。此外该模式仅支持模块API而不支持函数式API。

以下代码示例展示了从模型定义到QAT准备,再到最终

int8

转换的完整流程。

 import os, torch, torch.nn as nn, torch.optim as optim  

# 1. 使用QuantStub/DeQuantStub定义模型
class QATCNN(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.quant   = torch.quantization.QuantStub()  
        self.conv1   = nn.Conv2d(1, 16, 3, padding=1)  
        self.relu1   = nn.ReLU()  
        self.pool    = nn.MaxPool2d(2)  
        self.conv2   = nn.Conv2d(16, 32, 3, padding=1)  
        self.relu2   = nn.ReLU()  
        self.fc      = nn.Linear(32*14*14, 10)  
        self.dequant = torch.quantization.DeQuantStub()  

    def forward(self, x):  
        x = self.quant(x)  
        x = self.pool(self.relu1(self.conv1(x)))  
        x = self.relu2(self.conv2(x))  
        x = x.flatten(1)  
        x = self.fc(x)  
        return self.dequant(x)  

# 2. QAT准备
model = QATCNN()  
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')  
torch.quantization.prepare_qat(model, inplace=True)  

# 3. 微型训练循环
opt = optim.SGD(model.parameters(), lr=1e-2)  
crit = nn.CrossEntropyLoss()  
for _ in range(3):  
    inp = torch.randn(16,1,28,28)  
    tgt = torch.randint(0,10,(16,))  
    opt.zero_grad(); crit(model(inp), tgt).backward(); opt.step()  

# 4. 转换为真实的int8
model.eval()  
int8_model = torch.quantization.convert(model)  

# 5. 存储优势
torch.save(model.state_dict(), "fp32.pth")  
torch.save(int8_model.state_dict(), "int8.pth")  
mb = lambda p: os.path.getsize(p)/1e6  
 print(f"FP32: {mb('fp32.pth'):.2f} MB  vs  INT8: {mb('int8.pth'):.2f} MB")

预期结果:在类MNIST数据上,模型尺寸约减少4倍,精度损失不超过1%。

工作原理

torch.quantization.prepare_qat

函数递归地用

FakeQuantize

模块包装每个符合条件的层,默认的

FBGEMM

qconfig配置选择逐张量权重观察器和逐通道激活观察器,特别适合服务器/边缘CPU部署场景。

2、FX图模式量化

这是PyTorch中的自动化量化工作流,目前处于维护状态。它通过支持函数式API和自动化量化过程增强了Eager模式量化功能,但用户可能需要重构模型以确保兼容性。

需要注意的是,由于符号追踪的潜在限制,该方法可能不适用于任意模型结构,使用时需要熟悉

torch.fx

框架。使用此方法的代码示例如下:

 import torch, torchvision.models as models  
from torch.ao.quantization import get_default_qat_qconfig_mapping  
from torch.ao.quantization import prepare_qat_fx, convert_fx  

model = models.resnet18(weights=None)     # 或pretrained=True  
model.train()  

# 单行qconfig映射
qmap = get_default_qat_qconfig_mapping("fbgemm")  
# 图重写
model_prepared = prepare_qat_fx(model, qmap)  

# 微调几个周期
model_prepared.eval()  
 int8_resnet = convert_fx(model_prepared)

FX模式在图级别运行:

conv2d

batch_norm

relu

等算子会自动融合,从而在CPU上产生更高效的计算内核和更优的延迟性能。

3、PyTorch 2导出量化

PT2E (PyTorch 2 Export)特别适合将导出的计算图交付给C++运行时环境。这是PyTorch 2.1中发布的新一代全图模式量化工作流,专为

torch.export

捕获的模型设计。整个过程可通过几行代码实现:

 import torch  
from torch import nn  
from torch._export import capture_pre_autograd_graph  
from torch.ao.quantization.quantize_pt2e import (  
    prepare_qat_pt2e, convert_pt2e)  
from torch.ao.quantization.quantizer.xnnpack_quantizer import (  
    XNNPACKQuantizer, get_symmetric_quantization_config)  

class Tiny(nn.Module):  
    def __init__(self): super().__init__(); self.fc=nn.Linear(8,4)  
    def forward(self,x): return self.fc(x)  

ex_in = (torch.randn(2,8),)  
exported = torch.export.export_for_training(Tiny(), ex_in).module()  
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())  
qat_mod = prepare_qat_pt2e(exported, quantizer)  

# 微调模型...  
int8_mod = convert_pt2e(qat_mod)  
 torch.ao.quantization.move_exported_model_to_eval(int8_mod)

生成的计算图已准备好用于

torch::deploy

或提前(AOT)编译到移动端推理引擎中。

4、大语言模型Int4/Int8混合精度演示

虽然不属于正式API,但

torchao

/

torchtune

也提供了用于极致模型压缩的原型量化器:

 import torch  
from torchtune.models.llama3 import llama3  
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer  

model = llama3(vocab_size=4096, num_layers=16,  
               num_heads=16, num_kv_heads=4,  
               embed_dim=2048, max_seq_len=2048).cuda()  

qat_quant = Int8DynActInt4WeightQATQuantizer()  
model = qat_quant.prepare(model).train()  

#  ––– 简化微调过程 –––  
optim = torch.optim.AdamW(model.parameters(), 1e-4)  
lossf = torch.nn.CrossEntropyLoss()  
for _ in range(100):  
    ids   = torch.randint(0,4096,(2,128)).cuda()  
    label = torch.randint(0,4096,(2,128)).cuda()  
    loss  = lossf(model(ids), label)  
    optim.zero_grad(); loss.backward(); optim.step()  

model_quant = qat_quant.convert(model)  
 torch.save(model_quant.state_dict(),"llama3_int4int8.pth")

在这种配置下,模型激活以

int8

精度运行,权重以

int4

精度运行,在单个A100 GPU上可实现超过2倍的性能提升和约60%的内存降低,同时困惑度仅增加不到0.8个百分点。

有关

torchao

torchtune

进行LLM量化的更多信息,推荐阅读PyTorch官方博客的相关内容。

量化实践最佳策略

为在最小化精度损失的前提下最大化模型压缩效果,应遵循以下关键策略:

首先应使用PTQ技术进行初步量化尝试。若PTQ导致的精度损失低于2%,通常只需进行短期QAT微调(5-10个周期)即可获得理想效果。执行消融分析以识别对量化敏感的网络层是非常必要的,当发现某层量化后性能显著下降时,可考虑保留其原始精度。尽早融合操作(如

Conv + BN + ReLU

)能够稳定观察器量化范围并提高精度。

训练几个周期后,应当调用

torch.ao.quantization.disable_observer

函数并使用

freeze_bn_stats

冻结批量归一化统计数据,防止范围出现振荡。监控量化过程中的权重直方图分布(可通过

torch.ao.quantization.get_observer_state_dict()

或使用Netron工具)有助于发现异常值。在STE近似有效工作时,较小的学习率(不超过1e-3)可避免参数过度调整。

对于权重量化,逐通道量化方法相较于逐张量量化能将误差减半,是卷积层的推荐默认设置。如果模型准确率仍有显著下降,考虑采用混合精度策略,将首层和末层保持在

fp16

精度以保证安全。最后,根据目标硬件平台选择合适的量化配置:x86架构使用

FBGEMM

,ARM架构使用

QNNPACK/XNNPACK

总结

神经网络模型部署需要采取全面的优化策略——构建准确的模型通常是相对容易的部分,而真正的挑战在于实现高效的大规模部署。当标准的PTQ方法无法满足精度要求时,QAT技术提供了有效的解决方案。然而,成功部署量化模型需要充分考虑多方面因素,包括目标平台及其支持的操作集合。PyTorch凭借其成熟的QAT工具链,为用户提供了便捷灵活的模型量化能力,适用于从简单CNN到拥有数十亿参数的大型语言模型等各类深度学习应用场景。

https://avoid.overfit.cn/post/c4a82be1e3a84f79912849651c4f4714

Sahib Dhanjal

目录
相关文章
|
12月前
|
物联网 网络架构
PHATGOOSE:使用LoRA Experts创建低成本混合专家模型实现零样本泛化
这篇2月的新论文介绍了Post-Hoc Adaptive Tokenwise Gating Over an Ocean of Specialized Experts (PHATGOOSE),这是一种通过利用一组专门的PEFT模块(如LoRA)实现零样本泛化的新方法
130 0
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
ModernBERT:英伟达开源的新一代编码器模型,性能超越 SOTA,通过去除填充和序列打包减少计算浪费,提高训练和推理的效率
ModernBERT 是由英伟达和 HuggingFace 等机构联合开源的新一代编码器模型,支持长上下文处理,性能超越 SOTA,适合多种自然语言处理任务。
213 7
ModernBERT:英伟达开源的新一代编码器模型,性能超越 SOTA,通过去除填充和序列打包减少计算浪费,提高训练和推理的效率
|
7月前
|
人工智能 测试技术 数据处理
首个Mamba+Transformer混合架构多模态大模型来了,实现单卡千图推理
【10月更文挑战第18天】《LongLLaVA: Scaling Multi-modal LLMs to 1000 Images Efficiently via Hybrid Architecture》提出了一种新型多模态大模型LongLLaVA,结合了Mamba和Transformer架构,通过系统优化实现在单张A100 80GB GPU上处理近千张图像的突破。该模型在视频理解、高分辨率图像分析和多模态智能体任务中表现出色,显著提升了计算效率。
263 65
|
5月前
|
人工智能 物联网 C语言
SVDQuant:MIT 推出的扩散模型后训练的量化技术,能够将模型的权重和激活值量化至4位,减少内存占用并加速推理过程
SVDQuant是由MIT研究团队推出的扩散模型后训练量化技术,通过将模型的权重和激活值量化至4位,显著减少了内存占用并加速了推理过程。该技术引入了高精度的低秩分支来吸收量化过程中的异常值,支持多种架构,并能无缝集成低秩适配器(LoRAs),为资源受限设备上的大型扩散模型部署提供了有效的解决方案。
243 5
SVDQuant:MIT 推出的扩散模型后训练的量化技术,能够将模型的权重和激活值量化至4位,减少内存占用并加速推理过程
|
11月前
|
机器学习/深度学习 存储 自然语言处理
【机器学习】LoRA:大语言模型中低秩自适应分析
【机器学习】LoRA:大语言模型中低秩自适应分析
368 5
|
机器学习/深度学习 人工智能 算法
【CIKM 2023】扩散模型加速采样算法OLSS,大幅提升模型推理速度
近日,阿里云人工智能平台 PAI与华东师范大学陈岑副教授团队合作在深度学习顶级会议 CIKM 2023 上发表 OLSS (Optimal Linear Subspace Search) 算法,这是一种针对扩散模型的采样加速算法。在这篇论文中,扩散模型加速算法的本质被建模成线性子空间的扩张过程,给出了目前方法的统一分析,并基于此设计了新的加速算法,大幅度提升了扩散模型的生成速度。
|
11月前
|
机器学习/深度学习 数据采集 人工智能
【机器学习】CLIP模型在有限计算资源下的性能探究:从数据、架构到训练策略
【机器学习】CLIP模型在有限计算资源下的性能探究:从数据、架构到训练策略
514 0
|
11月前
|
机器学习/深度学习 自然语言处理 物联网
ICML 2024:脱离LoRA架构,训练参数大幅减少,新型傅立叶微调来了
【6月更文挑战第4天】在ICML 2024上,研究团队提出了傅立叶变换微调(FourierFT),一种减少训练参数的新方法,替代了依赖LoRA的微调。FourierFT通过学习权重变化矩阵的稀疏频谱系数,实现了LFMs的高效微调。在多项任务上,FourierFT展示出与LoRA相当或更优的性能,参数量却大幅减少,如在LLaMA2-7B模型上,仅需0.064M参数,对比LoRA的33.5M。广泛实验验证了其在NLP和CV任务上的效果,但未来还需探索其适用性和泛化能力。论文链接:[arxiv.org/abs/2405.03003](https://arxiv.org/abs/2405.03003)
212 0
|
12月前
|
机器学习/深度学习 并行计算 算法
模型压缩部署神技 | CNN与Transformer通用,让ConvNeXt精度几乎无损,速度提升40%
模型压缩部署神技 | CNN与Transformer通用,让ConvNeXt精度几乎无损,速度提升40%
214 0
|
机器学习/深度学习 存储 人工智能
模型推理加速系列 | 03:Pytorch模型量化实践并以ResNet18模型量化为例(附代码)
本文主要简要介绍Pytorch模型量化相关,并以ResNet18模型为例进行量化实践。