目录
题目
一.常规做法
二.高精度大数实现——O(N^2)
三.自定义Integer类大整数实现——O(NlogNlogN)
四.Integer类代码
原题:求10000以内阶乘
想法:
最近自己刚好在尝试实现c++的大整数类,刚好拿来实验一下大整数的计算
大整数:实现了加,减,乘,除和一些简单的赋值操作
后面介绍一些注意点!
实现:
最常规做法就是,使用迭代的方法,用1*2*3*4...n。
代码:
#include
using namespace std;
int main()
{
int n;
cin >> n;
long long sum = 1;
for (int i = 1; i <= n; i++)
{
sum *= i;
}
cout << sum << endl;
}
输入:
10
输出:
3628800
但是这里出现问题了,n大于一定数量会爆longlong。所以需要我们实现一个大数操作。
我们这里模仿了正常的乘法的实现过程,使用vector来实现大数的乘法。
#pragma GCC optimize(2)
#include
#include
#include
using namespace std;
vector mul(vector &A, int b)
{
vector C;
int t = 0;
for (int i = 0; i < A.size() || t; i++)
{
if (i < A.size())
t += A[i] * b;
C.push_back(t % 10);
t /= 10;
}
while (C.size() > 1 && C.back() == 0)
C.pop_back();
return C;
}
void fac(int n)
{
vector sum;
sum.push_back(1);
for (int i = 1; i <= n; i++)
{
sum = mul(sum, i);
}
// 输出结果
for (int i = sum.size() - 1; i >= 0; i--)
{
cout << sum[i];
}
cout << endl;
}
int main()
{
clock_t start, end; // 定义clock_t变量
int n;
cin >> n;
start = clock(); // 开始时间
fac(n);
end = clock(); // 结束时间
cout << "running time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}
输入:
100
输出:
93326215443944152681699238856266700490715968264381621468592963895217599993229915608941463976156518286253697920827223758251185210916864000000000000000000000000
running time = 0.005s
当输入更大的时候:
n=100 running time = 0.005s
n=1000 running time = 0.032s
n=10000 running time = 2.357s
n=50000 running time = 66.046s
分析:
我们发现当n增大的时候,时间增大的也很快。
这种乘法是由一个位数很大的数(vector存储)和一个位数很小的数相乘完成的,每一位相乘就 是一次运算。大数的位数是O()量级。小数位数虽然可以看为常数级但是一共有n个,所 以也是O()量级。总的复杂度就是O()。
从上面时间也可以看出实际上增大会更大。
思路:
时间花费这么多,怎么能忍!!
我们换种方法来计算,让阶乘的计算过程想一颗树一样会怎么样。
看下面n=8的过程:
1 * 2 3 * 4 5 * 6 7 * 8
2 * 12 30 * 56
24 * 1680
40320
让相邻两个数相乘,是不是logN轮就结束啦。
但是现在问题来了,我们变成了两个大数相乘了,从上文知道大数是O()量级。
两个也就是O(),再乘上logN轮,变成了O()。坏了,更慢了?
先不急,我们考虑大数相乘的优化。分治法O(n^1.59),使用NTT算法O(nlogn)!
这不就是O()复杂度了嘛!
代码:
#pragma GCC optimize(2)
#include
#include
#include
#include "Integer.h" //自己实现的类直接导入.h文件就好啦 代码再下面
using namespace std;
int fac(int n)
{
if (n == 0)
{
cout << n << "!=" << 1 << endl;
return 0;
}
Integer *A = new Integer[n + 1];
// 给每个对象赋值
for (int i = 1; i <= n; i++)
{
A[i] = i;
}
// 用分治的方法来计算(可以多线程)
// 每次乘法为O(nlogn) 然后分治为O(logn) 总的复杂度为O(nlogn*logn) 十万以内几秒吧
for (int i = 1; i <= n; i *= 2)
{
for (int j = 1; j <= n; j += 2 * i)
{
if (j + i <= n)
{
A[j] = A[j] * A[j + i];
A[j + i].clear();
}
}
}
// 结果
if (A[1].size() > 10)
{
cout << "the number have " << A[1].size() << " digit" << endl;
cout << "output?[Y/N]\n";
string a;
cin >> a;
if (a[0] == 'Y')
cout << n << "!=" << A[1] << endl;
}
else
{
cout << n << "!=" << A[1] << endl;
}
delete[] A;
return 0;
}
int main()
{
clock_t start, end; // 定义clock_t变量
int n; // cin >> n;
cin >> n;
start = clock(); // 开始时间
fac(n);
end = clock(); // 结束时间
cout << "running time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
return 0;
}
输入:
100
输出:
the number have 158 digit
output?[Y/N]
Y
100!=93326215443944152681699238856266700490715968264381621468592963895217599993229915608941463976156518286253697920827223758251185210916864000000000000000000000000
running time = 0.009s
当输入更大的时候:
使用了NTT的算法 未使用的O(N^2)算法
n=100 running time = 0.002s n=100 running time = 0.005s
n=1000 running time = 0.007s n=1000 running time = 0.032s
n=10000 running time = 0.067s n=10000 running time = 2.357s
n=50000 running time = 0.577s n=50000 running time = 66.046s
n=100000 running time = 1.269s
n=500000 running time = 7.973s
效果还是挺挺显著的,看上来也基本符合O()复杂度。
注意点:
实现的类的代码有点长,主要是因为乘法用了NTT来实现,总的来说效果不错O(nlogn)的算法比分治法O(n^1.59)好不少。
除法是使用了牛顿迭代法也是O(nlogn)。问题是,这个除法内存占得有点多,可能用一万次就要爆内存,大概因为我代码写的有点问题。有时间我在加上普通的除法吧让两种算法取舍一下。
加法和减法正常实现O(n),然后可以用longlong类型来赋值并且cout可以直接输出。
实现了一些比较。
测试代码:
#include
#include "Integer.h" //自己实现的类直接导入.h文件就好啦 代码再下面
using namespace std;
int main()
{
Integer a, b;
a = 23424234;
b = 267678;
cout << a << endl;
cout << "加:" << a - b << endl; // 减
cout << "减:" << a + b << endl; // 加
cout << "乘:" << a * b << endl; // 乘
cout << "除:" << a / b << endl; // 除
cout << "余" << a % b << endl; // 余
// 常见比较
if (a / b * b + a % b == a)
cout << "YES" << endl;
else
cout << "NO" << endl;
// 大于
if (a > b)
cout << "YES" << endl;
else
cout << "NO" << endl;
// 小于
if (a < b)
cout << "YES" << endl;
else
cout << "NO" << endl;
return 0;
}
结果:
加:23156556
减:23691912
乘:6270152108652
除:87
余136248
YES
YES
NO
最后:
在一些平台测试过数据,但不保证完全正确,有bug欢迎大家来跟我交流和指正,谢谢大家观看。
使用:
新建一个Integer.h文件,然后复制粘贴下面代码就可以使用啦。
遇到头文件出问题网上应该都有解决方案。
Integer.h代码 :
#ifndef BigInteger
#define BigInteger
#include
#include
#include
#include
#include
#include
class Poly_div
{
public:
struct complex
{
double a, b;
complex() : a(0), b(0) {}
complex(double x, double y) : a(x), b(y) {}
void operator+=(const complex& z) { a += z.a, b += z.b; }
complex operator-(const complex& z) { return { a - z.a, b - z.b }; }
complex operator*(const complex& z) { return { a * z.a - b * z.b, a * z.b + b * z.a }; }
};
const long long Mod = 998244353;
long long L;
complex* W[2];
int* a, * b, * bi, * x, * y, * rev;
complex* t1, * t2;
long long* t3;
int n, m, d;
// 构造函数
Poly_div(long long num)
{
L = num;
rev = new int[L << 1]();
W[0] = new complex[L]();
W[1] = new complex[L]();
a = new int[L]();
b = new int[L]();
bi = new int[L]();
x = new int[L]();
y = new int[L]();
t1 = new complex[L]();
t2 = new complex[L]();
t3 = new long long[L + 1]();
n = 0;
m = 0;
d = 0;
}
~Poly_div()
{
delete[] rev;
delete[] W[0];
delete[] W[1];
delete[] a;
delete[] b;
delete[] bi;
delete[] x;
delete[] y;
delete[] t1;
delete[] t2;
delete[] t3;
}
// 得到函数
void get(std::vector& A, std::vector& B)
{ // A B都是规范化之后的
n = A.size();
for (int i = 0; i < A.size(); i++)
{
a[A.size() - 1 - i] = A[i];
}
m = B.size();
for (int i = 0; i < B.size(); i++)
{
b[B.size() - 1 - i] = B[i];
}
}
// 除法
void div(std::vector& A)
{
FFT_Init();
d = divide(n, m);
lack(n, m, d);
for (int i = 0; i < n + m; i++)
{
A.push_back(x[i]);
}
return;
}
// 初始化
void FFT_Init()
{
int i, j;
const int l = L >> 1;
const double pi = acos(-1);
for (i = j = 1; j < L; j <<= 1)
for (; i < j << 1; i++)
rev[i << 1] = rev[i], rev[i << 1 | 1] = rev[i] + j;
for (i = 0; i < l; i++)
W[0][i + l] = { cos(pi * i / l), sin(pi * i / l) };
for (i = l - 1; i; i--)
W[0][i] = W[0][i << 1];
memcpy(W[1], W[0], L * sizeof(complex));
for (i = 1; i < L; i++)
W[1][i].b *= -1;
}
// 过程
void FFT(complex* f, int len, int sign)
{
int i = 1, j, k;
complex* p, * q, * w, t;
t.a = 0;
t.b = 0;
sign = (int)(sign < 0);
for (int* r = rev + len + 1; i < len; i++, r++)
if (i < (k = *r))
t = f[i], f[i] = f[k], f[k] = t;
for (i = 1; i < len; i <<= 1)
for (j = 0; j < len; j += i, j += i)
for (q = (p = f + j) + i, w = W[sign] + i, k = 0; k < i; k++, p++, q++)
t = *q * *w++, * q = *p - t, * p += t;
if (sign)
{
double p = 1. / len;
for (i = 0; i < len; i++, f++)
f->a *= p, f->b *= p;
}
}
void majutsu(int len) // 计算b的倒数,有效位数为len,结果存进bi里
{
bool g = 1;
int l = 16, l2 = 32, l4 = 64, i;
long double d = 0, e = 1;
for (i = 0; i < 20; i++)
{
d = d + e * b[i], e *= 0.1;
}
d = 10. / d;
if (d < 10)
{
for (i = 0; i <= l; i++)
bi[i] = d, d = (d - bi[i]) * 10;
bi[l - 1] += (bi[l] > 4);
}
else
bi[0] = 10;
while (l < len)
{
p:
memset(t1, 0, sizeof(complex) * l4);
memset(t2, 0, sizeof(complex) * l4);
for (i = 0; i < l2; i++)
t1[i].a = b[i];
for (i = 0; i < l; i++)
t2[i].a = bi[i];
FFT(t1, l4, 1), FFT(t2, l4, 1);
for (i = 0; i < l4; i++)
{
t1[i] = t1[i] * t2[i];
t1[i].a = 20 - t1[i].a, t1[i].b = -t1[i].b;
t2[i] = t1[i] * t2[i];
}
FFT(t2, l4, -1);
t3[l4] = 0;
for (i = l4 - 1; i >= 0; i--)
{
t3[i] = (long long)(floor(t2[i].a + 0.5)) + t3[i + 1] / 10, t3[i + 1] %= 10;
if (t3[i + 1] < 0)
t3[i + 1] += 10, t3[i]--;
}
if (t3[0] > 9)
{
bi[0] = t3[0] / 10, t3[0] %= 10;
for (i = 1; i < l2; i++)
bi[i] = t3[i - 1];
bi[l2 - 1] += (t3[l2 - 1] > 4);
}
else
{
for (i = 0; i < l2; i++)
bi[i] = t3[i];
bi[l2 - 1] += (t3[l2] > 4);
}
l <<= 1, l2 <<= 1, l4 <<= 1;
}
if (g)
{
g = 0, l >>= 1, l2 >>= 1, l4 >>= 1;
goto p;
} // 迭代过程结束后末几位仍会有偏差,于是再单独迭代一次
}
int divide(int n, int m) // 计算a/b,整数部分存进x里
{
int p = n - m + 16, l, i;
majutsu(p);
for (l = 1; l < n + p; l <<= 1)
;
memset(t1, 0, sizeof(complex) * l);
memset(t2, 0, sizeof(complex) * l);
for (i = 0; i < n; i++)
t1[i].a = a[i];
for (i = 0; i < p; i++)
t2[i].a = bi[i];
FFT(t1, l, 1), FFT(t2, l, 1);
for (i = 0; i < l; i++)
t1[i] = t1[i] * t2[i];
FFT(t1, l, -1);
t3[l] = 0;
for (i = l - 1; i >= 0; i--)
t3[i] = (long long)(t1[i].a + 0.5) + t3[i + 1] / 10, t3[i + 1] %= 10;
if (t3[0] > 9)
{
x[0] = t3[0] / 10, t3[0] %= 10, l = n - m + 1;
for (i = 0; i < n - m; i++)
x[i + 1] = t3[i];
}
else
for (l = n - m, i = 0; i < n - m; i++)
x[i] = t3[i];
return l;
}
void lack(int n, int m, int& d) // 微调
{
int l, i, j;
char t;
long long tl = 0;
for (i = 0, j = n - 1; i < j; i++, j--)
t = a[i], a[i] = a[j], a[j] = t;
for (i = 0, j = m - 1; i < j; i++, j--)
t = b[i], b[i] = b[j], b[j] = t;
for (i = 0, j = d - 1; i < j; i++, j--)
t = x[i], x[i] = x[j], x[j] = t;
for (l = 1; l < n; l <<= 1)
;
memset(t1, 0, sizeof(complex) * l);
memset(t2, 0, sizeof(complex) * l);
for (i = 0; i < m; i++)
t1[i].a = b[i];
for (i = 0; i < d; i++)
t2[i].a = x[i];
t2[0].a += 1.;
FFT(t1, l, 1), FFT(t2, l, 1);
for (i = 0; i < l; i++)
t1[i] = t1[i] * t2[i];
FFT(t1, l, -1);
for (i = 0; i <= n; i++)
tl += (long long)(t1[i].a + 0.5), y[i] = tl % 10, tl /= 10;
for (i = n; i >= 0; i--)
{
if (y[i] > a[i])
return;
else if (y[i] < a[i])
break;
}
for (x[0]++, i = 0; x[i] > 9; i++)
x[i + 1] += x[i] / 10, x[i] %= 10;
if (x[d])
d++;
}
};
class Polynomial
{
public:
long long Mod = 998244353;
int G = 3, iG = 332748118;
int MS;
long long* Inv; // 逆元
long long Sz, * R;
long long InvSz; // NTT
int N, M;
long long* A1, * B1;
Polynomial(int sum)
{
MS = std::max(64, sum * 2);
Inv = new long long[MS];
R = new long long[MS];
A1 = new long long[MS];
B1 = new long long[MS];
for (int i = 0; i < MS; i++)
{
Inv[i] = 0;
R[i] = 0;
A1[i] = 0;
B1[i] = 0;
}
InvSz = 0;
Sz = 0;
N = 0;
M = 0;
Init(MS);
}
void get(std::vector& A, std::vector& B);
std::vector& add();
std::vector& sub();
std::vector& mul();
static int cmp(std::vector& A, std::vector& B);
void Clear();
long long qPow(long long b, int e);
void Init(int N);
void InitFNTT(int N);
void FNTT(long long* A, int Ty);
void PolyMul(long long* A, long long* B, int deg);
void PolyInv(long long* A, int N, long long* B);
void PolyLn(long long* A, int N, long long* B);
void PolyExp(long long* A, int N, long long* B);
};
class Integer
{
protected:
std::vector data;
int flag = 1;
public:
Integer();
Integer(long long a);
Integer(int flag, std::vector& a);
void check();
friend std::istream& operator>>(std::istream& in, Integer& a);
friend std::ostream& operator<<(std::ostream& out, Integer& a);
long long abs(long long a);
Integer& operator=(Integer& a);
Integer& operator=(long long a);
Integer& operator=(std::string& a);
int operator==(long long a);
int operator==(Integer& a);
int operator!=(long long a);
int operator!=(Integer& a);
int operator<(Integer& a);
int operator<=(Integer& a);
int operator>(Integer& a);
int operator>=(Integer& a);
Integer& operator+(Integer& a);
Integer& operator-(Integer& a);
Integer& operator-();
Integer& operator/(Integer& a);
Integer& operator%(Integer& a);
Integer& operator*(Integer& a);
std::vector& Polymul_tool(std::vector a, std::vector b);
int size() { return data.size(); }
void clear()
{
flag = 1;
data.clear();
data.push_back(0);
}
};
long long Integer::abs(long long a)
{
if (a < 0)
return -a;
else
return a;
}
Integer::Integer()
{
flag = 1;
data.push_back(0);
}
Integer::Integer(long long a)
{
if (a < 0)
flag = -1;
else
flag = 1;
a = abs(a);
while (a)
{
data.push_back(a % 10);
a /= 10;
}
}
Integer::Integer(int flag, std::vector& a)
{
data = a;
this->flag = 1;
}
int Integer::operator<(Integer& a)
{
if (flag == 1) {
if (a.flag == 1) {
if (Polynomial::cmp(data, a.data) < 0)
return 1;
else
return 0;
}
else
{
return 0;
}
}
else {
if (a.flag == 1) {
return 1;
}
else {
if (Polynomial::cmp(data, a.data) > 0)
return 1;
else
return 0;
}
}
return 0;
}
int Integer::operator<=(Integer& a)
{
if (flag == 1) {
if (a.flag == 1) {
if (Polynomial::cmp(data, a.data) > 0)
return 0;
else
return 1;
}
else
{
return 0;
}
}
else {
if (a.flag == 1) {
return 1;
}
else {
if (Polynomial::cmp(data, a.data) < 0)
return 0;
else
return 1;
}
}
return 0;
}
int Integer::operator>(Integer& a)
{
if (flag == 1) {
if (a.flag == 1) {
if (Polynomial::cmp(data, a.data) > 0)
return 1;
else
return 0;
}
else
{
return 1;
}
}
else {
if (a.flag == 1) {
return 0;
}
else {
if (Polynomial::cmp(data, a.data) < 0)
return 1;
else
return 0;
}
}
return 0;
}
int Integer::operator>=(Integer& a)
{
if (flag == 1) {
if (a.flag == 1) {
if (Polynomial::cmp(data, a.data) < 0)
return 0;
else
return 1;
}
else
{
return 1;
}
}
else {
if (a.flag == 1) {
return 0;
}
else {
if (Polynomial::cmp(data, a.data) > 0)
return 0;
else
return 1;
}
}
return 0;
}
void Integer::check()
{
for (int i = data.size() - 1; i >= 1; i--)
if (data[i] == 0)
data.pop_back();
else
break;
if (data.size() == 1 && data[0] == 0)
flag = 1;
if (data.size() == 0)
{
data.push_back(0);
flag = 1;
}
}
int Integer::operator==(long long a)
{
Integer tool(a);
return (*this) == tool;
}
int Integer::operator==(Integer& a)
{
if (Polynomial::cmp(data, a.data) == 0)
return 1;
else
return 0;
}
int Integer::operator!=(long long a)
{
if ((*this) == a)
return 0;
else
return 1;
}
int Integer::operator!=(Integer& a)
{
if ((*this) == a)
return 0;
else
return 1;
}
Integer& Integer::operator=(long long a)
{
if (a < 0)
flag = -1;
else
flag = 1;
a = abs(a);
data.clear();
while (a)
{
data.push_back(a % 10);
a /= 10;
}
return (*this);
}
Integer& Integer::operator=(std::string& date)
{
if (date[0] == '-')
{
flag = -1;
for (int i = date.size() - 1; i >= 1; i--)
{
data.push_back(date[i] - 48);
}
}
else
{
flag = 1;
data.clear();
for (int i = date.size() - 1; i >= 0; i--)
{
data.push_back(date[i] - 48);
}
}
check();
return *this;
}
Integer& Integer::operator=(Integer& a)
{
data = a.data;
flag = a.flag;
return *this;
}
Integer& Integer::operator+(Integer& a)
{
if (flag == 1 && a.flag == 1)
{
Integer* p;
p = new Integer;
p->flag = 1;
Polynomial poly(data.size() + a.data.size() - 1);
poly.get(data, a.data);
p->data = poly.add();
p->check();
return (*p);
}
else if (flag == -1 && a.flag == -1)
{
Integer* p;
p = new Integer;
p->flag = -1;
Polynomial poly(data.size() + a.data.size() - 1);
poly.get(data, a.data);
p->data = poly.add();
p->check();
return (*p);
}
else if (flag == 1 && a.flag == -1)
{
Integer b = a;
b.flag = 1;
return (*this) - b;
}
else if (flag == -1 && a.flag == 1)
{
Integer b = (*this);
b.flag = 1;
return a - b;
}
return a;
}
Integer& Integer::operator-(Integer& a)
{
if (flag == 1 && a.flag == 1)
{
Integer* p;
p = new Integer;
Polynomial poly(data.size() + a.data.size() - 1);
p->flag = poly.cmp(this->data, a.data);
if (p->flag == 1)
{
poly.get(data, a.data);
p->data = poly.sub();
return (*p);
}
else
{
poly.get(a.data, data);
p->data = poly.sub();
return (*p);
}
}
else if (flag == -1 && a.flag == -1)
{
Integer* p;
p = new Integer;
Polynomial poly(data.size() + a.data.size() - 1);
p->flag = poly.cmp(this->data, a.data);
if (p->flag == 1)
{
poly.get(data, a.data);
p->data = poly.sub();
p->flag *= -1;
return (*p);
}
else
{
poly.get(a.data, data);
p->data = poly.sub();
p->flag *= -1;
return (*p);
}
}
else
{
Integer b = a;
b.flag *= -1;
return (*this) + b;
}
}
Integer& Integer::operator-()
{
Integer* p;
p = new Integer;
p->data.clear();
p->flag *= -1;
p->data = data;
return *p;
}
std::vector& Integer::Polymul_tool(std::vector a, std::vector b)
{
int deg = a.size() + b.size() - 1;
for (int i = a.size() - 1; i >= 0; i--)
if (a[i] == 0)
a.pop_back();
else
break;
for (int i = b.size() - 1; i >= 0; i--)
if (b[i] == 0)
b.pop_back();
else
break;
Polynomial test(a.size() + b.size());
test.get(a, b);
std::vector* c;
c = &test.mul();
return *c;
}
Integer& Integer::operator*(Integer& a)
{
Integer* p;
p = new Integer;
p->data.clear();
p->flag = this->flag * a.flag;
p->data = Polymul_tool(this->data, a.data);
return *p;
}
Integer& Integer::operator/(Integer& a)
{
Integer* p;
p = new Integer;
p->data.clear();
p->flag = this->flag * a.flag;
int cmp_temp = Polynomial::cmp(data,a.data);
if (cmp_temp == 0) {
(*p)=1;
return (*p);
}
else if (cmp_temp == -1) {
return (*p);
}
int len = 64;
int len_2 = 10 + this->data.size() + a.data.size();
while (len < len_2)
len *= 2;
Poly_div test(len * 2);
test.get(this->data, a.data);
test.div(p->data);
p->check();
return (*p);
}
Integer& Integer::operator%(Integer& a)
{
return (*this) - (*this) / a * a;
}
std::istream& operator>>(std::istream& in, Integer& a)
{
a.data.clear();
std::string date;
std::cin >> date;
if (date[0] == '-')
{
a.flag = -1;
for (int i = date.size() - 1; i >= 1; i--)
{
a.data.push_back(date[i] - 48);
}
}
else
{
a.flag = 1;
a.data.clear();
for (int i = date.size() - 1; i >= 0; i--)
{
a.data.push_back(date[i] - 48);
}
}
a.check();
return in;
}
std::ostream& operator<<(std::ostream& out, Integer& a)
{
a.check();
if (a.data.size() == 0)
out << '0';
else if (a.flag == 1)
{
for (int i = a.data.size() - 1; i >= 0; i--)
out << a.data[i];
}
else
{
out << "-";
for (int i = a.data.size() - 1; i >= 0; i--)
out << a.data[i];
}
return out;
}
// Polynomial的函数
void Polynomial::get(std::vector& A, std::vector& B)
{ // A B都是规范化之后的
N = A.size();
for (int i = A.size() - 1; i >= 0; i--)
{
A1[i] = A[i];
}
M = B.size();
for (int i = B.size() - 1; i >= 0; i--)
{
B1[i] = B[i];
}
}
int Polynomial::cmp(std::vector& A, std::vector& B)
{
if (A.size() > B.size())
return 1;
else if (A.size() < B.size())
return -1;
else
{
for (int i = A.size() - 1; i >= 0; i--)
{
if (A[i] > B[i])
return 1;
else if (A[i] < B[i])
return -1;
}
return 0;
}
return 0;
}
std::vector& Polynomial::mul()
{
std::vector* P;
P = new std::vector;
if (N == 0 || M == 0)
{
P->push_back(0);
return (*P);
}
PolyMul(A1, B1, N + M - 2);
long long tool = 0;
for (int i = 0; i < N + M - 1; i++)
{
long long a = tool + (A1[i] + Mod) % Mod;
P->push_back(a % 10);
tool = a / 10;
}
while (tool)
{
P->push_back(tool % 10);
tool /= 10;
}
for (int i = P->size() - 1; i >= 1; i--)
{
if ((*P)[i] == 0)
P->pop_back();
else
break;
}
Clear();
return (*P);
}
std::vector& Polynomial::sub()
{
std::vector* P;
P = new std::vector;
if (N == 0 && M == 0)
{
P->push_back(0);
return (*P);
}
int tool1 = 0;
int Z = std::max(N, M);
for (int i = 0; i < Z; i++)
{
int tool2 = A1[i] - B1[i] + tool1;
if (tool2 >= 0)
{
tool1 = 0;
P->push_back(tool2);
}
else
{
tool1 = -1;
P->push_back(10 + tool2);
}
}
for (int i = P->size() - 1; i >= 1; i--)
{
if ((*P)[i] == 0)
P->pop_back();
else
break;
}
Clear();
return (*P);
}
std::vector& Polynomial::add()
{
std::vector* P;
P = new std::vector;
if (N == 0 && M == 0)
{
P->push_back(0);
return (*P);
}
int tool1 = 0;
int Z = std::max(N, M);
for (int i = 0; i < Z; i++)
{
int tool2 = A1[i] + B1[i] + tool1;
P->push_back(tool2 % 10);
tool1 = tool2 / 10;
}
if (tool1 > 0)
{
P->push_back(tool1);
}
for (int i = P->size() - 1; i >= 1; i--)
{
if ((*P)[i] == 0)
P->pop_back();
else
break;
}
Clear();
return (*P);
}
void Polynomial::Clear()
{
for (int i = 0; i < MS; i++)
{
R[i] = 0;
A1[i] = 0;
B1[i] = 0;
}
}
long long Polynomial::qPow(long long b, int e)
{ // 快速幂
long long a = 1;
for (; e; e >>= 1, b = b * b % Mod)
if (e & 1)
a = a * b % Mod;
return a;
}
void Polynomial::Init(int N)
{ // 求逆元
Inv[1] = 1;
for (int i = 2; i < N; ++i)
Inv[i] = -(Mod / i) * Inv[Mod % i] % Mod;
return;
}
void Polynomial::InitFNTT(int N)
{
int Bt = 0;
for (; 1 << Bt <= N; ++Bt)
;
if (Sz == (1 << Bt))
return;
Sz = 1 << Bt;
InvSz = -(Mod - 1) / Sz;
for (int i = 1; i < Sz; ++i)
R[i] = R[i >> 1] >> 1 | (i & 1) << (Bt - 1);
}
void Polynomial::FNTT(long long* A, int Ty)
{
for (int i = 0; i < Sz; ++i)
if (R[i] < i)
std::swap(A[R[i]], A[i]);
for (int j = 1, j2 = 2; j < Sz; j <<= 1, j2 <<= 1)
{
long long gn = qPow(~Ty ? G : iG, (Mod - 1) / j2), g, X, Y;
for (int i = 0, k; i < Sz; i += j2)
{
for (k = 0, g = 1; k < j; ++k, g = g * gn % Mod)
{
X = A[i + k], Y = g * A[i + j + k] % Mod;
A[i + k] = (X + Y) % Mod, A[i + j + k] = (X - Y) % Mod;
}
}
}
if (!~Ty)
for (int i = 0; i < Sz; ++i)
A[i] = A[i] * InvSz % Mod;
}
void Polynomial::PolyMul(long long* A, long long* B, int deg)
{ // 多项式乘法 A=A*B
InitFNTT(deg);
FNTT(A, 1);
FNTT(B, 1);
for (int i = 0; i < Sz; i++)
{
A[i] = (A[i] * B[i]) % Mod;
}
FNTT(A, -1);
FNTT(B, -1);
}
void Polynomial::PolyInv(long long* A, int N, long long* B)
{ // 多项式求inv 对x^N取mod A*B=1
// long long tA[MS], tB[MS];
long long* tA = new long long[MS];
long long* tB = new long long[MS];
B[0] = qPow(A[0], Mod - 2);
for (int L = 1; L < N; L <<= 1)
{
int L2 = L << 1, L4 = L << 2;
InitFNTT(L4);
memcpy(tA, A, 8 * L2);
memset(tA + L2, 0, 8 * (Sz - L2));
memcpy(tB, B, 8 * L);
memset(tB + L, 0, 8 * (Sz - L));
FNTT(tA, 1), FNTT(tB, 1);
for (int i = 0; i < Sz; ++i)
tB[i] = (2 - tB[i] * tA[i]) % Mod * tB[i] % Mod;
FNTT(tB, -1);
for (int i = 0; i < L2; ++i)
B[i] = tB[i];
}
delete[] tA;
delete[] tB;
}
void Polynomial::PolyLn(long long* A, int N, long long* B)
{ // 多项式求ln 对x^N取mod A[0]=1 B=ln(A)
long long* tA = new long long[MS];
long long* tB = new long long[MS];
PolyInv(A, N - 1, tB);
InitFNTT(N * 2 - 3);
for (int i = 1; i < N; ++i)
tA[i - 1] = i * A[i] % Mod;
memset(tA + N - 1, 0, 8 * (Sz - N + 1));
memset(tB + N - 1, 0, 8 * (Sz - N + 1));
FNTT(tA, 1), FNTT(tB, 1);
for (int i = 0; i < Sz; ++i)
tA[i] = (long long)tA[i] * tB[i] % Mod;
FNTT(tA, -1);
B[0] = 0;
for (int i = 1; i < N; ++i)
B[i] = (long long)tA[i - 1] * Inv[i] % Mod;
delete[] tA;
delete[] tB;
}
void Polynomial::PolyExp(long long* A, int N, long long* B)
{ // 多项式求exp 对x^N取mod A[0]=0 B=e^A
long long* tA = new long long[MS];
long long* tB = new long long[MS];
B[0] = 1;
for (int L = 1; L < N; L <<= 1)
{
int L2 = L << 1, L4 = L << 2;
memset(B + L, 0, 8 * (L2 - L));
PolyLn(B, L2, tB);
InitFNTT(L4);
memcpy(tA, B, 8 * L);
memset(tA + L, 0, 8 * (Sz - L));
for (int i = 0; i < L2; ++i)
tB[i] = ((!i) - tB[i] + A[i]) % Mod;
memset(tB + L2, 0, 8 * (Sz - L2));
FNTT(tA, 1), FNTT(tB, 1);
for (int i = 0; i < Sz; ++i)
tA[i] = tA[i] * tB[i] % Mod;
FNTT(tA, -1);
for (int i = 0; i < L2; ++i)
B[i] = tA[i];
}
delete[] tA;
delete[] tB;
}
#endif // !BigInteger