本次分享内容为S&P 2021 收录的一篇文章:《SIRNN: A Math Library for Secure RNN Inference》。这篇论文主要设计了基于秘密分享的不同大小(比特)的环转换协议,以及混合比特的数学计算,并应用到了机器学习预测场景中,通过大量低比特的计算,减小了通信开销,提高了模型预测的效率。
围绕这篇论文的分享将分为以下的四个部分:
- 背景以及动机介绍
- 基础计算原语的设计
- 数学运算的协议设计
- 模型预测实验对比
1. 背景以及动机介绍
1.1 背景这个工作的场景为隐私保护机器学习,基于秘密分享协议(安全多方计算的一个分支)实现模型预测,确保多方预测时的模型参数以及预测输入的安全性。通常来说,机器学习的计算是基于浮点数运算,然而使用秘密分享实现安全浮点数运算会导致较高的开销,效率难以满足实际的计算需求,因此通常的做法是使用定点数来近似浮点数,在一定精度损失的前提下,取得较大的效率优化。现有的秘密分享协议通常定义在环或者域上,各有优劣。环上的计算由于其取模操作能够隐式地由硬件负责,相较于域上计算需要手动取模,计算效率更高。1.2 动机此工作同样采用这一范式,将计算统一编码到环上,同时使用定点数表示。举例来说,定点数可以理解为一个INT64类型的整数,其低位的20-bit表示小数部分,也即精度是固定的,而不像浮点数。通常,MPC下的密文空间是uniform bitwidth,即在一个统一大小的环上执行。然而考虑到模型预测中的部分中间计算实际上使用较小的bitwidth也能完成,并且使用较小的bitwidth能够减少计算和通信开销,因而SIRNN提出使用non-uniform bitwidth(混合比特)的计算,对不同计算使用不同大小的环,进而提高计算的效率。以下述代码为例,这段代码执行了混合比特运算。各个变量的比特位数和小数位数在下述表格中展示。简单的加减乘除只需要12-bit的精度,但是为了避免乘法的溢出,因此使用32-bit来存储中间计算。而对于指数exp计算,则是使用了30-bit的精度,相应的比特位数也需要提高到32,而不是16。通过这一方式,只有最后的exp需要高位数、高精度,加减乘除则是可以在低位数上进行,减小了开销。
2. 基础计算原语的设计
该图片素材来自https://www.youtube.com/watch?v=Wz_FBVObL7U,如有不妥,立即删除考虑在ℤ𝑚上的2PC additive Secret Sharing,输入𝑥被拆分成(𝑥0,𝑥1),满足(𝑥0+𝑥1) mod 2𝑚=𝑥,两个计算方𝑃0,𝑃1分别持有𝑥0,𝑥1。对于混合比特位数计算,考虑两个不同大小的环ℤ𝑚和ℤ𝑛,分别对应的环大小为2𝑚和2𝑛,即分别使用m-bit 整数和n-bit整数。
2.1 理论复杂度
2.2 Share Reduction
即ℤ𝑛→ℤ𝑚。输入为𝑥0,𝑥1,满足𝑥=(𝑥0+𝑥1) mod 2𝑛这一计算可以由两个参与方本地完成,这是因为𝑥 mod 2𝑚=(𝑥0+𝑥1-𝑤*2𝑛) mod 2𝑚=(𝑥0 mod 2𝑚)+(𝑥1 mod 2𝑚)=(𝑥′0+𝑥′1) mod 2𝑚其中𝑤*2𝑛表示加和溢出的部分,因为2𝑚<2𝑛,所以可以被隐式的取模舍去。𝑃0,𝑃1对𝑥0和𝑥1分别本地对2𝑚取模,即得到𝑥′0,𝑥′1,满足𝑥=(𝑥′0+𝑥′1) mod 2𝑚
2.3 Share Extension
即ℤ𝑚→ℤ𝑛。输入为𝑥0,𝑥1,满足𝑥=(𝑥0+𝑥1) mod 2𝑚这一计算不能本地完成,举例来说:假设原始输入𝑥=3,在大小为16的环上的两个分享为(10,9),本地对𝑥0,𝑥1模8可以得到𝑥在大小为8的环上的两个分享为(2,1)。然而从16到32的环上则正确性不对,因为(10+9) mod 32=19≠3。这是因为存在溢出部分,即𝑤=1,此时由于是从小环到大环,这个溢出项不能被忽视,必须显示地计算得到𝑤。以下给出文章提出的extension协议,针对有符号和无符号的表示分别处理。
2.3.1 Zero Extension
- 第一步中显示地计算𝑤=wrap(𝑥0,𝑥1,2𝑚),得到𝑤的boolean sharing。(wrap协议在CrypTFlow2 中提出)
- 第二步使用B2A协议,得到𝑤在ℤ𝑛-𝑚上的arithmetic sharing。注意这一步没有直接得到𝑤在ℤ𝑛上的arithmetic sharing是为了减少通信量。而这么做的合理性在于⟨𝑤⟩𝑛-𝑚*2𝑚=⟨𝑤⟩𝑛
- 第三步显示地将这个潜在的溢出项减去,即𝑥=(𝑥0+𝑥1-𝑤*2𝑚) mod 2𝑛
2.3.2 Signed Extension
上述协议针对无符号整数,当输入为有符号的整数时,稍有不同。观察到int(𝑥)=𝑥′-2𝑚-1, for 𝑥′=(𝑥+2𝑚-1) mod 2𝑚因此有SExt(𝑥,𝑚,𝑛)=ZExt(𝑥′,𝑚,𝑛)-2𝑚-1
2.4 Truncation
考虑输入⟨𝑥⟩𝑙,计算𝑥/2𝑠。同样分为无符号和有符号整数处理。
2.4.1 Logical Right Shift
构造⟨𝑥⟩𝑙𝑏=𝑢𝑏||𝑣𝑏,其中𝑏∈{0,1},𝑢𝑏∈{0,1}𝑙-𝑠,𝑣𝑏∈{0,1}𝑠满足𝑥>>𝐿 𝑠=𝑢0+𝑢1-2𝑙-𝑠·wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)+wrap(𝑣0,𝑣1,2𝑠)这里减去2𝑙-𝑠·wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)是因为类似share extension,需要将潜在的溢出去掉,由于2𝑙-𝑠<2𝑙因此不能隐式完成。加上wrap(𝑣0,𝑣1,2𝑠)则是因为可能低位𝑠-bit的𝑣0+𝑣1>2𝑠,若直接取高位的𝑢0+𝑢1,则缺了潜在的进位1。注意,这里可以直接调用wrap协议分别计算wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙),wrap(𝑣0,𝑣1,2𝑠)。这篇文章对这里的计算进行了优化,根据如下定理,避免直接计算wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)(因为长度为𝑙比特,远大于𝑠)解释如下,其中𝐿=2𝑙:𝑥0+𝑥1=(𝑣0+𝑣1)+(𝑢0+𝑢1)*2𝑠=(𝑣0+𝑣1-𝑐*2𝑠)+(𝑢0+𝑢1-𝑑*2𝑙-𝑠)*2𝑠+𝑐*2𝑠+𝑑*𝐿=𝑣′+(𝑢′+𝑐)*2𝑠+𝑑*𝐿令𝑤′=1{𝑢′+𝑐>2𝑙-𝑠-1},满足𝑤′=𝑐∧𝑒有𝑥0+𝑥1=𝑣′+(𝑢′+𝑐-𝑤′*2𝑙-𝑠)+(𝑑+𝑤′)*𝐿因此有𝑤=𝑑+𝑤′=𝑑+𝑐∧𝑒=𝑑⊕𝑐∧𝑒
2.4.2 Arithmetic Right Shift
对于有符号整数ARS(𝑥,𝑠)=LRS(𝑥′,𝑠)-2𝑙-1-𝑠其中𝑥′=𝑥+2𝑙-1 mod 2𝑙
2.4.3 Division by power-of-2
这里直接放上原文的描述,相较right-shift,这里计算的C-style division,即rounding-to-zero。也就是说如果输入是负数,需要考虑rounding-up,即加1的情况。
2.4.4 Truncate-and-Reduce
输入⟨𝑥⟩𝑙,输出⟨𝑥/2𝑠⟩𝑙-𝑠可以结合truncation 和 share reduction,有TR(𝑥,𝑠)=𝑢0+𝑢1+wrap(𝑣0,𝑣1,2𝑠)可以避免计算wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)
2.5 Multiplication with non-uniform bitwidths
2.5.1Unsigned Multiplication
输入⟨𝑥⟩𝑚,⟨𝑦⟩𝑛,输出⟨𝑧⟩𝑚+𝑛,其中𝑧=𝑥*𝑦这里实现乘法的思路是标准的2PC multiplication,即𝑥0𝑦0,𝑥1𝑦1可以本地计算得到,𝑥0𝑦1,𝑥1𝑦0则需要使用Cross-term Multiplication,本质上是基于OT的乘法(可以参考ABY)
2.5.2 Signed Multiplication
类似对于Share Extension和Truncation的处理
2.5.3 Matrix Multiplication
注意,这里不同于基础乘法,矩阵乘法由于需要进行𝑑2-1次加法,因此加法有可能出现overflow,因此需要将加和的结果使用𝑙=𝑚+𝑛+𝑒比特表示,其中𝑒=「log𝑑2」。为了减少开销,在第一行首先将𝑛-bit的输入𝑌 extend到𝑛+𝑒-bit,接着和𝑋执行non-uniform Multiplication。然而,这里的通信复杂度为𝑂(𝑑1𝑑2𝑑3)。
2.6 Digit Decomposition
此协议的目的是得到𝑙-bit 输入 𝑥,按照𝑑-bit分段,各段的分享。这个协议主要用于数学运算的协议实现。这个协议的关键步骤在于第5~6行,计算𝑢𝑖+1=𝑤𝑖⊕(𝑢𝑖∧𝑒𝑖),其中𝑤𝑖表示第𝑖段的进位,𝑒𝑖表示第𝑖段对应的𝑑-bit 数各位上是否为全1,𝑢𝑖表示低位对第𝑖段的进位。使用如下的示意图理解:
- 初始𝑢0=0因为没有来自低位的进位。
- 计算𝑒0∧𝑢0,如果𝑦0全1,且有来自低位的进位,则说明𝑦1有来自低位的进位
- 此外,如果直接𝑤0=1,那么同样说明𝑦1有来自低位的进位
- 因此𝑢1=𝑤0⊕(𝑢0∧𝑒0)
- 最后𝑧1=𝑦1+𝑢1
- 迭代计算到最高位的一段即可
2.7 MSNZB
此协议的目的是得到最高位非0比特的位置,例如,0000|0101→2基本原理为:𝑀𝑆𝑁𝑍𝐵(𝑦)=𝑀𝑆𝑁𝑍𝐵(𝑦𝑖)+𝑖·𝑑,满足𝑦𝑖!=0,𝑦𝑖=0,𝑗>𝑖。具体的实现可以参考原文章附录,需要注意的是,这个协议里面使用到的 子协议 MSNZB-P,Zeros and One-Hot 是使用LUT(Look Up Table,[3])实现的。
3. 数学运算的协议设计
以下数学计算的实现思路类似,核心是结合LUT得到initial guess,然后利用Iterative method 迭代近似正确值。
3.1 Exponential
注意这里第2行是预先构造出一个LUT,满足输入某一段的值,查表对应的结果即为对数计算的输出。
3.2 Sigmoid & Tanh
Sigmoid的计算为而𝑇𝑎𝑛𝘩(𝑧)=2𝑆𝑖𝑔𝑚𝑜𝑖𝑑(2𝑧)-1因此计算这俩个函数只需要用到exponential和reciprocal计算,exponential如上已经构造好,而reciprocal的协议如下:Reciprocal的计算首先根据LUT得到初始近似,再迭代得到最终结果。
3.3 Reciprocal of Square Root
和Reciprocal一样,平方根求逆计算协议同样是遵循了这个范式:首先使用LUT得到初始近似值,再利用Goldschmidt 迭代得到近似计算结果。
4. 模型预测实验对比
实验部分没有比较primitive,如乘法的耗时,只比较了math library以及nn inference的效率
4.1 Math Library
针对基础的数学计算,选择MiniONN、MP-SPDZ作为对比,效率上有显著提升,同时精度显著高于这两个工作。这主要得益于LUT的初始猜测精度更高。
4.2 DNN
针对2PC安全预测,和MiniONN、DeepSecure进行了实验对比,最快效率是DeepSecure的87倍,MiniONN的2.2倍
4.3 RNN
针对RNN的安全预测,对比对象为ABY,效率有显著的提升,这主要得益于通信开销的显著优化
5. 总结
SIRNN基于CrypTFLow2以及LUT,设计实现了混合比特的数学计算,并应用在DNN、RNN的预测上,通过减少通信开销,显著提高了效率。参考文献[1]: SIRNN: A Math Library for Secure RNN Inference, S&P 21'[2]: ABY – A Framework for Efficient Mixed-Protocol Secure Two-Party Computation, NDSS 15'[3]: Pushing the Communication Barrier in Secure Computation using Lookup Tables, NDSS 17'[4]: CrypTFlow2: Practical 2-Party Secure Inference, CCS 20'