#include
#include
#include
#include
#include //大整数
#include //浮点数
#include //向量
#pragma comment(lib, "NTL")
using namespace std;
using namespace NTL;
typedef long long int64;
typedef unsigned int uint;
typedef unsigned long long ulong;
/*
Barrett Reduction,大数取模
计算:a mod n
令:s = floor(a/n)
则:a mod n = a - s*n
用ZZ计算RR:
1/n = m/2^k = m>>k
这里,m = floor(2^k/n)
错误大小为:
e = 1/n - m/2^k
结果正确,要保证:a*e < 1
k越大,e越小,支持的a的范围越广
*/
//#######################################################################
// 计时器
#define Clock() std::chrono::high_resolution_clock::now() //高分辨率时钟,auto t = Clock();
#define Time(t_start,t_end) std::chrono::duration(t_end - t_start).count() //double, ms
static std::chrono::time_point TM_start, TM_end;
#define Timer(code) TM_start = Clock(); code; TM_end = Clock(); std::cout << Time(TM_start,TM_end) << "ms";
//产生随机数
void Random(char* arr, uint size)
{
std::random_device rd; //真随机数
std::mt19937 gen(rd()); //PRG
std::uniform_int_distribution dist(0, 256); //均匀分布
for (uint i = 0; i < size; i++)
arr[i] = dist(gen);
}
//将ZZ转换为可打印字符串
string ZZ_to_bits(ZZ num)
{
long len = NumBits(num);
string str;
str.resize(len);
for (long i = len - 1; i >= 0; i--)
{
str[i] = to_int(num % 2) + '0';
num >>= 1;
}
return "0b" + str;
}
inline void print(ZZ& num)
{
cout << num << " - " << ZZ_to_bits(num) << endl;
}
//#######################################################################
class Barrett
{
private:
ZZ n;
ZZ m;
uint k;
uint acc;
public:
//初始化,acc控制精度,acc越大,e越小,支持的计算范围越广
void init(ZZ& n, uint acc);
//支持的a的范围大约是:0 ~ 2^{log(n)+acc}
void Mod(ZZ& a, ZZ& res);
//展示参数
void show();
};
void Barrett::init(ZZ& n, uint acc)
{
this->n = n;
this->acc = acc;
this->k = NumBits(n) + acc; //k的大小,必须比n的比特数要大
m = power2_ZZ(k)/n; //m = floor(2^k / n)
}
void Barrett::Mod(ZZ& a, ZZ& res)
{
res = (m*a) >> k; //q = floor(a/n) = a * m/2^k
res = a - res * n; //a - q*n
if (res >= n)
res -= n;
}
void Barrett::show()
{
printf(" * Parameter\n");
printf(" n: "); cout << n << endl;
printf(" m: "); cout << m << endl;
printf(" k: %d\n", k);
printf(" acc:%d\n\n", acc);
}
int main()
{
Barrett B;
ZZ n, a, res;
uint n_size = 1024, a_size = 2048;
RandomBits(n, n_size);
RandomBits(a, a_size);
//初始化,预计算
uint acc = 1024;
B.init(n, acc);
B.show();
//取模运算
B.Mod(a, res);
cout << "a = " << a << endl;
cout << "n = " << n << endl;
cout << "res = " << res << endl;
printf("\n\n测试速度:");
printf("\n比特数:%d", a_size);
int loop = 10;
printf("\nBarrett Mod %d次 - ", loop);
Timer(
for (int i = 0; i < loop; i++)
B.Mod(a, res);
);
printf("\nNTL Mod %d次 - ", loop);
Timer(
for (int i = 0; i < loop; i++)
res = a % n;
);
printf("\n\n");
return 0;
}
Montgomery multiplication,一种计算大数模乘的算法。
映射到蒙哥马利域上计算模乘,使用位运算 (移位、按位与) 代替取模运算。
当乘数达到16384比特,运算速度是 o p e n s s l . m u l ( ) openssl.mul() openssl.mul()的1.5倍。
#include
#include
#include
#include
#include //大整数
#include //向量
#pragma comment(lib, "NTL")
using namespace std;
using namespace NTL;
typedef long long int64;
typedef unsigned int uint;
typedef unsigned long long ulong;
/*
Montgomery Multiplication,蒙哥马利模乘算法
计算:a*b mod N
寻找:
gcd(R,N) = 1
R_inv * R = 1 mod N
做双射:
a_bar <- a*R mod N
b_bar <- b*R mod N
在Montgomery域上,计算:
a_bar * b_bar * R_inv = ab_bar mod N
若R是2^s形式,那么:
a mod R,就是和 R-1 按位与
a*R,就是a右移s位
a*R_inv,若有整除关系:2^s | a,就是a左移s位
构造整除关系:
T * R_inv mod N = (T + m*N) * R_inv
推导:
T+m*N = 0 mod R
m*N = -T mod R
m = -T * N_inv mod R
转换结果:
ab = ab_bar * R_inv mod N
*/
//#######################################################################
// 计时器
#define Clock() std::chrono::high_resolution_clock::now() //高分辨率时钟,auto t = Clock();
#define Time(t_start,t_end) std::chrono::duration(t_end - t_start).count() //double, ms
static std::chrono::time_point TM_start, TM_end;
#define Timer(code) TM_start = Clock(); code; TM_end = Clock(); std::cout << Time(TM_start,TM_end) << "ms";
//产生随机数
void Random(char* arr, uint size)
{
std::random_device rd; //真随机数
std::mt19937 gen(rd()); //PRG
std::uniform_int_distribution dist(0, 256); //均匀分布
for (uint i = 0; i < size; i++)
arr[i] = dist(gen);
}
//将ZZ转换为可打印字符串
string ZZ_to_bits(ZZ num)
{
long len = NumBits(num);
string str;
str.resize(len);
for (long i = len - 1; i >= 0; i--)
{
str[i] = to_int(num % 2) + '0';
num >>= 1;
}
return "0b" + str;
}
inline void print(ZZ& num)
{
cout << num << " - " << ZZ_to_bits(num) << endl;
}
//#######################################################################
//蒙哥马利模乘算法
class Montgomery
{
private:
uint R_pow;//幂次
ZZ R;//格式为2的幂次
ZZ N;
ZZ N_inv;//N_inv * N = -1 mod R
ZZ Z;//零
ZZ L;//R的掩码
ZZ m;
public:
//初始化模型
void init(ZZ& N);
//映射到 Montgomery域 上
void Map(ZZ& src, ZZ& dst);
void Map(Vec& src, Vec& dst);
//从 Montgomery域 映射回去
void InvMap(ZZ& src, ZZ& dst);
void InvMap(Vec& src, Vec& dst);
//乘法
void Mul(ZZ& a, ZZ& b, ZZ& ab);
//展示参数
void show();
};
void Montgomery::init(ZZ& N)
{
if ((N & ZZ(1)) != 1) //N应当是奇数,因为R取的2的幂次
return;
this->N = N;
//寻找与N互素的R,R = 0b1000...000
R_pow = NumBits(N);
R = 1;
R <<= R_pow; //R = 2^R_pow
Z = 0; //零
L = R - 1; //掩码
ZZ d, s, t;
XGCD(d, s, t, N, R); //d=1, s*N = 1 mod R
N_inv = R - s; //N_inv * N = -1 mod R
}
// a*R mod N
void Montgomery::Map(ZZ& src, ZZ& dst)
{
dst = src << R_pow; //a*R
dst %= N;
}
// a*R mod N
void Montgomery::Map(Vec& src, Vec& dst)
{
uint len = src.length();
dst.SetLength(len);
for (uint i = 0; i < len; i++)
{
dst[i] = src[i] << R_pow; //a*R
dst[i] %= N;
}
}
// a_bar * R_inv mod N
void Montgomery::InvMap(ZZ& src, ZZ& dst)
{
dst = src; //T
m = dst * N_inv; //m = T * (-N_inv)
m &= L; //m mod R
dst += m * N; //T + m*N
dst >>= R_pow; //(T + m*N) * R_inv
if (dst > N)
dst -= N;
}
// a_bar * R_inv mod N
void Montgomery::InvMap(Vec& src, Vec& dst)
{
uint len = src.length();
dst.SetLength(len);
for (uint i = 0; i < len; i++)
{
dst[i] = src[i]; //T
m = dst[i] * N_inv; //m = T * (-N_inv)
m &= L; //m mod R
dst[i] += m * N; //T + m*N
dst[i] >>= R_pow; //(T + m*N) * R_inv
if (dst[i] > N)
dst[i] -= N;
}
}
//Montgomery乘法
void Montgomery::Mul(ZZ& a, ZZ& b, ZZ& ab)
{
ab = a*b; //T = a*R * b*R
m = ab * N_inv; //m = T * (-N_inv)
m &= L; //m mod R
ab += m * N; //(T + m*N)
ab >>= R_pow; //(T + m*N) * R_inv
if (ab > N)
ab -= N;
}
//展示参数
void Montgomery::show()
{
printf(" * Parameter\n");
printf(" N: "); print(N);
printf(" N_inv: "); print(N_inv);
printf(" R: "); print(R);
printf(" L: "); print(L);
printf(" Z: "); print(Z);
printf(" R_pow: %d\n\n", R_pow);
}
//#######################################################################
int main()
{
Montgomery m;
ZZ a, a_bar;
ZZ b, b_bar;
ZZ N;
ZZ T, T_bar;
//产生随机数
uint len = 2048;
RandomLen(a, len);
RandomLen(b, len);
RandomLen(N, len);
if ((N & ZZ(1)) != 1) //让N是奇数
N += 1;
m.init(N);
m.show();
m.Map(a, a_bar);
m.Map(b, b_bar);
m.Mul(a_bar, b_bar, T_bar);
m.InvMap(T_bar, T);
cout << "a = " << a << endl;
cout << "b = " << b << endl;
cout << "N = " << N << endl;
cout << "T = " << T << endl;
printf("\n\n测试速度:");
printf("\n比特数:%d", len);
int loop = 10000;
printf("\nMontgomery Mul %d次 - ", loop);
Timer(
for (int i = 0; i < loop; i++)
m.Mul(a_bar, b_bar, T_bar);
);
printf("\nNTL Mul %d次 - ", loop);
Timer(
for (int i = 0; i < loop; i++)
T = MulMod(a, b, N);
);
printf("\n\n");
return 0;
}