Paper: SIRNN: A Math Library for Secure RNN Inference
Code: https://github.com/mpc-msri/EzPC.
2PC在整数上运算比在浮点数上运算更高效,在定点运算数中, ⌊ r 2 s ⌋ m o d 2 l \lfloor r2^s\rfloor \mod 2^l ⌊r2s⌋mod2l,其中 l l l就是bitwidth, s s s是scale。
安全参数: λ = 128 \lambda=128 λ=128
基于4种构造块:
(1)Extension(扩展)
Z 2 m → Z 2 n ( m < n ) \mathbb{Z}_{2^m} \rightarrow \mathbb{Z}_{2^n} (m
GC需要的通信开销(重构和重共享)为: λ ( 4 m + 2 n ) \lambda(4m+2n) λ(4m+2n) bits,SIRNN需要的通信开销仅为: λ m \lambda m λm,大约比GC快6x。
(2)Truncation(截断)
常用于乘法之后减小规模,对于 l l l-bit截断了 s s s-bit有四种截断操作:
目前最好的算数右移通信大约是: λ ( l + s ) \lambda(l+s) λ(l+s),本文提出的逻辑/算数右移协议大约是 λ l \lambda l λl,大多数数学函数都只需要截断且减小去减小scale和bitwidth,SIRNN只需要 λ ( s + 1 ) \lambda (s+1) λ(s+1)通信。
(3)Multiplication(乘法)
m m m-bit整数和 n n n-bit整数相乘得到 l = ( m + n ) l=(m+n) l=(m+n)-bit输出, l l l的选择保证了没有溢出。
(4)Digit Decomposition(数位分解)
将 l l l-bit的值分解为 c = l / d c=l/d c=l/d个 d d d-bits,可以用GC实现,通信量为 λ ( 6 l − 2 c − 2 ) \lambda (6l-2c-2) λ(6l−2c−2) bits。本文进一步优化,通信量为 λ ( c − 1 ) ( d + 2 ) \lambda (c-1)(d+2) λ(c−1)(d+2) bits,大约比GC低5x。
ULP是真实数据和函数输出值之间的可表示数值的数量。
符号 | 意义 |
---|---|
x ∈ Z 2 l x\in \mathbb Z_{2^l} x∈Z2l | power-of-2 rings, x x x的环为 Z 2 l \mathbb Z_{2^l} Z2l,即以 2 l 2^l 2l为模 |
B B B | ring Z 2 \mathbb Z_2 Z2,即以2为模 |
λ \lambda λ | 计算安全系数 |
⊕ \oplus ⊕ | 异或门 |
ζ l , ζ l , m ( m > l ) \zeta_l, \zeta_{l,m} (m>l) ζl,ζl,m(m>l) | 无损lifting操作,映射 Z L → Z \mathbb Z_L\rightarrow \mathbb Z ZL→Z,映射 Z L → Z M \mathbb Z_L\rightarrow \mathbb Z_M ZL→ZM |
L , M , N L,M,N L,M,N | 2 l , 2 m , 2 n 2^l, 2^m, 2^n 2l,2m,2n |
[ k ] [k] [k] | 0 , 1 , . . , k − 1 {0, 1, .., k-1} 0,1,..,k−1 |
1 { b } 1\{b\} 1{b} | b = t r u e b=true b=true时为1,反之为0 |
i n t ( x ) int(x) int(x)和 u i n t ( x ) uint(x) uint(x) | 对于 x ∈ Z l x\in \mathbb Z^l x∈Zl,分别代表有符号和无符号值,int(x)=uint(x)−MSB(x)L |
MSB(x) | MSB(x) = 1 { x ≥ 2 l − 1 } =1\{x\geq 2^{l-1}\} =1{x≥2l−1},表示最有效高位 |
F M i l l l ( x , y ) F_{Mill}^l(x, y) FMilll(x,y) | F M i l l l ( x , y ) = ⟨ z ⟩ B = 1 { x < y } F_{Mill}^l(x, y)=\langle z\rangle^B=1\{x |
F w r a p l F_{wrap}^l Fwrapl | F w r a p l = F M i l l l ( L − 1 − x , y ) : w = w r a p ( x , y , L ) = 1 { x + y ≥ L } F_{wrap}^l=F_{Mill}^l(L-1-x, y): w=wrap(x, y, L)=1\{x+y\geq L\} Fwrapl=FMilll(L−1−x,y):w=wrap(x,y,L)=1{x+y≥L} |
e e e | e = 1 { ( x + y m o d L ) = L − 1 } e=1\{(x+y \mod L)=L-1\} e=1{(x+ymodL)=L−1},判断是否全是1 |
F w r a p & a l l 1 s l F_{wrap\&all1s}^l Fwrap&all1sl | F w r a p & a l l 1 s l ( x , y ) = ( ⟨ w ⟩ B ∣ ∣ ⟨ e ⟩ B ) F_{wrap\&all1s}^l(x,y)=(\langle w\rangle^B||\langle e\rangle^B) Fwrap&all1sl(x,y)=(⟨w⟩B∣∣⟨e⟩B),至多一项是1 |
∗ m *_m ∗m | x ∗ m y = x y m o d M x*_m y=xy\mod M x∗my=xymodM,从 Z × Z → Z M \mathbb Z \times \mathbb Z \rightarrow \mathbb Z_M Z×Z→ZM |
l l l | bitwidth |
s s s | scale |
l − s l-s l−s | 整数部分的bitwidth |
F i x ( x , l , s ) Fix(x, l, s) Fix(x,l,s) | F i x ( x , l , s ) = x 2 s m o d L Fix(x, l, s)=x2^s \mod L Fix(x,l,s)=x2smodL,从实数转到定点数表示 |
u r t ( l , s ) ( a ) urt_{(l,s)}(a) urt(l,s)(a) | 对于无符号数, u r t ( l , s ) ( a ) = u i n t ( a ) / 2 s urt_{(l,s)}(a)=uint(a)/2^s urt(l,s)(a)=uint(a)/2s,从定点数转到实数表示 |
s r t ( l , s ) ( a ) srt_{(l,s)}(a) srt(l,s)(a) | 对于有符号数, s r t ( l , s ) ( a ) = i n t ( a ) / 2 s srt_{(l,s)}(a)=int(a)/2^s srt(l,s)(a)=int(a)/2s,从定点数转到实数表示 |
> > L , > > A >>_L, >>_A >>L,>>A | 逻辑右移和算术右移 |
对于 m m m-bit的数 x ∈ Z M x\in \mathbb Z_M x∈ZM,将其转换为 n n n-bit的数( n > m n>m n>m),这个过程就称为扩展(extension)。零扩展和有符号扩展分别用于扩展无符号数和有符号数的位宽。
零扩展(Zero Extension)
P 0 P_0 P0和 P 1 P_1 P1两方输入 ⟨ x ⟩ m \langle x\rangle^m ⟨x⟩m,扩展输出 ⟨ y ⟩ n \langle y\rangle^n ⟨y⟩n,要求满足 u n i t ( x ) = u i n t ( y ) unit(x)=uint(y) unit(x)=uint(y)。对于 x m ∈ Z M x^m\in \mathbb Z_M xm∈ZM,可以得到 【这个等式在后面广泛使用,没太理解怎么来的】:
x m = ⟨ x ⟩ 0 m + ⟨ x ⟩ 1 m − w M x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM xm=⟨x⟩0m+⟨x⟩1m−wM
其中, w = w r a p ( ⟨ x ⟩ 0 m , ⟨ x ⟩ 1 m , M ) w=wrap(\langle x \rangle_0^m, \langle x \rangle_1^m, M) w=wrap(⟨x⟩0m,⟨x⟩1m,M),这是个boolean share,需要转换为算术share。这里考虑在 n − m n-m n−m环上转换,原因就是下面的模约减步骤会是通信量大大降低。
F B 2 A n − m ( ⟨ w ⟩ B ) = ⟨ w ⟩ n − m ∈ Z 2 n − m F_{B2A}^{n-m}(\langle w\rangle^B)=\langle w\rangle^{n-m}\in \mathbb Z_{2^{n-m}} FB2An−m(⟨w⟩B)=⟨w⟩n−m∈Z2n−m
w = ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m w = \langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m}-wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m} w=⟨w⟩0n−m+⟨w⟩1n−m−wrap(⟨w⟩0n−m,⟨w⟩1n−m,Z2n−m)2n−m
M ∗ n w = M ∗ n ( ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m ) M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m} - wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m}) M∗nw=M∗n(⟨w⟩0n−m+⟨w⟩1n−m−wrap(⟨w⟩0n−m,⟨w⟩1n−m,Z2n−m)2n−m)
其中, M ∗ n w r a p ( ⋅ ) 2 n − m = M w r a p ( ⋅ ) 2 n − m m o d N = w r a p ( ⋅ ) 2 n m o d N = 0 M_{*n}wrap(\cdot)2^{n-m}=Mwrap(\cdot)2^{n-m} \mod N=wrap(\cdot)2^{n} \mod N=0 M∗nwrap(⋅)2n−m=Mwrap(⋅)2n−mmodN=wrap(⋅)2nmodN=0(这一步称作“模约减”,modulo-reduce),所以上式子转换为:
M ∗ n w = M ∗ n ( ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m ) M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m}) M∗nw=M∗n(⟨w⟩0n−m+⟨w⟩1n−m)
于是:
y = ∑ b = 0 1 ( ⟨ x ⟩ b m − M ⟨ w ⟩ b n − m ) m o d N y = \sum_{b=0}^1(\langle x\rangle_b^m-M\langle w\rangle_b^{n-m}) \mod N y=b=0∑1(⟨x⟩bm−M⟨w⟩bn−m)modN
这里是在 P 0 P_0 P0和 P 1 P_1 P1上分别计算,然后求和取模,得到扩展后的结果。其中, x m o d N = y x \mod N=y xmodN=y。
算法如下:
需要 log ( m + 2 ) \log(m+2) log(m+2) rounds和少于 λ ( m + 1 ) + 13 m + n \lambda(m+1)+13m+n λ(m+1)+13m+n bits的通信量。作为对比,用GC实现零扩展和有符号扩展需要 λ ( 4 m + 2 n − 4 ) \lambda(4m+2n-4) λ(4m+2n−4) bits的通信量,大约是SIRNN的6倍。
有符号扩展(Signed Extension)
有符号扩展可以基于以下等式,通过转换无符号扩展得到,在环 Z \mathbb Z Z上:
i n t ( x ) = x ′ − 2 m − 1 , x ′ = x + 2 m − 1 m o d M int(x)=x'-2^{m-1}, x'=x+2^{m-1} \mod M int(x)=x′−2m−1,x′=x+2m−1modM
证明如下:
于是:
S E x t ( x , m , n ) = Z E x t ( x , m , n ) − 2 m − 1 SExt(x, m, n)=ZExt(x, m, n)-2^{m-1} SExt(x,m,n)=ZExt(x,m,n)−2m−1
相比零扩展,没有额外的通信开销。
首先,规定 > > L , > > A >>_L,>>_A >>L,>>A分别表示逻辑右移和算术右移,它们的输入和输出都是在 Z L \mathbb Z_L ZL环上。
T R ( x , s ) TR(x, s) TR(x,s)表示截断且减小(truncate & reduce),将 x ∈ Z L x\in \mathbb Z_L x∈ZL截断且减小 s s s-bits,最终得到的 x x x在更小的 Z 2 l − s \mathbb Z_{2^{l-s}} Z2l−s环上。
逻辑右移
Toy example: x = 101001 x=101001 x=101001逻辑右移3位,则 x ′ = 000101 x'=000101 x′=000101(右侧截掉,左侧补0)。
对于 x ∈ Z L x\in \mathbb Z_L x∈ZL,则 x = ⟨ x ⟩ 0 l + ⟨ x ⟩ 1 l m o d L x=\langle x\rangle_0^l+\langle x\rangle_1^l \mod L x=⟨x⟩0l+⟨x⟩1lmodL,记 ⟨ x ⟩ b l = u b ∣ ∣ v b \langle x\rangle_b^l=u_b||v_b ⟨x⟩bl=ub∣∣vb( u b u_b ub是高位, v b v_b vb是低位),其中 u b ∈ { 0 , 1 } l − s , v b ∈ { 0 , 1 } s u_b\in\{0, 1\}^{l-s}, v_b\in\{0, 1\}^{s} ub∈{0,1}l−s,vb∈{0,1}s。如下图:
根据前面提到的公式:
x m = ⟨ x ⟩ 0 m + ⟨ x ⟩ 1 m − w M x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM xm=⟨x⟩0m+⟨x⟩1m−wM
可以得到:
x > > L s = u 0 + u 1 − 2 l − s w r a p ( ⟨ x ⟩ 0 l , ⟨ x ⟩ 1 l , L ) + w r a p ( v 0 , v 1 , 2 s ) x>>_Ls=u_0+u_1-2^{l-s} wrap (\langle x\rangle_0^l, \langle x\rangle_1^l, L) + wrap(v_0, v_1, 2^s) x>>Ls=u0+u1−2l−swrap(⟨x⟩0l,⟨x⟩1l,L)+wrap(v0,v1,2s)
上式中, w r a p ( v 0 , v 1 , 2 s ) wrap(v_0, v_1, 2^s) wrap(v0,v1,2s)这一项是考虑了进位。我们知道,加性秘密共享时, v v v部分可能会存在1位进位的情况,所以 w r a p ( v 0 , v 1 , 2 s ) wrap(v_0, v_1, 2^s) wrap(v0,v1,2s)就是判断 v 0 + v 1 v_0+v_1 v0+v1是否大于 2 s 2^s 2s,如果是,则会进1,如果不是,则为0。
常规做法是计算两个 w r a p ( ⋅ ) wrap(\cdot) wrap(⋅)值即可,但是SIRNN提出了一种优化,避开直接计算位宽是 l l l的那一项。文章中的Lemma 1即是这个引理:
通信开销低于 λ ( l + 3 ) + 15 + s + 20 \lambda(l+3)+15+s+20 λ(l+3)+15+s+20,并需要 log l + 3 \log l+3 logl+3 rounds。
原文证明如下:
算术右移
对于无符号数,直接采用逻辑右移,对于有符号数,则需要采用算术右移。从前面零扩展到有符号扩展可以知道: i n t ( x ) = x ′ − 2 l − 1 , x ′ = x + 2 l − 1 m o d L int(x)=x'-2^{l-1}, x'=x+2^{l-1} \mod L int(x)=x′−2l−1,x′=x+2l−1modL,于是:
x > > A s = x > > L s − 2 l − s − 1 x>>_As = x>>_Ls-2^{l-s-1} x>>As=x>>Ls−2l−s−1
截断且减小
Toy example: x = 101001 x=101001 x=101001截断且减小3位,则 x ′ = 101 x'=101 x′=101。
因为 2 l − s ∗ l w m o d 2 l − s = 0 2^{l-s}{*_l} w \mod 2^{l-s}=0 2l−s∗lwmod2l−s=0(模约减),所以:
⟨ T R ( x , s ) ⟩ l − s = u 0 + u 1 + w r a p ( v 0 , v 1 , 2 s ) \langle TR(x, s)\rangle^{l-s}=u_0+u_1+wrap(v_0, v_1, 2^s) ⟨TR(x,s)⟩l−s=u0+u1+wrap(v0,v1,2s)
除以power-of-2
z < 0 , z = ⌈ i n t ( x ) / 2 s ⌉ m o d L ; z ≥ 0 , z = ⌊ i n t ( x ) / 2 s ⌋ m o d L z<0, z=\lceil int(x)/2^s\rceil \mod L; z\geq0, z=\lfloor int(x)/2^s\rfloor \mod L z<0,z=⌈int(x)/2s⌉modL;z≥0,z=⌊int(x)/2s⌋modL
实际上 i n t ( x ) / 2 s m o d L int(x)/2^s \mod L int(x)/2smodL就是做 > > A >>_A >>A,取整括号即是将值往0靠近。令 m x = 1 { x ≥ 2 l − 1 } m_x=1\{x\geq 2^{l-1}\} mx=1{x≥2l−1}判断 x x x的正负性, c = 1 { x m o d 2 s = 0 } c=1\{x\mod 2^s=0\} c=1{xmod2s=0}
m x = 1 m_x=1 mx=1,则 z < 0 , ⌈ z ⌉ z<0, \lceil z\rceil z<0,⌈z⌉;反之, ⌊ z ⌋ \lfloor z\rfloor ⌊z⌋。所以有:
D i v P o w 2 ( x , s ) = ( x > > A s ) + m x ∧ c DivPow2(x, s)=(x>>_As)+m_x\land c DivPow2(x,s)=(x>>As)+mx∧c
以前做乘法通常是用Beaver Triplet三元组实现,SIRNN中不能用了,因为加法和乘法的数bitwidth不一致。
无符号乘法
输入 ⟨ x ⟩ m , ⟨ y ⟩ n \langle x\rangle^m, \langle y\rangle^n ⟨x⟩m,⟨y⟩n,输出 ⟨ z ⟩ l , z = x ∗ l y , l = n + m \langle z\rangle^l, z=x*_l y, l=n+m ⟨z⟩l,z=x∗ly,l=n+m。
对于 x , y x,y x,y,在 Z \mathbb Z Z上有:
u i n t ( x ) ⋅ u i n t ( y ) = ( x 0 + x 1 − 2 m w x ) ⋅ ( y 0 + y 1 − 2 n w y ) = x 0 y 0 + x 0 y 1 + x 1 y 0 + x 1 y 1 − 2 m w x y − 2 n w y x + 2 l w x w y uint(x)\cdot uint(y)=(x_0+x_1-2^mw_x)\cdot(y_0+y_1-2^nw_y)\\=x_0y_0+x_0y_1+x_1y_0+x_1y_1-2^mw_xy-2^nw_yx+2^lw_xw_y uint(x)⋅uint(y)=(x0+x1−2mwx)⋅(y0+y1−2nwy)=x0y0+x0y1+x1y0+x1y1−2mwxy−2nwyx+2lwxwy
观察上式, x 0 y 0 , x 1 y 1 x_0y_0,x_1y_1 x0y0,x1y1都是可以本地计算的【本地计算为什么不管位宽是否一致?】, 2 l w x w y 2^lw_xw_y 2lwxwy可以在 m o d L \mod L modL时被消掉(模约减), w x y , x y x w_xy, x_yx wxy,xyx是boolean share和算术share的计算,本质上是MUX,可用直接用OT实现。最难的一项是交叉项 x 0 y 1 , x 1 y 0 x_0y_1, x_1y_0 x0y1,x1y0,SIRNN采用COT实现。
巧妙的一点在于:选择比特位短的一方作为receiver,比特位长的一方作为sender,这样在做OT的取数时,round数就会更少。
交叉项算法如下:
无符号乘法算法如下:
SIRNN利用1-out-of-2的COT来实现这个过程,将短的数按位拆解,每一位非0即1,然后做二选一的COT,每一位计算完成后,在本地累加起来。
通信开销大约是: λ ( 3 μ + v ) + μ ( μ + 2 v ) + 16 ( m + n ) \lambda(3\mu + v) + \mu(\mu + 2v) + 16(m + n) λ(3μ+v)+μ(μ+2v)+16(m+n),其中 μ = min ( m , n ) , ν = max ( m , n ) \mu = \min(m, n), ν = \max(m, n) μ=min(m,n),ν=max(m,n)。普通的扩展位数然后相乘的开销是: 3 λ ( μ + v ) + ( m + n ) 2 + 15 ( m + n ) 3\lambda(\mu+v)+(m+n)^2+15(m + n) 3λ(μ+v)+(m+n)2+15(m+n),大约是SIRNN的1.5x。
有符号乘法
布尔分享转换为算术分享:
⟨ x ⟩ A = ⟨ x ⟩ 0 B + ⟨ x ⟩ 1 B − 2 ⟨ x ⟩ 0 B ⟨ x ⟩ 1 B \langle x\rangle^A=\langle x\rangle_0^B+\langle x\rangle_1^B-2\langle x\rangle_0^B\langle x\rangle_1^B ⟨x⟩A=⟨x⟩0B+⟨x⟩1B−2⟨x⟩0B⟨x⟩1B
基于前面无符号数和有符号数的关系,可以得到:无符号数 x ′ = x + 2 m − 1 m o d M , y ′ = y + 2 n − 1 m o d N x'=x+2^{m-1}\mod M, y'=y+2^{n-1}\mod N x′=x+2m−1modM,y′=y+2n−1modN。由秘密共享, x ′ = x 0 ′ + x 1 ′ m o d M , y ′ = y 0 ′ + y 1 ′ m o d N x'=x_0'+x_1' \mod M, y'=y_0'+y_1' \mod N x′=x0′+x1′modM,y′=y0′+y1′modN。有符号数 i n t ( x ) = x ′ − 2 m − 1 , i n t ( y ) = y ′ − 2 n − 1 int(x)=x'-2^{m-1}, int(y)=y'-2^{n-1} int(x)=x′−2m−1,int(y)=y′−2n−1。因此,在 Z \mathbb Z Z环上:
x ′ y ′ x'y' x′y′是无符号数的乘法,可以用algorithm 3计算, 2 m − 1 y b ′ , 2 n − 1 x b ′ 2^{m-1}y_b', 2^{n-1}x_b' 2m−1yb′,2n−1xb′也都可以在本地计算出来。难点是wrap项应该如何计算。
2 m + n − 1 w x ′ = 2 l − 1 w x ′ = 2 l − 1 ( ⟨ w x ′ ⟩ 0 B + ⟨ w x ′ ⟩ 1 B − 2 ⟨ w x ′ ⟩ 0 B ⟨ w x ′ ⟩ 1 B ) 2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B-2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B) 2m+n−1wx′=2l−1wx′=2l−1(⟨wx′⟩0B+⟨wx′⟩1B−2⟨wx′⟩0B⟨wx′⟩1B)
其中, 2 ⟨ w x ′ ⟩ 0 B ⟨ w x ′ ⟩ 1 B 2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B 2⟨wx′⟩0B⟨wx′⟩1B与 2 l − 1 2^{l-1} 2l−1相乘再 m o d L \mod L modL后会被消除掉,所以无需计算。因此,上式变为:
2 m + n − 1 w x ′ = 2 l − 1 w x ′ = 2 l − 1 ( ⟨ w x ′ ⟩ 0 B + ⟨ w x ′ ⟩ 1 B ) 2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B) 2m+n−1wx′=2l−1wx′=2l−1(⟨wx′⟩0B+⟨wx′⟩1B)
有符号的乘法相比无符号的乘法,也没有额外的开销。
矩阵乘法和卷积
矩阵乘法和卷积是很常见的(实际上可以展开为普通乘法做elment-wise乘和加),两个矩阵 A ∈ Z M d 1 × d 2 , A ∈ Z N d 2 × d 3 A\in \mathbb Z_M^{d1\times d2}, A\in \mathbb Z_N^{d2\times d3} A∈ZMd1×d2,A∈ZNd2×d3,输出矩阵乘法结果 A ∈ Z L d 1 × d 3 A\in \mathbb Z_L^{d1\times d3} A∈ZLd1×d3,其中 l = m + n l=m+n l=m+n。做矩阵乘法需要 d 2 d_2 d2次乘以及 d 2 − 1 d_2-1 d2−1次加。
这个时候可能出现的问题是:加法导致溢出。一种解决方式是将element-wise乘后的结果扩展 e = ⌈ log d 2 ⌉ e=\lceil \log d_2\rceil e=⌈logd2⌉-bits后,再做加法。但是,这样扩展开销很大,需要扩展 d 1 d 2 d 3 d_1d_2d_3 d1d2d3次。
于是本文这样做:考虑到前面算交叉项(CrossTerm)时,通信round数取决于较小的bitwidth,所以本文将bitwidth较大的一项拿去扩展 e e e-bits,在不增加开销的情况下,扩大了环。
通信开销大致为 λ ( 3 d 1 d 2 ( m + 2 ) + d 2 d 3 ( n + 2 ) ) + d 1 d 2 d 3 ( ( 2 m + 4 ) ( n + e ) + m 2 + 5 m ) \lambda(3d_1d_2(m+2)+d_2d_3(n+2))+d_1d_2d_3((2m+4)(n+e)+m^2+5m) λ(3d1d2(m+2)+d2d3(n+2))+d1d2d3((2m+4)(n+e)+m2+5m) bits。
算法如下:
乘且截断
首先调用有符号乘法,然后截断。输入 ⟨ x ⟩ m , ⟨ y ⟩ n \langle x\rangle^m, \langle y\rangle^n ⟨x⟩m,⟨y⟩n,输出 ⟨ z ′ ⟩ l − s \langle z'\rangle^{l-s} ⟨z′⟩l−s。 z = i n t ( x ) ∗ l i n t ( y ) , z ′ = T R ( z , s ) z=int(x)*_l int(y), z'=TR(z, s) z=int(x)∗lint(y),z′=TR(z,s)。其中 l = m + n l=m+n l=m+n。