隐语小课 | 基于秘密分享的混合比特数学运算库-SIRNN介绍

简介: 隐语小课 | 基于秘密分享的混合比特数学运算库-SIRNN介绍


本次分享内容为S&P 2021 收录的一篇文章:《SIRNN: A Math Library for Secure RNN Inference》。这篇论文主要设计了基于秘密分享的不同大小(比特)的环转换协议,以及混合比特的数学计算,并应用到了机器学习预测场景中,通过大量低比特的计算,减小了通信开销,提高了模型预测的效率。

围绕这篇论文的分享将分为以下的四个部分:

  • 背景以及动机介绍
  • 基础计算原语的设计
  • 数学运算的协议设计
  • 模型预测实验对比

1. 背景以及动机介绍

1.1 背景这个工作的场景为隐私保护机器学习,基于秘密分享协议(安全多方计算的一个分支)实现模型预测,确保多方预测时的模型参数以及预测输入的安全性。通常来说,机器学习的计算是基于浮点数运算,然而使用秘密分享实现安全浮点数运算会导致较高的开销,效率难以满足实际的计算需求,因此通常的做法是使用定点数来近似浮点数,在一定精度损失的前提下,取得较大的效率优化。现有的秘密分享协议通常定义在环或者域上,各有优劣。环上的计算由于其取模操作能够隐式地由硬件负责,相较于域上计算需要手动取模,计算效率更高。1.2 动机此工作同样采用这一范式,将计算统一编码到环上,同时使用定点数表示。举例来说,定点数可以理解为一个INT64类型的整数,其低位的20-bit表示小数部分,也即精度是固定的,而不像浮点数。通常,MPC下的密文空间是uniform bitwidth,即在一个统一大小的环上执行。然而考虑到模型预测中的部分中间计算实际上使用较小的bitwidth也能完成,并且使用较小的bitwidth能够减少计算和通信开销,因而SIRNN提出使用non-uniform bitwidth(混合比特)的计算,对不同计算使用不同大小的环,进而提高计算的效率。以下述代码为例,这段代码执行了混合比特运算。各个变量的比特位数和小数位数在下述表格中展示。简单的加减乘除只需要12-bit的精度,但是为了避免乘法的溢出,因此使用32-bit来存储中间计算。而对于指数exp计算,则是使用了30-bit的精度,相应的比特位数也需要提高到32,而不是16。通过这一方式,只有最后的exp需要高位数、高精度,加减乘除则是可以在低位数上进行,减小了开销。

2. 基础计算原语的设计

该图片素材来自https://www.youtube.com/watch?v=Wz_FBVObL7U如有不妥,立即删除考虑在𝑚上的2PC additive Secret Sharing,输入𝑥被拆分成(𝑥0,𝑥1),满足(𝑥0+𝑥1) mod 2𝑚=𝑥,两个计算方𝑃0,𝑃1分别持有𝑥0,𝑥1对于混合比特位数计算,考虑两个不同大小的环ℤ𝑚和ℤ𝑛,分别对应的环大小为2𝑚和2𝑛,即分别使用m-bit 整数和n-bit整数。

2.1 理论复杂度

2.2 Share Reduction

即ℤ𝑛→ℤ𝑚。输入为𝑥0,𝑥1,满足𝑥=(𝑥0+𝑥1) mod 2𝑛这一计算可以由两个参与方本地完成,这是因为𝑥 mod 2𝑚=(𝑥0+𝑥1-𝑤*2𝑛) mod 2𝑚=(𝑥0 mod 2𝑚)+(𝑥1 mod 2𝑚)=(𝑥′0+𝑥′1) mod 2𝑚其中𝑤*2𝑛表示加和溢出的部分,因为2𝑚<2𝑛,所以可以被隐式的取模舍去。𝑃0,𝑃1𝑥0𝑥1分别本地对2𝑚取模,即得到𝑥′0,𝑥′1,满足𝑥=(𝑥′0+𝑥′1) mod 2𝑚

2.3 Share Extension

即ℤ𝑚→ℤ𝑛。输入为𝑥0,𝑥1,满足𝑥=(𝑥0+𝑥1) mod 2𝑚这一计算不能本地完成,举例来说:假设原始输入𝑥=3,在大小为16的环上的两个分享为(10,9),本地对𝑥0,𝑥1模8可以得到𝑥在大小为8的环上的两个分享为(2,1)。然而从16到32的环上则正确性不对,因为(10+9) mod 32=19≠3。这是因为存在溢出部分,即𝑤=1,此时由于是从小环到大环,这个溢出项不能被忽视,必须显示地计算得到𝑤。以下给出文章提出的extension协议,针对有符号和无符号的表示分别处理。

2.3.1 Zero Extension

  1. 第一步中显示地计算𝑤=wrap(𝑥0,𝑥1,2𝑚),得到𝑤的boolean sharing。(wrap协议在CrypTFlow2 中提出)
  2. 第二步使用B2A协议,得到𝑤在ℤ𝑛-𝑚上的arithmetic sharing。注意这一步没有直接得到𝑤在ℤ𝑛上的arithmetic sharing是为了减少通信量。而这么做的合理性在于𝑤𝑛-𝑚*2𝑚=𝑤𝑛
  3. 第三步显示地将这个潜在的溢出项减去,即𝑥=(𝑥0+𝑥1-𝑤*2𝑚) mod 2𝑛

2.3.2 Signed Extension

上述协议针对无符号整数,当输入为有符号的整数时,稍有不同。观察到int(𝑥)=𝑥-2𝑚-1, for 𝑥′=(𝑥+2𝑚-1) mod 2𝑚因此有SExt(𝑥,𝑚,𝑛)=ZExt(𝑥′,𝑚,𝑛)-2𝑚-1

2.4 Truncation

考虑输入⟨𝑥⟩𝑙,计算𝑥/2𝑠。同样分为无符号和有符号整数处理。

2.4.1 Logical Right Shift


构造⟨𝑥⟩𝑙𝑏=𝑢𝑏||𝑣𝑏,其中𝑏∈{0,1},𝑢𝑏∈{0,1}𝑙-𝑠,𝑣𝑏∈{0,1}𝑠满足𝑥>>𝐿  𝑠=𝑢0+𝑢1-2𝑙-𝑠·wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)+wrap(𝑣0,𝑣1,2𝑠)这里减去2𝑙-𝑠·wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)是因为类似share extension,需要将潜在的溢出去掉,由于2𝑙-𝑠<2𝑙因此不能隐式完成。加上wrap(𝑣0,𝑣1,2𝑠)则是因为可能低位𝑠-bit的𝑣0+𝑣1>2𝑠,若直接取高位的𝑢0+𝑢1,则缺了潜在的进位1。注意,这里可以直接调用wrap协议分别计算wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙),wrap(𝑣0,𝑣1,2𝑠)。这篇文章对这里的计算进行了优化,根据如下定理,避免直接计算wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)(因为长度为𝑙比特,远大于𝑠)解释如下,其中𝐿=2𝑙𝑥0+𝑥1=(𝑣0+𝑣1)+(𝑢0+𝑢1)*2𝑠=(𝑣0+𝑣1-𝑐*2𝑠)+(𝑢0+𝑢1-𝑑*2𝑙-𝑠)*2𝑠+𝑐*2𝑠+𝑑*𝐿=𝑣′+(𝑢′+𝑐)*2𝑠+𝑑*𝐿令𝑤′=1{𝑢′+𝑐>2𝑙-𝑠-1},满足𝑤′=𝑐∧𝑒有𝑥0+𝑥1=𝑣′+(𝑢′+𝑐-𝑤′*2𝑙-𝑠)+(𝑑+𝑤′)*𝐿因此有𝑤=𝑑+𝑤′=𝑑+𝑐∧𝑒=𝑑⊕𝑐∧𝑒

2.4.2 Arithmetic Right Shift

对于有符号整数ARS(𝑥,𝑠)=LRS(𝑥′,𝑠)-2𝑙-1-𝑠其中𝑥′=𝑥+2𝑙-1 mod 2𝑙

2.4.3 Division by power-of-2

这里直接放上原文的描述,相较right-shift,这里计算的C-style division,即rounding-to-zero。也就是说如果输入是负数,需要考虑rounding-up,即加1的情况。

2.4.4 Truncate-and-Reduce

输入⟨𝑥⟩𝑙,输出⟨𝑥/2𝑠𝑙-𝑠可以结合truncation 和 share reduction,有TR(𝑥,𝑠)=𝑢0+𝑢1+wrap(𝑣0,𝑣1,2𝑠)可以避免计算wrap(⟨𝑥⟩𝑙0,⟨𝑥⟩𝑙1,2𝑙)

2.5 Multiplication with non-uniform bitwidths

2.5.1Unsigned Multiplication

输入⟨𝑥⟩𝑚,⟨𝑦⟩𝑛,输出⟨𝑧⟩𝑚+𝑛,其中𝑧=𝑥*𝑦这里实现乘法的思路是标准的2PC multiplication,即𝑥0𝑦0,𝑥1𝑦1可以本地计算得到,𝑥0𝑦1,𝑥1𝑦0则需要使用Cross-term Multiplication,本质上是基于OT的乘法(可以参考ABY)

2.5.2 Signed Multiplication

类似对于Share Extension和Truncation的处理

2.5.3 Matrix Multiplication

注意,这里不同于基础乘法,矩阵乘法由于需要进行𝑑2-1次加法,因此加法有可能出现overflow,因此需要将加和的结果使用𝑙=𝑚+𝑛+𝑒比特表示,其中𝑒=「log𝑑2」。为了减少开销,在第一行首先将𝑛-bit的输入𝑌 extend到𝑛+𝑒-bit,接着和𝑋执行non-uniform Multiplication。然而,这里的通信复杂度为𝑂(𝑑1𝑑2𝑑3)。

2.6 Digit Decomposition

此协议的目的是得到𝑙-bit 输入 𝑥,按照𝑑-bit分段,各段的分享。这个协议主要用于数学运算的协议实现。这个协议的关键步骤在于第5~6行,计算𝑢𝑖+1=𝑤𝑖⊕(𝑢𝑖∧𝑒𝑖),其中𝑤𝑖表示第𝑖段的进位,𝑒𝑖表示第𝑖段对应的𝑑-bit 数各位上是否为全1,𝑢𝑖表示低位对第𝑖段的进位。使用如下的示意图理解:

  • 初始𝑢0=0因为没有来自低位的进位。
  • 计算𝑒0∧𝑢0,如果𝑦0全1,且有来自低位的进位,则说明𝑦1有来自低位的进位
  • 此外,如果直接𝑤0=1,那么同样说明𝑦1有来自低位的进位
  • 因此𝑢1=𝑤0⊕(𝑢0∧𝑒0)
  • 最后𝑧1=𝑦1+𝑢1
  • 迭代计算到最高位的一段即可

2.7 MSNZB

此协议的目的是得到最高位非0比特的位置,例如,0000|0101→2基本原理为:𝑀𝑆𝑁𝑍𝐵(𝑦)=𝑀𝑆𝑁𝑍𝐵(𝑦𝑖)+𝑖·𝑑,满足𝑦𝑖!=0,𝑦𝑖=0,𝑗>𝑖具体的实现可以参考原文章附录,需要注意的是,这个协议里面使用到的 子协议 MSNZB-P,Zeros and One-Hot 是使用LUT(Look Up Table,[3])实现的。

3. 数学运算的协议设计

以下数学计算的实现思路类似,核心是结合LUT得到initial guess,然后利用Iterative method 迭代近似正确值。

3.1 Exponential

注意这里第2行是预先构造出一个LUT,满足输入某一段的值,查表对应的结果即为对数计算的输出。

3.2 Sigmoid & Tanh

Sigmoid的计算为而𝑇𝑎𝑛𝘩(𝑧)=2𝑆𝑖𝑔𝑚𝑜𝑖𝑑(2𝑧)-1因此计算这俩个函数只需要用到exponential和reciprocal计算,exponential如上已经构造好,而reciprocal的协议如下:Reciprocal的计算首先根据LUT得到初始近似,再迭代得到最终结果。

3.3 Reciprocal of Square Root

和Reciprocal一样,平方根求逆计算协议同样是遵循了这个范式:首先使用LUT得到初始近似值,再利用Goldschmidt 迭代得到近似计算结果。

4. 模型预测实验对比

实验部分没有比较primitive,如乘法的耗时,只比较了math library以及nn inference的效率

4.1 Math Library

针对基础的数学计算,选择MiniONN、MP-SPDZ作为对比,效率上有显著提升,同时精度显著高于这两个工作。这主要得益于LUT的初始猜测精度更高。

4.2 DNN

针对2PC安全预测,和MiniONN、DeepSecure进行了实验对比,最快效率是DeepSecure的87倍,MiniONN的2.2倍

4.3 RNN

针对RNN的安全预测,对比对象为ABY,效率有显著的提升,这主要得益于通信开销的显著优化

5. 总结

SIRNN基于CrypTFLow2以及LUT,设计实现了混合比特的数学计算,并应用在DNN、RNN的预测上,通过减少通信开销,显著提高了效率。参考文献[1]: SIRNN: A Math Library for Secure RNN Inference, S&P 21'[2]: ABY – A Framework for Efficient Mixed-Protocol Secure Two-Party Computation, NDSS 15'[3]: Pushing the Communication Barrier in Secure Computation using Lookup Tables, NDSS 17'[4]: CrypTFlow2: Practical 2-Party Secure Inference, CCS 20'

相关文章
|
机器学习/深度学习 安全 算法
技术焦点篇|Cheetah猎豹及其在隐语中的实现
技术焦点篇|Cheetah猎豹及其在隐语中的实现
1620 1
|
存储 物联网 测试技术
改变LoRA的初始化方式,北大新方法PiSSA显著提升微调效果
【4月更文挑战第23天】北京大学团队提出的新方法PiSSA,基于SVD进行参数高效微调,降低计算成本。PiSSA通过聚焦低秩矩阵训练,实现与全参数微调相当甚至更好的性能,快于LoRA收敛且在五个基准测试中胜出。PiSSA继承LoRA的参数效率,初始化仅需几秒,适合快速适应不同下游任务。尽管有潜力,但其在更大模型和任务上的效果,以及与LoRA结合的可能优化,仍是未来研究课题。[链接](https://arxiv.org/pdf/2404.02948.pdf)
583 7
|
11月前
|
传感器 人工智能 物联网
《跨越架构鸿沟:分布式软总线实现设备通信大一统》
随着设备多样性增加,不同芯片架构(如X86、ARM、RISC-V)在通信中面临诸多障碍。分布式软总线技术应运而生,通过融合底层通信技术、协议货架适配和中间适配层,屏蔽硬件、操作系统及协议差异,实现高效统一通信。该技术已在智能家居与办公场景中展现价值,未来结合AI与新一代通信技术,将助力万物互联愿景的实现。
463 6
|
6月前
|
存储 JSON 对象存储
零门槛玩转向量引擎!阿里云 Milvus 无代码全流程实操指南
阿里云Milvus版是企业级向量引擎,支持非结构化数据语义检索。全托管架构、开源兼容,助力智能驾驶、电商推荐、智能客服等场景实现毫秒级精准匹配,无代码操作让AI落地更高效。
831 0
|
8月前
|
机器学习/深度学习 自动驾驶 算法
基于深度学习的YOLO框架的7种交通场景识别项目系统【附完整源码+数据集】
在智慧交通和智能驾驶日益普及的今天,准确识别复杂交通场景中的关键元素已成为自动驾驶系统的核心能力之一。传统的图像处理技术难以适应高动态、复杂天气、多目标密集的交通环境,而基于深度学习的目标检测算法,尤其是YOLO(You Only Look Once)系列,因其检测速度快、精度高、可部署性强等特点,在交通场景识别中占据了重要地位。
989 0
基于深度学习的YOLO框架的7种交通场景识别项目系统【附完整源码+数据集】
|
9月前
|
机器学习/深度学习 自然语言处理 测试技术
Qwen3技术报告首次全公开!“混合推理模型”是这样炼成的
近日,通义千问Qwen3系列模型已开源,其技术报告也正式发布。Qwen3系列包含密集模型和混合专家(MoE)模型,参数规模从0.6B到235B不等。该模型引入了“思考模式”与“非思考模式”的动态切换机制,并采用思考预算机制优化推理性能。Qwen3支持119种语言及方言,较前代显著提升多语言能力,在多个基准测试中表现领先。此外,通过强到弱蒸馏技术,轻量级模型性能优异,且计算资源需求更低。所有Qwen3模型均采用Apache 2.0协议开源,便于社区开发与应用。
6642 30
|
机器学习/深度学习 编解码 BI
RT-DETR改进策略【Conv和Transformer】| CVPR-2023 BiFormer 稀疏自注意力,减少内存占用
RT-DETR改进策略【Conv和Transformer】| CVPR-2023 BiFormer 稀疏自注意力,减少内存占用
383 0
RT-DETR改进策略【Conv和Transformer】| CVPR-2023 BiFormer 稀疏自注意力,减少内存占用
|
机器学习/深度学习 人工智能 安全
阿里云先知安全沙龙(武汉站) ——AI赋能软件漏洞检测,机遇, 挑战与展望
本文介绍了漏洞检测的发展历程、现状及未来展望。2023年全球披露的漏洞数量达26447个,同比增长5.2%,其中超过7000个具有利用代码,115个已被广泛利用,涉及多个知名软件和系统。文章探讨了从人工审计到AI技术的应用,强调了数据集质量对模型性能的重要性,并展示了不同检测模型的工作原理与实现方法。此外,还讨论了对抗攻击对模型的影响及提高模型可解释性的多种方法,展望了未来通过任务大模型实现自动化漏洞检测与修复的趋势。
|
弹性计算 关系型数据库 MySQL
CentOS 7.x操作系统的ECS云服务器上搭建WordPress网站
CentOS 7.x操作系统的ECS云服务器上搭建WordPress网站
|
Web App开发 JSON 测试技术
精通Postman接口测试:关联技术与自动化实践指南
这篇文章详细介绍了如何使用Postman进行接口测试,包括关联技术、自动化实践,以及如何通过环境变量和全局变量解决接口之间的关联性问题。
593 0
精通Postman接口测试:关联技术与自动化实践指南