算法学习FFT系列(2):快速数论变换NTT &&bzoj3992: [SDOI2015]序列统计例题详解

bzoj3992: [SDOI2015]序列统计

Description

小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。

Input

一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。第二行,|S|个整数,表示集合S中的所有元素。

Output

一行,一个整数,表示你求出的种类数mod 1004535809的值。

Sample Input

4 3 1 2
1 2

Sample Output

8

HINT

【样例说明】
可以生成的满足要求的不同的数列有(1,1,1,1)、(1,1,2,2)、(1,2,1,2)、(1,2,2,1)、(2,1,1,2)、(2,1,2,1)、(2,2,1,1)、(2,2,2,2)。
【数据规模和约定】
对于10%的数据,1<=N<=1000;
对于30%的数据,3<=M<=100;
对于60%的数据,3<=M<=800;
对于全部的数据,1<=N<=109,3<=M<=8000,M为质数,1<=x<=M-1,输入数据保证集合S中元素不重复

知识点:快速数论变换

其实学了FFT之后再学NTT已经只差临门一脚了。
首先复习一下FFT

复习FFT

离散型快速傅立叶变换如下
Xk=n=0N1xne2πiNnkk=0,1N1 X k = ∑ n = 0 N − 1 x n e − 2 π i N n k k = 0 , 1 ⋯ N − 1
离散型快速傅立叶逆变换如下
xn=1Nn=0N1Xke2πiNnkk=0,1N1 x n = 1 N ∑ n = 0 N − 1 X k e 2 π i N n k k = 0 , 1 ⋯ N − 1

引入NTT

如今快速数论变换是在 Zp Z p 上进行的。
快速傅立叶变换中,有一个神奇的单位复根叫做
ω=e2πiN ω = e − 2 π i N
这玩意儿有一个神奇的性质,就是 ωN=1 ω N = 1 并且 ωk ω k k=0,1N1 k = 0 , 1 ⋯ N − 1 有N个取值。
那么相应地类比一下,在快速数论变换中,也有一个神奇的东西叫做g人家是正儿八经的原根。它的性质就是
e2πiNgP1N(modP) e − 2 π i N ≡ g P − 1 N ( mod P )
e这玩意儿不是double的吗,怎么跟一个int扔过去同余了捏,其实上面那个公式的意思,就是g再modp意义下和单位复根是同一个作用。。。。

离散型NTT公式

这样子的话,我们就可以得到我们的离散型快速数论变换的公式。
Xk=n=0N1xngP1Nnk(modP)k=0,1N1 X k = ∑ n = 0 N − 1 x n g P − 1 N n k ( mod P ) k = 0 , 1 ⋯ N − 1
类比一下就是离散型快速数论逆变换公式
Xk=1Nn=0N1xngP1Nnk(modP)k=0,1N1 X k = 1 N ∑ n = 0 N − 1 x n g − P − 1 N n k ( mod P ) k = 0 , 1 ⋯ N − 1
这个时候你肯定会觉得很草率,因为g是什么你都不知道。
所以说,这篇文章的重点是类比,是类比,是类比!!
你瞧,咱们的 ω ω 小盆友,可以它如果玩起次方来可以玩遍所有数,那么g同学也一样,它玩起次方来,也可以玩遍所有数。

原根的定义

标准的定义就是
设m是正整数,a是整数,若a模m的阶等于 ϕ(m) ϕ ( m ) ,则称a为模m的一个原根。
可是这个神奇的原根是不一定存在的,所以题目在大多数条件下给的模数是素数。有一个神奇的定理是,素数的原根一定存在。
我们发现我们的 P1N P − 1 N 也不一定是整数。而我们知道,NTT中 N=2k N = 2 k 所以说,咱们的素数一般就是 k2n+1 k 2 n + 1 也就是费马素数。
这个时候NTT又变了一个名字,叫做费马数数论变换(什么?英文?英文是不存在的)
这个时候就可记几个常见的模数了
1004535809=479×221+1 1004535809 = 479 × 2 21 + 1 注意这个数的原根是3
998244353=119×223+1 998244353 = 119 × 2 23 + 1 这个数的原根仍然是3

寻根之旅

这个时候好奇的同学们就又会问了,如果题目不给你原根怎么办?为了考虑这些同学们的心情,就再扔一个定理来压压惊。
原根不会超过素数的 14 1 4 次方(证明?证明是不存在的)
由原根的定义知
0<i,j<P,ijgi/gj(modP) ∀ 0 < i , j < P , i ≠ j 有 g i ≢ g j ( mod P ) 并且 gp11(modP) g p − 1 ≡ 1 ( mod P )
转化一下可得当且仅当 i=P1gi1(modP) i = P − 1 时 , g i ≡ 1 ( mod P ) (如果存在两个同余的,指数相减一下就有一个同余1的了)
再化简一下,如果 iP1s.t.ai1(modP) ∃ i ≠ P − 1 s . t . a i ≡ 1 ( mod P )
i|(P1)s.t.ai1(modP) ∃ i | ( P − 1 ) s . t . a i ≡ 1 ( mod P )
证明的方法就是辗转相处一下得到 agcd(P1,i)1(modP) a g c d ( P − 1 , i ) ≡ 1 ( mod P )
所以先筛因数,然后暴搜即可,复杂度 O(P34) O ( P 3 4 )

代码

int getG(int n) {
    int top = 0;
    for(int i = 2;i < n - 1; ++i) if(!((n - 1) % i)) q[++top] = i;
    for(int i = 2, j; ; ++i) {
        for(j = 1;j <= top; ++j) if(pow(i, q[j], n) == 1) break;
        if(j == top + 1) return i;
    }
}

例题分析

很抱歉,这道并不是裸题。
这是神题。
没错这是一道dp题
设f[i][j]为第i数个乘积为j
那么有
F[i,j]=F[i1][k] F [ i , j ] = ∑ F [ i − 1 ] [ k ]
其中k满足 iSs.t.kij(modP) ∃ i ∈ S s . t . k ∗ i ≡ j ( mod P )
N太大了,于是我们用矩阵优化。
可是怎么转移?我们换一种写法
F[i,jkmodm]=F[i][j]C[k] F [ i , j ∗ k mod m ] = ∑ F [ i ] [ j ] C [ k ]
其中 C[k]=[kS] C [ k ] = [ k ∈ S ]
仍然很棘手。(所以是神题嘛)
注意到m是素数于是乎有一个玄学的变换
我们找到m的原根g,然后对m以内的所有数进行映射。
ind[i]=j表示 gji(modm) g j ≡ i ( mod m )
有什么用,我们突然发现,本来是以乘积形式存在的 jkmodm j ∗ k mod m Zm Z m 意义下突然通过原根变成了加和形式
jkmodm=ind[j]+ind[k]mod(m1) j ∗ k mod m = i n d [ j ] + i n d [ k ] mod ( m − 1 )
很神奇有木有!
这个时候重新观察式子
F[i,j+kmod(m1)]=F[i][j]C[k] F [ i , j + k mod ( m − 1 ) ] = ∑ F [ i ] [ j ] C [ k ]
这不是NTT么?
NTT变换加速卷积即可
所以说我们发现了原根的另一个用途,就是可以把 Zm Z m 意义下的乘积运算通过原根幂的带换变成乘积的形式。总复杂度 O(mlogmlogn) O ( m l o g m l o g n )
呼呼,终于搞完啦!

代码

/**************************************************************
    Problem: 3992
    User: 2014lvzelong
    Language: C++
    Result: Accepted
    Time:3256 ms
    Memory:1676 kb
****************************************************************/

#include
#include
#include
#include
#include
#include
using namespace std;
const double pi = acos(-1);
const int N = 16384, P = 1004535809, K = 13;
int R[N], a[N], b[N], g[K + 1], ng[K + 1], ind[N], q[N], inv[N + 1];
int read() {
    char ch = getchar(); int x = 0;
    while(ch < '0' || ch > '9') ch = getchar();
    for(;ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 1) + (x << 3) - '0' + ch;
    return x;
}
int pow(int a, int k, int P) {
    int b = 1; 
    for(; k; a = (1LL * a * a) % P, k >>= 1) 
        if(k & 1) b = (1LL * b * a) % P;
    return b;
}
void NTT(int *F, int n, int f) {
    for(int i = 0;i < n; ++i) if(i < R[i]) swap(F[i], F[R[i]]);
    for(int d = 0;(1 << d) < n; ++d) {
        int wn = ~f ? g[d] : ng[d], m = 1 << d, m2 = m << 1; 
        for(int j = 0;j < n; j += m2) {
            for(int w = 1, l = 0;l < m; ++l , w = 1LL * w * wn % P) {
                int &x = F[j + l], &y = F[j + l + m], t = 1LL * w * y % P;
                y = x - t; if(y < 0) y += P;
                x = x + t; if(x >= P) x -= P;
            }
        }
    }
    if(!(~f)) for(int i = 0;i < n; ++i) F[i] = 1LL * F[i] * inv[n] % P;
}

void Mul(int *A, int *B, int n, int m) {
    for(int i = 0;i < n; ++i) A[i] = 1LL * A[i] * B[i] % P;
    NTT(A, n, -1);
    for(int i = m;i < n; ++i) A[i % m] = (A[i % m] + A[i]) % P, A[i] = 0;
}

void mpower(int n, int k, int m) {
    for(int i = 0;i < n; ++i) b[i] = a[i];
    for(--k; k; k >>= 1) {
        NTT(a, n, 1);
        if(k & 1) NTT(b, n, 1), Mul(b, a, n, m);
        Mul(a, a, n, m);
    }
}

int getG(int n) {
    int top = 0;
    for(int i = 2;i < n - 1; ++i) if(!((n - 1) % i)) q[++top] = i;
    for(int i = 2, j; ; ++i) {
        for(j = 1;j <= top; ++j) if(pow(i, q[j], n) == 1) break;
        if(j == top + 1) return i;
    }
}

int main() {
    int i, j, G, len, L;
    for(G = 3, g[K] = pow(G, (P - 1) / N, P), ng[K] = pow(g[K], P - 2, P), i = K - 1; ~i; --i) 
        g[i] = 1LL * g[i + 1] * g[i + 1] % P, ng[i] = 1LL * ng[i + 1] * ng[i + 1] % P;
    for(inv[1] = 1, i = 2; i <= N; ++i) inv[i] = 1LL * (P - inv[P % i]) * (P / i) % P;
    int n = read(), m = read(), x = read(), S = read();
    for(G = getG(m), i = 0, j = 1; i < m - 1; ++i, j = (j * G) % m) ind[j] = i; //r^i=j(mod m)
    while(S--) {
        i = read(); 
        if(i) a[ind[i]] = 1;
    }
    for(len = 1, L = 0, --m; len < m << 1; len <<= 1, ++L) ;
    for(int i = 0;i < len; ++i) R[i] = (R[i >> 1] >> 1) | ((i & 1) << L - 1);
    mpower(len, n, m);
    printf("%d\n", b[ind[x]]);
    return 0;
}

你可能感兴趣的:(数学相关-FFT与NTT)