256KB内存约束下的设备端训练:算法与系统协同设计
Ji Lin, Ligeng Zhu, Wei-Ming Chen, Wei-Chen Wang, Chuang Gan, and Song Han. 2022. On-device training under 256KB memory. In Proceedings of the 36th International Conference on Neural Information Processing Systems (NIPS '22). Curran Associates Inc., Red Hook, NY, USA, Article 1667, 22941–22954.
引言与背景
物联网设备的爆炸式增长带来了一个根本性挑战:如何让这些资源极度受限的设备具备学习和适应能力。传统的机器学习范式依赖于云端训练,这不仅带来隐私风险,还增加了网络延迟和能耗。MIT和MIT-IBM Watson AI Lab的研究团队提出了一个突破性的解决方案,首次实现了在仅有256KB SRAM和1MB Flash的微控制器上进行深度神经网络训练。
现代深度学习框架的内存需求与微控制器资源之间存在着巨大鸿沟。即使是训练最小的模型(如MobileNetV2-w0.35,批次大小为1),PyTorch也需要303MB内存,TensorFlow需要652MB,而专为边缘设计的MNN框架仍需要41.5MB。相比之下,典型的微控制器如STM32F746仅有320KB SRAM,这种超过1000倍的差距使得现有框架无法直接应用。

图1 - 内存占用对比图:该图清晰展示了不同框架的内存需求阶梯式下降。从云端框架(PyTorch 303MB、TensorFlow 652MB)到边缘框架(MNN 41.5MB),再到本文提出的Tiny Training Engine通过一系列优化技术(量化感知缩放、稀疏层/张量更新、算子重排序)逐步将内存降至141KB,最终实现了2300倍的内存减少。图中用不同颜色标注了每个优化步骤带来的改进倍数。
真实量化图的优化挑战
量化图的本质差异
边缘设备上的神经网络必须量化以适应有限内存。对于标准的fp32线性层运算,其int8量化版本的数学表达为:
$$\bar{y}_{int8} = \text{cast2int8}[s_{fp32} \cdot (\bar{W}_{int8}\bar{x}_{int8} + \bar{b}_{int32})]$$
这里的关键在于理解真实量化图(real quantized graph)与伪量化图(fake quantized graph)的本质区别。

图2 - 真实量化图vs伪量化图对比:
- 上半部分(a)展示真实量化图:输入和权重都是int8格式(用深蓝色表示),卷积运算产生int32中间结果,偏置也是int32,通过缩放因子和类型转换最终输出int8。每个张量旁边标注了其数据类型和值范围。
- 下半部分(b)展示伪量化图(QAT中使用):虽然包含fake quantize操作来模拟量化效果,但所有张量实际都是fp32(用浅色表示),还包含BatchNorm层。这种设计用于模拟但无法提供实际的内存节省。
真实量化图的更新公式为:
$$\bar{W}'_{int8} = \text{cast2int8}(\bar{W}_{int8} - \alpha \cdot G_{\bar{W}})$$
其中$\alpha$是学习率,$G_{\bar{W}}$是权重的梯度。梯度计算也在int8中执行以提高计算效率。
梯度尺度失配问题
直接训练真实量化图会遇到严重的优化困难。量化图包含不同位精度的张量(int8、int32、fp32),且缺少批归一化层(已融合),导致梯度更新不稳定。

图3 - 权重/梯度范数比分析:该图展示了35个张量索引的$\log_{10}(|W|/|G|)$比值。蓝色线是fp32模型,呈现平滑模式;橙色线是未经处理的int8模型,在红框标注区域呈现剧烈的锯齿状波动(权重-偏置-权重-偏置交替);绿色线是应用QAS后的int8模型,成功对齐了fp32的模式。纵轴的对数尺度显示量化后比值可能相差几个数量级。
量化感知缩放(QAS)的数学原理
缩放因子的影响分析
考虑权重矩阵$W \in \mathbb{R}^{c_1 \times c_2}$的逐张量量化过程。为使量化后权重$\bar{W}$的最大幅值达到$2^7 - 1 = 127$,计算缩放率:
$$s_W = \frac{\max(|W|)}{127}$$
量化和反量化过程可表示为:
$$W = s_W \cdot \text{round}(W/s_W) \approx s_W \cdot \bar{W}$$
在反向传播时,梯度也会被相应缩放:
$$G_{\bar{W}} \approx s_W \cdot G_W$$
这导致权重与梯度的范数比例发生扭曲:
$$\frac{\|\bar{W}\|}{\|G_{\bar{W}}\|} \approx \frac{\|W/s_W\|}{\|s_W \cdot G_W\|} = s_W^{-2} \cdot \frac{\|W\|}{\|G_W\|}$$
QAS补偿机制
QAS通过精确的梯度补偿来纠正这种扭曲。对于包含权重和偏置的完整层,补偿公式为:
$$\tilde{G}_{\bar{W}} = G_{\bar{W}} \cdot s_W^{-2}$$
$$\tilde{G}_{\bar{b}} = G_{\bar{b}} \cdot s_W^{-2} \cdot s_x^{-2} = G_{\bar{b}} \cdot s^{-2}$$
其中$s = s_W \cdot s_x$是方程(1)中的组合缩放因子。这种补偿机制无需任何超参数调节,自动适应不同层的量化尺度。
内存高效的稀疏更新策略
更新成本分析
对于线性层$y = Wx + b$的反向传播,梯度计算涉及:
- 权重梯度:$G_W = G_y \cdot x^T$(需要保存激活$x$)
- 偏置梯度:$G_b = \sum G_y$(不需要激活)
- 输入梯度:$G_x = W^T \cdot G_y$(用于继续反向传播)

图4 - 四种更新范式对比:图中用不同颜色区分更新(深色)和冻结(浅色)的参数。展示了两个连续线性层的四种更新策略:(a)完全更新所有权重和偏置;(b)仅更新偏置,权重保持冻结;(c)稀疏层更新,选择性更新某些层;(d)稀疏张量更新,甚至可以只更新某一层的部分通道。
贡献分析与自动搜索

图5 - 贡献分析结果:
- (a)偏置更新贡献曲线:随着更新层数增加,精度提升在15层后趋于饱和(约8%提升)
- (b)权重更新贡献热图:纵轴是精度增益,横轴是层索引。颜色深度表示不同通道更新比例(1/8到全部)。后面的层(30-40)贡献更大,点卷积层(pw1)呈现峰值
- (c)搜索有效性验证:横轴是贡献总和,纵轴是实际下游精度,正相关关系(相关系数>0.8)验证了线性假设的合理性
优化问题的求解采用进化算法,搜索空间约为$10^{30}$但通过贡献分析可在10分钟内找到优质解。
Tiny Training Engine系统架构

图6 - TTE完整工作流程:
- 输入模型通过追踪获得前向计算图(蓝色节点)
- 编译时自动微分生成反向图,红色圆圈表示梯度下降算子
- 图剪枝阶段,浅蓝色节点(冻结权重相关)被移除
- 算子重排序,黄色标注的梯度更新算子与反向计算交错执行
- 最终部署到微控制器,接收传感器数据进行在线更新
内存优化技术

图7 - 算子重排序的内存生命周期分析:
- (a)原始反向图:内存占用呈现多个高峰(最高达384KB),因为所有梯度都要保存到最后统一更新
- (b)优化后反向图:通过原地梯度更新和算子融合,内存峰值降至约160KB(2.4倍减少)。图中标注了可融合的算子对,以及原地更新带来的内存节省区域
实验结果与分析
整体性能评估

图9 - 稀疏更新vs传统方法对比:三个子图分别展示MobileNetV2、ProxylessNAS和MCUNet在不同内存预算下的精度。每条曲线代表一种更新策略:
- 紫色虚线:仅更新分类器的精度下限(59-65%)
- 蓝色曲线:更新最后k个偏置,精度快速提升但很快饱和
- 橙色曲线:更新最后k层的完整权重,精度高但内存成本巨大
- 绿色曲线:本文的稀疏更新,在更小内存下达到更高精度
- 图中标注了关键内存阈值(50KB、75KB、100KB、150KB)和相应的内存减少倍数(4.5×到7.5×)
系统性能测量

图10 - 实测内存和延迟:
- (a)峰值内存对比:展示三个模型在不同方案下的实测SRAM占用。从完全更新(2939-3650KB,超出内存)到稀疏更新(326-560KB)再到稀疏更新+重排序(141-173KB),实现20-21倍减少
- (b)不同精度水平的内存优化:即使对于不同的稀疏程度(对应72.0%、73.4%、75.1%精度),重排序都能带来约3.2倍的一致改进
- (c)训练延迟:TF-Lite完全更新需要8501-13398ms(投影值,因为OOM),而TTE稀疏更新仅需373-546ms,实现23-25倍加速
更新方案分析

图11 - MCUNet更新方案可视化:
- (a)每层内存成本分解:橙色区域表示激活内存(前面层高),蓝色区域表示权重内存(后面层高),总成本在中间层(18-30)最低
- (b)最终更新方案:底部条形图显示了自动搜索得到的方案。前20层仅前向传播,21-42层更新偏置,6个特定层更新权重(部分层只更新1/8或1/4通道)
消融研究与深入分析
研究通过详尽的消融实验验证了各个组件的贡献。表1展示了不同优化器和精度设置下的结果,QAS将int8训练精度从64.9%提升到73.5%,完全匹配fp32基准。图8的训练曲线显示QAS显著改善了收敛性,训练和验证损失都更加稳定。
对于批大小和动量的研究(表4)揭示了有趣的发现:在单批次设置下,动量实际上有害(71.5% vs 72.3%),这与常规认知相反。这是因为单样本梯度的高方差使得动量积累了噪声而非有用信号。
附录:数学推导
A. 量化过程的数学分析
A.1 前向量化过程
对于权重矩阵$W \in \mathbb{R}^{c_1 \times c_2}$,逐通道量化的完整过程如下:
首先计算每个输出通道的缩放因子:
$$s_{W,i} = \frac{\max_j |W_{i,j}|}{127}, \quad i \in [1, c_2]$$
量化函数定义为:
$$Q(w, s) = \text{clip}(\text{round}(w/s), -128, 127)$$
其中clip函数确保值在int8范围内:
$$\text{clip}(x, a, b) = \begin{cases} a & \text{if } x < a \\ x & \text{if } a \leq x \leq b \\ b & \text{if } x > b \end{cases}$$
A.2 反向传播的梯度流
考虑量化操作的梯度。使用直通估计器(Straight-Through Estimator, STE):
$$\frac{\partial Q(w, s)}{\partial w} \approx \begin{cases} 1/s & \text{if } |w/s| \leq 127 \\ 0 & \text{otherwise} \end{cases}$$
这导致梯度流为:
$$\frac{\partial L}{\partial W} = \frac{\partial L}{\partial \bar{W}} \cdot \frac{\partial Q(W, s_W)}{\partial W} \approx \frac{1}{s_W} \cdot \frac{\partial L}{\partial \bar{W}}$$
A.3 QAS的理论推导
设原始浮点模型的最优学习率为$\alpha^*$。在梯度下降中,权重更新为:
$$W_{t+1} = W_t - \alpha^* \cdot G_W$$
对于量化模型,如果不进行补偿:
$$\bar{W}_{t+1} = \bar{W}_t - \alpha^* \cdot G_{\bar{W}} = \bar{W}_t - \alpha^* \cdot s_W \cdot G_W$$
由于$\bar{W} = W/s_W$,等效的权重空间更新为:
$$W_{t+1} = W_t - \alpha^* \cdot s_W^2 \cdot G_W$$
这相当于使用了$\alpha^* \cdot s_W^2$的有效学习率,通常$s_W \ll 1$,导致学习率过小。
QAS通过缩放梯度来补偿:
$$\tilde{G}_{\bar{W}} = s_W^{-2} \cdot G_{\bar{W}} = s_W^{-2} \cdot s_W \cdot G_W = s_W^{-1} \cdot G_W$$
使得有效更新变为:
$$W_{t+1} = W_t - \alpha^* \cdot s_W \cdot s_W^{-1} \cdot G_W = W_t - \alpha^* \cdot G_W$$
恢复了原始的更新动态。
B. 稀疏更新的内存复杂度分析
B.1 完整反向传播的内存需求
对于$L$层网络,设第$i$层有$n_i$个神经元。完整训练的内存需求包括:
- 前向激活缓存:$M{act} = \sum{i=1}^{L-1} n_i \cdot b$(批大小$b$)
- 权重参数:$M{weight} = \sum{i=1}^{L} n_{i-1} \cdot n_i \cdot \text{sizeof}(\text{dtype})$
- 梯度缓存:$M{grad} = M{weight}$(每个参数的梯度)
- 反向中间变量:$M{back} \approx M{act}$
总内存:$M{total} = M{act} + M{weight} + M{grad} + M{back} \approx 2M{act} + 2M_{weight}$
B.2 稀疏更新的内存节省
稀疏层更新(更新第$k$层到第$L$层):
$$M'_{act} = \sum_{i=k-1}^{L-1} n_i \cdot b$$
$$M'_{grad} = \sum_{i=k}^{L} n_{i-1} \cdot n_i \cdot \text{sizeof}(\text{dtype})$$
稀疏张量更新(更新比例$r \in (0,1]$):
$$M''_{grad} = r \cdot M'_{grad}$$
内存节省比例:
$$\rho = 1 - \frac{M'_{act} + M''_{grad} + M_{weight}}{M_{total}}$$
B.3 贡献分析的数学模型
定义层$i$权重$W_i$的贡献度量:
$$C(W_i) = \mathbb{E}_{(x,y) \sim D_{val}} \left[ \frac{\partial \mathcal{L}(f(x; \theta), y)}{\partial W_i} \cdot \Delta W_i \right]$$
其中$\Delta W_i$是该层的预期更新量。实践中通过有限差分近似:
$$C(W_i) \approx \text{Acc}(\theta \cup \{W_i\}) - \text{Acc}(\theta \setminus \{W_i\})$$
贡献的可加性假设:
$$C(\{W_i\}_{i \in S}) \approx \sum_{i \in S} C(W_i)$$
实验验证这一近似在稀疏更新场景下足够准确(图5c相关系数>0.8)。
C. 编译器优化的形式化描述
C.1 计算图表示
定义计算图$G = (V, E)$,其中:
- $V = V_f \cup V_b \cup V_u$:前向、反向、更新节点
- $E$:数据依赖边
内存生命周期函数:
$$L(v) = [\text{first\_use}(v), \text{last\_use}(v)]$$
峰值内存:
$$M_{peak} = \max_{t} \sum_{v: t \in L(v)} \text{size}(v)$$
C.2 算子重排序算法
目标:找到拓扑排序$\pi: V \rightarrow \mathbb{N}$,最小化$M_{peak}$。
约束条件:
$$\forall (u, v) \in E: \pi(u) < \pi(v)$$
对于梯度更新节点$v_u \in V_u$和对应的梯度计算节点$v_g \in V_b$,原地更新要求:
$$\pi(v_g) + 1 = \pi(v_u)$$
通过启发式算法(优先调度"关键路径"上的节点)近似求解。
C.3 图剪枝规则
给定冻结参数集$F$,剪枝规则:
- 直接剪枝:$\forall v \in V_b: \text{output}(v) \in F \Rightarrow \text{remove}(v)$
- 级联剪枝:$\forall v \in V: \text{out_degree}(v) = 0 \Rightarrow \text{remove}(v)$
- 激活剪枝:如果层$l$的权重不更新,则$\text{activation}_l$无需保存
这些规则迭代应用直到不动点。
D. 收敛性分析
D.1 量化训练的收敛保证
定理:在适当的学习率和QAS下,量化SGD收敛到稳定点的邻域。
证明概要:定义量化误差$\epsilon_q = W - s_W \cdot \bar{W}$,有界为$|\epsilon_q| \leq s_W/2$。
李雅普诺夫函数:
$$V(W) = \mathcal{L}(W) + \lambda \|\epsilon_q\|^2$$
在QAS下,期望下降:
$$\mathbb{E}[V(W_{t+1}) - V(W_t)] \leq -\eta \|\nabla \mathcal{L}(W_t)\|^2 + \mathcal{O}(s_W^2)$$
当$|\nabla \mathcal{L}(W_t)| = \mathcal{O}(s_W)$时达到平衡。
D.2 稀疏更新的近似界
稀疏更新可视为投影梯度下降:
$$W_{t+1} = \Pi_{\mathcal{S}}(W_t - \alpha G_t)$$
其中$\mathcal{S}$是允许更新的参数子空间。
近似误差界:
$$\|\nabla \mathcal{L}(W^*_{sparse}) - \nabla \mathcal{L}(W^*_{full})\| \leq \kappa \cdot \sqrt{1 - r}$$
其中$r$是更新参数比例,$\kappa$是问题相关常数。
这解释了为什么即使稀疏更新也能达到接近完整更新的精度。