在之前的(1)——(4)中,一步步地实现并优化了RSA及其大数运算库,之前说,RSA的效率取决于除法,是因为计算模幂,需要使用取模,取模使用除法,最后归根结底到了除法上。
然而,有另一种思路,就是在计算模幂时,使用蒙哥马利算法。蒙哥马利算法能够将取模时的除法,转化为相对廉价的乘加和移位操作。
话说,网上的相关中文资料说得简明的可真不容易找,绿盟和FreeBuf的文档说的挺清晰的,之前还有一篇CSDN的也非常不错,搜索引擎搜到的第一篇应该就是。如果说最简明易懂的,可能还是Wikipedia上的蒙哥马利算法,不过是英文的。
蒙哥马利方法进行模幂,需要先实现蒙哥马利约简和蒙哥马利模乘,配合反复平方法,将原先的取模替换为使用蒙哥马利方法,实现模幂。
设对N取模,R是一个恰好比N大的数,且R是2的m次幂,R=2m。
R·R’=1(mod N),(-N)·N’=1(mod R),-N=R-N(mod R)
考虑求
z=x·y(mod N) ①
利用蒙哥马利约减来求解,蒙哥马利约简简称REDC。REDC(T)约简结果为T·R’(mod N)。只要将z按照某种“形式”输入REDC(),就有望得到最终结果。
这种“形式”,称为蒙哥马利形式。在蒙哥马利形式中:
x表示为x·R(mod N)
y表示为y·R(mod N)
把z也表示为蒙哥马利形式,可以是z·R=x·y·R=REDC(x·R·R)·REDC(y·R·R)·R(mod N),把这个形式输入REDC(z·R)进行蒙哥马利约减,得到的结果转换为正常形式,就是z(mod N)。
在这个过程中,需要防止y·R·R这样的连续乘法发生溢出,边乘边约简,计算R·R(mod N)再相乘。
计算z=x·y(mod N)
x’=REDC(x·(R·R mod N))
y’=REDC(y·(R·R mod N))
z’=REDC(x’·y’)
z=REDC(z’)
return(z)
在蒙哥马利模乘中用到了蒙哥马利约简蒙哥马利约简REDC(T)的结果为T·R’(mod N),如果输入REDC(T·R),就能够得到T(mod N)的结果。
m = ( ( T % R ) * N’ ) % R;
t = ( T + m * N ) / R;
if ( t >= N )
return( t - N );
else
return( t );
仔细观察这个算法,原本是要对N取模的,在这个算法里面没有出现除以N的操作,而转化为了对R进行取模和除R的操作。因为R是精心挑选的,是2的幂次方,对R进行取模和除法操作,都可以使用移位或者直接选取的方法,相对于除法是廉价的。在这个算法中,耗时的操作是加减法和乘法。
仔细观察,里面涉及到了R、N、N‘、R·R(mod N),其中N’和R·R(mod N)是可以预先计算的,在一次模幂中,只需要一开始计算一次,运行时间取决于乘法。
//蒙哥马利约简,result=DT*R' (mod N )
void mont_redc(BN DT, BN N, BN Np, BN R,BN & result)
{
BN temp1 = { 0 };
BND temp2 = { 0 };
BN m = { 0 }, t = { 0 };
unsigned int R_bits = getbits_b(R);//模除R就是保留这么多位
int Bits = (R_bits + 31) / 32;//换算成大的位数
int remain = R_bits - 32 * (Bits-1);//最高的“位”的第[remain-1]位为1,原封不动保留[remain-2]..[0]位就可以
//if(remain>1) 保留R[0];否则R[0]-1,保留所有R[0]-1"位"
//对于R[Bits],保留R[Bits][remain-2]..R[Bits][0]
//if (Bits>1) 保留R[Bits-1]..R[1]
//进行temp=T%R
if (getbits_b(DT) >= R_bits)//只有多的时候才要取余,R是0100,T是1101也要取余,不如构造一个,100异或101就可以了,但是不能有最高位
{
for (unsigned int i = 1; i <R[0]; i++)//全保留,都是0嘛
{
temp1[i] = DT[i];
}
if (remain >1)//最高位至少是10,remain=2,可以保留1位
{
temp1[R[0]] = (DT[R[0]] & (uint32_t)(1U<<(remain-1))-1U );
temp1[0] = R[0];
}
else {
temp1[0] = R[0]-1;
}
}
else {
cpy_b(temp1, DT);
}
//(T%R)*N',进行乘N'的操作,可能会超过1024位
mul(temp1, Np, temp2);
//m=temp1%R
if (getbits_b(temp2) >= R_bits)//只有多的时候才要取余,R是0100,T是1101也要取余,不如构造一个,100异或101就可以了,但是不能有最高位
{
for (unsigned int i = 1; i < R[0]; i++)//全保留,都是0嘛
{
m[i] = temp2[i];
}
if (remain > 1)//最高位至少是10,remain=2,可以保留1位
{
m[R[0]] = (temp2[R[0]] & (uint32_t)(1U << (remain - 1)) - 1U);
m[0] = R[0];
}
else {
m[0] = R[0] - 1;
}
}
else {
cpy_b(m, temp2);
}
memset(temp2, 0, sizeof(temp2));
mul(m, N, temp2);
add(DT, temp2, temp2);
//cout << "t*R= " << bn2str(temp2) <
for (int i = getbits_b(R)-1;i>0;i--)
{
shr_b(temp2);
}
//cout << "t= " << bn2str(temp2) << endl << endl;
cpy_b(t, temp2);
if(cmp_b(t, N) >= 0)
{
sub(t, N, t);//t=t-N
}
cpy_b(result,t);
}
void mont_modmul(BN x, BN y, BN R,BN N,BN Np,BN RRN, BN & result)//用于模幂的模乘
{
BND temp1 = { 0 }, temp2 = { 0 }, temp3 = { 0 };//可能会超过1024位!肯定不能只有1024位
BN Xp = { 0 }, Yp = { 0 }, Zp = { 0 };
mul(x, RRN, temp1);
mul(y, RRN, temp2);
mont_redc(temp1, N, Np, R, Xp);
mont_redc(temp2, N, Np, R, Yp);
mul(Xp, Yp, temp3);
mont_redc(temp3, N, Np, R, Zp);
mont_redc(Zp, N, Np, R, result);
}
void mont_modmul(BN x, BN y, BN N ,BN & result)
{
BND temp1 = { 0 }, temp2 = { 0 }, temp3 = { 0 };//可能会超过1024位!肯定不能只有1024位
BN Xp = { 0 }, Yp = { 0 }, Zp = { 0 };
int m = getbits_b(N);//模幂里基本上不会出现要模偶数,尤其加解密不太可能出现,认为是+1,R=2^m一定大于n
BN RRN = { 0 };//R*R (mod n)
BN R = { 0 }, Rp = { 0 }, Np = { 0 };
int Bits = (m + 31) / 32;//换算成大的位数
int remain = m - 32 * (Bits - 1);//最高的“位”的第[remain-1]位为1,原封不动保留[remain-2]..[0]位就可以
R[0] = Bits;
R[Bits] = (uint32_t)(1U <<remain);
//inv_b(R, N, Rp);
sub(R, N, temp1);//temp1=-N
inv_b(temp1, R, Np);
modmul(R, R, N, RRN);
//cout << "R= " << bn2str(R) << endl;
//cout << "N= " << bn2str(N) << endl;
//cout << "temp1= " << bn2str(temp1)<< endl;
//cout << "Np= " << bn2str(Np) << endl<
mul(x, RRN, temp1);
mul(y, RRN, temp2);
mont_redc(temp1, N, Np, R, Xp);
//cout << "Xp= " << bn2str(Xp) << endl << endl;
mont_redc(temp2, N, Np, R, Yp);
//cout << "Yp= " << bn2str(Yp) << endl << endl;
mul(Xp, Yp, temp3);
mont_redc(temp3, N, Np, R, Zp);
//cout << "Zp= " << bn2str(Zp) << endl << endl;
mont_redc(Zp, N, Np, R, result);
}
void mont_modexp(BN a, BN b, BN N, BN & result)//蒙哥马利模幂 a^b mod N
{
int m = getbits_b(N);//模幂里基本上不会出现要模偶数,尤其加解密不太可能出现,认为是+1,R=2^m一定大于n
BN RRN = { 0 };//R*R (mod n)
BN R = { 0 }, Rp = { 0 }, Np = { 0 };
BN a_t = { 1,1 }, b_t;//a=1;n做了二进制展开
BN temp1 = { 0 }, temp2 = { 0 };//计算作为result有个清零操作
BN Xp = { 0 }, Yp = { 0 };
int Bits = (m + 31) / 32;//换算成大的位数
int remain = m - 32 * (Bits - 1);//最高的“位”的第[remain-1]位为1,原封不动保留[remain-2]..[0]位就可以
R[0] = Bits;
R[Bits] = (uint32_t)(1U << (remain - 1));
sub(R, N, temp1);//temp1=-N
inv_b(temp1, R, Np);
modmul(R, R, N, RRN);
//b^n (mod m) --->a^b mod N
memset(result, 0, sizeof(result));
cpy_b(b_t, a);//b_t=b,初始化
uint32_t *nptr, *mnptr;
nptr = LSDPTR_B(b);
mnptr = MSDPTR_B(b);//!!!!!!!
char binform[33];//每个32bit的uint32转化为二进制即可,一次次取出来
int i = 0;
while (nptr <= mnptr)//没越界就都来做
{
memset(binform, 0, sizeof(binform));
_ultoa(*nptr, binform, 2);
i = strlen(binform) - 1;//到达最后一位
for (int j = 31; j >= 0; j--)//开始模平方
{
if (i >= 0)//正事儿,否则只是平方b取模
{
if (binform[i] == '1')
{
mont_modmul(a_t,b_t, R, N, Np, RRN, a_t);
}
i--;
}
mont_modmul(b_t, b_t, R, N, Np, RRN, b_t);
}
nptr++;
}
cpy_b(result, a_t);
RMLDZRS_B(result);
}
运行速度比使用除法的慢不少,应该是代码写得比较low的原因,怀疑移位函数shr_b()写得不够巧妙,乘法函数应该是没有问题,乘法使用快速乘法相比原先的乘法,速度上的贡献并不是特别大,而且数位只有1024位体现不出来。
鉴于速度暂时比原先的慢(或者真的是慢),这个版本的暂不同步到github。希望大神们能指出这其中的不足之处,非常感谢。