[uoj 34 多项式乘法] FFT&NTT 模板

[uoj 34 多项式乘法] FFT&NTT 模板

分类:模板 FFT NTT

1. 题目链接

[uoj 34 多项式乘法]

2. 题意描述

给你两个多项式,请输出乘起来后的多项式。
第一行两个整数 n m ,分别表示两个多项式的次数。
第二行 n+1 个整数,分别表示第一个多项式的 0 n 次项前的系数。
第三行 m+1 个整数,分别表示第一个多项式的 0 m 次项前的系数。

3. 解题思路

模板测试题。给出FFT和NTT的板子。
可以直接去[uoj statistics] 查看更好的板子。

4. 实现代码

#include 
using namespace std;

typedef long long LL;
typedef long double LB;
typedef pair<int, int> PII;
typedef pair PLL;
typedef vector<int> VII;

const int INF = 0x3f3f3f3f;
const LL INFL = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-8;
const double PI = acos(-1.0);

template <typename T>
inline bool scan_d (T &ret) {
    char c;
    int sgn;
    if (c = getchar(), c == EOF) return 0; //EOF
    while (c != '-' && (c < '0' || c > '9') ) c = getchar();
    sgn = (c == '-') ? -1 : 1;
    ret = (c == '-') ? 0 : (c - '0');
    while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
    ret *= sgn;
    return 1;
}
template<typename T>
void print(T x) {
    static char s[33], *s1; s1 = s;
    if (!x) *s1++ = '0';
    if (x < 0) putchar('-'), x = -x;
    while(x) *s1++ = (x % 10 + '0'), x /= 10;
    while(s1-- != s) putchar(*s1);
}
template<typename T> void println(T x) { print(x); putchar('\n');}


const int MAXN = 262144 + 5;    /// 数组大小应为2^k
//typedef complex CP;
struct CP {
    double x, y;
    CP() {}
    CP(double x, double y) : x(x), y(y) {}
    inline double real() { return x; }
    inline CP operator * (const CP& r) const { return CP(x * r.x - y * r.y, x * r.y + y * r.x); }
    inline CP operator - (const CP& r) const { return CP(x - r.x, y - r.y); }
    inline CP operator + (const CP& r) const { return CP(x + r.x, y + r.y); }
};
CP a[MAXN], b[MAXN];
int r[MAXN], res[MAXN];

void fft_init(int nm, int k) {
    for(int i = 0; i < nm; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1)); /// Rader操作
}

void fft(CP ax[], int nm, int op) {
    for(int i = 0; i < nm; ++i) if(i < r[i]) swap(ax[i], ax[r[i]]);
    for(int h = 2, m = 1; h <= nm; h <<= 1, m <<= 1) {  /// 枚举长度
        CP wn = CP(cos(op * 2 * PI / h), sin(op * 2 * PI / h));
        for(int i = 0; i < nm; i += h) {    /// 枚举所有长度为h的区间
            CP w(1, 0);                     /// 旋转因子
            for(int j = i; j < i + m; ++j, w = w * wn) { /// 枚举角度
                CP t = w * ax[j + m];       /// 蝴蝶操作
                ax[j + m] = ax[j] - t;
                ax[j] = ax[j] + t;
            }
        }
    }
    if(op == -1) for(int i = 0; i < nm; ++i) ax[i].x /= nm;
}

void trans(int ax[], int bx[], int n, int m) {

    int nm = 1, k = 0;
    while(nm < 2 * n || nm < 2 * m) nm <<= 1, ++k;

    for(int i = 0; i < n; ++i) a[i] = CP(ax[i], 0);
    for(int i = 0; i < m; ++i) b[i] = CP(bx[i], 0);
    for(int i = n; i < nm; ++i) a[i] = CP(0, 0);
    for(int i = m; i < nm; ++i) b[i] = CP(0, 0);

    fft_init(nm, k);
    fft(a, nm, 1); fft(b, nm, 1);
    for(int i = 0; i < nm; ++i) a[i] = a[i] * b[i];
    fft(a, nm, -1);
    nm = n + m - 1;
    for(int i = 0; i < nm; ++i) 
        res[i] = (int)(a[i].real() + 0.5), print(res[i]), putchar(" \n"[i == nm - 1]);
}

int main() {
#ifdef ___LOCAL_WONZY___
    freopen("input.txt", "r", stdin);
#endif // ___LOCAL_WONZY___
    static int ax[MAXN], bx[MAXN], n, m;

    scan_d(n); scan_d(m); ++n, ++m;
    for(int i = 0; i < n; ++i) scan_d(ax[i]);
    for(int i = 0; i < m; ++i) scan_d(bx[i]);

    trans(ax, bx, n, m);

    return 0;
}
/** NTT **/
#include 
using namespace std;

typedef long long LL;
typedef long double LB;
typedef pair<int, int> PII;
typedef pair PLL;
typedef vector<int> VII;

const int INF = 0x3f3f3f3f;
const LL INFL = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-8;
const double PI = acos(-1.0);

template <typename T>
inline bool scan_d (T &ret) {
    char c;
    int sgn;
    if (c = getchar(), c == EOF) return 0; //EOF
    while (c != '-' && (c < '0' || c > '9') ) c = getchar();
    sgn = (c == '-') ? -1 : 1;
    ret = (c == '-') ? 0 : (c - '0');
    while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
    ret *= sgn;
    return 1;
}
template<typename T>
void print(T x) {
    static char s[33], *s1; s1 = s;
    if (!x) *s1++ = '0';
    if (x < 0) putchar('-'), x = -x;
    while(x) *s1++ = (x % 10 + '0'), x /= 10;
    while(s1-- != s) putchar(*s1);
}
template<typename T> void println(T x) { print(x); putchar('\n');}

const int MAXN = 262144 + 5;    /// 数组大小应为2^k
const int G = 3, MOD = 998244353;

int a[MAXN], b[MAXN], r[MAXN], res[MAXN];

template<typename T>
T quick_pow(T a, T b) {
    T ret = 1;
    while(b) {
        if(b & 1) ret = (LL)ret * a % MOD;
        a = (LL)a * a % MOD;
        b >>= 1;
    }
    return ret;
}

void ntt_init(int nm, int k) {
    for(int i = 0; i < nm; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1)); /// Rader操作
}

template<typename T>
void ntt(T ax[], int nm, int op) {
    for(int i = 0; i < nm; ++i) if(i < r[i]) swap(ax[i], ax[r[i]]);
    for(int h = 2, m = 1; h <= nm; h <<= 1, m <<= 1) {  /// 枚举长度
        T wn = quick_pow(G, (MOD - 1) / h);
        for(int i = 0; i < nm; i += h) {    /// 枚举所有长度为h的区间
            T w = 1;                                  /// 旋转因子
            for(int j = i; j < i + m; ++j, w = (LL)w * wn % MOD) { /// 枚举角度
                T t = (LL)w * ax[j + m] % MOD;       /// 蝴蝶操作
                ax[j + m] = ax[j] - t + MOD;
                if(ax[j + m] >= MOD) ax[j + m] -= MOD;
                ax[j] = ax[j] + t;
                if(ax[j] >= MOD) ax[j] -= MOD;
            }
        }
    }
    if(op == -1) {
        for(int i = 1; i < nm / 2; i++) swap(ax[i], ax[nm - i]); /// Caution Here!
        T inv = quick_pow(nm, MOD - 2);
        for(int i = 0; i < nm; ++i) ax[i] = (LL)ax[i] * inv % MOD;
    }
}

template<typename T>
void trans(T ax[], T bx[], int n, int m) {

    int nm = 1, k = 0;
    while(nm < 2 * n || nm < 2 * m) nm <<= 1, ++k;

    for(int i = 0; i < n; ++i) a[i] = ax[i];
    for(int i = 0; i < m; ++i) b[i] = bx[i];
    for(int i = n; i < nm; ++i) a[i] = 0;
    for(int i = m; i < nm; ++i) b[i] = 0;

    ntt_init(nm, k);
    ntt(a, nm, 1); ntt(b, nm, 1);
    for(int i = 0; i < nm; ++i) a[i] = (LL)a[i] * b[i] % MOD;
    ntt(a, nm, -1);
    nm = n + m - 1;
    for(int i = 0; i < nm; ++i) res[i] = a[i], print(res[i]), putchar(" \n"[i == nm - 1]);
}

int main() {
#ifdef ___LOCAL_WONZY___
    freopen("input.txt", "r", stdin);
#endif // ___LOCAL_WONZY___
    static int n, m;

    scan_d(n); scan_d(m); ++n, ++m;
    for(int i = 0; i < n; ++i) scan_d(a[i]);
    for(int i = 0; i < m; ++i) scan_d(b[i]);

    trans(a, b, n, m);

    return 0;
}

你可能感兴趣的:(ACM____FFT&NTT)