大数运算:Barrett And Montgomery

Barrett reduction

  • Barrett Reduction,一种计算大数取模的算法。
  • 使用ZZ上的移位来代替RR上的浮点除法。
  • 当模数达到65536比特,运算速度是 o p e n s s l . m o d ( ) openssl.mod() openssl.mod()的2倍。
#include 
#include 
#include 
#include 

#include  //大整数
#include  //浮点数
#include  //向量
#pragma comment(lib, "NTL")

using namespace std;
using namespace NTL;

typedef long long int64;
typedef unsigned int uint;
typedef unsigned long long ulong;

/*
	Barrett Reduction,大数取模

	计算:a mod n
	令:s = floor(a/n)
	则:a mod n = a - s*n

	用ZZ计算RR:
	1/n = m/2^k = m>>k
	这里,m = floor(2^k/n)

	错误大小为:
	e = 1/n - m/2^k

	结果正确,要保证:a*e < 1
	k越大,e越小,支持的a的范围越广
*/


//#######################################################################

// 计时器
#define Clock() std::chrono::high_resolution_clock::now() //高分辨率时钟,auto t = Clock();
#define Time(t_start,t_end) std::chrono::duration(t_end - t_start).count() //double, ms
static std::chrono::time_point TM_start, TM_end;
#define Timer(code) TM_start = Clock(); code; TM_end = Clock(); std::cout << Time(TM_start,TM_end) << "ms";


//产生随机数
void Random(char* arr, uint size)
{
	std::random_device rd; //真随机数
	std::mt19937 gen(rd()); //PRG
	std::uniform_int_distribution dist(0, 256); //均匀分布

	for (uint i = 0; i < size; i++)
		arr[i] = dist(gen);
}

//将ZZ转换为可打印字符串
string ZZ_to_bits(ZZ num)
{
	long len = NumBits(num);
	string str;
	str.resize(len);
	for (long i = len - 1; i >= 0; i--)
	{
		str[i] = to_int(num % 2) + '0';
		num >>= 1;
	}
	return "0b" + str;
}

inline void print(ZZ& num)
{
	cout << num << " - " << ZZ_to_bits(num) << endl;
}

//#######################################################################

class Barrett
{
private:

	ZZ n;
	ZZ m;
	uint k;
	uint acc;


public:

	//初始化,acc控制精度,acc越大,e越小,支持的计算范围越广
	void init(ZZ& n, uint acc);

	//支持的a的范围大约是:0 ~ 2^{log(n)+acc}
	void Mod(ZZ& a, ZZ& res);

	//展示参数
	void show();

};


void Barrett::init(ZZ& n, uint acc)
{
	this->n = n;
	this->acc = acc;
	this->k = NumBits(n) + acc; //k的大小,必须比n的比特数要大

	m = power2_ZZ(k)/n; //m = floor(2^k / n)
}

void Barrett::Mod(ZZ& a, ZZ& res)
{
	res = (m*a) >> k; //q = floor(a/n) = a * m/2^k
	res = a - res * n; //a - q*n

	if (res >= n)
		res -= n;
}

void Barrett::show()
{
	printf(" * Parameter\n");
	printf("   n:	"); cout << n << endl;
	printf("   m:	"); cout << m << endl;
	printf("   k:	%d\n", k);
	printf("   acc:%d\n\n", acc);
}




int main()
{
	Barrett B;

	ZZ n, a, res;
	uint n_size = 1024, a_size = 2048;
	RandomBits(n, n_size);
	RandomBits(a, a_size);

	//初始化,预计算
	uint acc = 1024;
	B.init(n, acc);
	B.show();

	//取模运算
	B.Mod(a, res);

	cout << "a = " << a << endl;
	cout << "n = " << n << endl;
	cout << "res = " << res << endl;


	printf("\n\n测试速度:");
	printf("\n比特数:%d", a_size);
	int loop = 10;

	printf("\nBarrett Mod %d次 - ", loop);
	Timer(
		for (int i = 0; i < loop; i++)
			B.Mod(a, res);
	);

	printf("\nNTL Mod %d次 - ", loop);
	Timer(
		for (int i = 0; i < loop; i++)
			res = a % n;
	);

	printf("\n\n");


	return 0;
}

Montgomery multiplication

  • Montgomery multiplication,一种计算大数模乘的算法。

  • 映射到蒙哥马利域上计算模乘,使用位运算 (移位、按位与) 代替取模运算。

  • 当乘数达到16384比特,运算速度是 o p e n s s l . m u l ( ) openssl.mul() openssl.mul()的1.5倍。

#include 
#include 
#include 
#include 

#include  //大整数
#include  //向量
#pragma comment(lib, "NTL")

using namespace std;
using namespace NTL;

typedef long long int64;
typedef unsigned int uint;
typedef unsigned long long ulong;

/*
	Montgomery Multiplication,蒙哥马利模乘算法
	计算:a*b mod N

	寻找:
	gcd(R,N) = 1
	R_inv * R = 1 mod N

	做双射:
	a_bar <- a*R mod N
	b_bar <- b*R mod N

	在Montgomery域上,计算:
	a_bar * b_bar * R_inv = ab_bar mod N

	若R是2^s形式,那么:
	a mod R,就是和 R-1 按位与
	a*R,就是a右移s位
	a*R_inv,若有整除关系:2^s | a,就是a左移s位

	构造整除关系:
	T * R_inv mod N = (T + m*N) * R_inv

	推导:
	T+m*N = 0 mod R
	m*N = -T mod R
	m = -T * N_inv mod R

	转换结果:
	ab = ab_bar * R_inv mod N
*/


//#######################################################################

// 计时器
#define Clock() std::chrono::high_resolution_clock::now() //高分辨率时钟,auto t = Clock();
#define Time(t_start,t_end) std::chrono::duration(t_end - t_start).count() //double, ms
static std::chrono::time_point TM_start, TM_end;
#define Timer(code) TM_start = Clock(); code; TM_end = Clock(); std::cout << Time(TM_start,TM_end) << "ms";


//产生随机数
void Random(char* arr, uint size)
{
	std::random_device rd; //真随机数
	std::mt19937 gen(rd()); //PRG
	std::uniform_int_distribution dist(0, 256); //均匀分布

	for (uint i = 0; i < size; i++)
		arr[i] = dist(gen);
}


//将ZZ转换为可打印字符串
string ZZ_to_bits(ZZ num)
{
	long len = NumBits(num);
	string str;
	str.resize(len);
	for (long i = len - 1; i >= 0; i--)
	{
		str[i] = to_int(num % 2) + '0';
		num >>= 1;
	}
	return "0b" + str;
}

inline void print(ZZ& num)
{
	cout << num << " - " << ZZ_to_bits(num) << endl;
}


//#######################################################################

//蒙哥马利模乘算法
class Montgomery
{
private:

	uint R_pow;//幂次
	ZZ R;//格式为2的幂次
	ZZ N;
	ZZ N_inv;//N_inv * N = -1 mod R
	ZZ Z;//零
	ZZ L;//R的掩码
	ZZ m;

public:

	//初始化模型
	void init(ZZ& N);

	//映射到 Montgomery域 上
	void Map(ZZ& src, ZZ& dst);
	void Map(Vec& src, Vec& dst);

	//从 Montgomery域 映射回去
	void InvMap(ZZ& src, ZZ& dst);
	void InvMap(Vec& src, Vec& dst);

	//乘法
	void Mul(ZZ& a, ZZ& b, ZZ& ab);

	//展示参数
	void show();

};


void Montgomery::init(ZZ& N)
{
	if ((N & ZZ(1)) != 1) //N应当是奇数,因为R取的2的幂次
		return;

	this->N = N;

	//寻找与N互素的R,R = 0b1000...000
	R_pow = NumBits(N);

	R = 1;
	R <<= R_pow; //R = 2^R_pow
	Z = 0; //零
	L = R - 1; //掩码

	ZZ d, s, t;
	XGCD(d, s, t, N, R); //d=1, s*N = 1 mod R

	N_inv = R - s; //N_inv * N = -1 mod R
}

// a*R mod N
void Montgomery::Map(ZZ& src, ZZ& dst)
{
	dst = src << R_pow; //a*R
	dst %= N;
}

// a*R mod N
void Montgomery::Map(Vec& src, Vec& dst)
{
	uint len = src.length();
	dst.SetLength(len);

	for (uint i = 0; i < len; i++)
	{
		dst[i] = src[i] << R_pow; //a*R
		dst[i] %= N;
	}
}

// a_bar * R_inv mod N
void Montgomery::InvMap(ZZ& src, ZZ& dst)
{
	dst = src; //T

	m = dst * N_inv; //m = T * (-N_inv)
	m &= L; //m mod R

	dst += m * N; //T + m*N

	dst >>= R_pow; //(T + m*N) * R_inv
	if (dst > N)
		dst -= N;
}

// a_bar * R_inv mod N
void Montgomery::InvMap(Vec& src, Vec& dst)
{
	uint len = src.length();
	dst.SetLength(len);

	for (uint i = 0; i < len; i++)
	{
		dst[i] = src[i]; //T

		m = dst[i] * N_inv; //m = T * (-N_inv)
		m &= L; //m mod R

		dst[i] += m * N; //T + m*N

		dst[i] >>= R_pow; //(T + m*N) * R_inv
		if (dst[i] > N)
			dst[i] -= N;
	}
}


//Montgomery乘法
void Montgomery::Mul(ZZ& a, ZZ& b, ZZ& ab)
{
	ab = a*b; //T = a*R * b*R

	m = ab * N_inv; //m = T * (-N_inv)
	m &= L; //m mod R

	ab += m * N; //(T + m*N)

	ab >>= R_pow; //(T + m*N) * R_inv
	if (ab > N)
		ab -= N;
}

//展示参数
void Montgomery::show()
{
	printf(" * Parameter\n");
	printf("   N:		"); print(N);
	printf("   N_inv:	"); print(N_inv);
	printf("   R:		"); print(R);
	printf("   L:		"); print(L);
	printf("   Z:		"); print(Z);
	printf("   R_pow:	%d\n\n", R_pow);
}


//#######################################################################

int main()
{
	Montgomery m;

	ZZ a, a_bar;
	ZZ b, b_bar;
	ZZ N;
	ZZ T, T_bar;

	//产生随机数
	uint len = 2048;
	RandomLen(a, len); 
	RandomLen(b, len);
	RandomLen(N, len);
	if ((N & ZZ(1)) != 1) //让N是奇数
		N += 1;

	m.init(N);
	m.show();

	m.Map(a, a_bar);
	m.Map(b, b_bar);

	m.Mul(a_bar, b_bar, T_bar);

	m.InvMap(T_bar, T);

	cout << "a = " << a << endl;
	cout << "b = " << b << endl;
	cout << "N = " << N << endl;
	cout << "T = " << T << endl;

	printf("\n\n测试速度:");
	printf("\n比特数:%d", len);
	int loop = 10000;
	
	printf("\nMontgomery Mul %d次 - ", loop);
	Timer(
		for (int i = 0; i < loop; i++)
			m.Mul(a_bar, b_bar, T_bar);
	);

	printf("\nNTL Mul %d次 - ", loop);
	Timer(
		for (int i = 0; i < loop; i++)
			T = MulMod(a, b, N);
	);

	printf("\n\n");

	return 0;
}

你可能感兴趣的:(计算机,代码,算法,数学,抽象代数,c++)