k进制快速沃尔什变换(k进制FWT)

引入

我们已经知道了多项式版的二进制 FWT,但是我们似乎不是很好将其扩展到 k k k 进制下,于是我们来考虑另一种方式的 FWT。

原理

二进制

规定三个多项式的长度为 2 n 2^n 2n,如果不够则往后面补 0 0 0

还是先考虑二进制,我们假设存在矩阵使得 T A ∗ T B = T C TA * TB = TC TATB=TC,考虑矩阵长什么样。

我们可以得到如下式子

F W T ( A ) x = ∑ i = 0 2 n − 1 w x , i ∗ A i FWT(A)_{x}=\sum_{i = 0}^{2^n-1}w_{x,i}*A_i FWT(A)x=i=02n1wx,iAi

因为 F W T ( A ) ∗ F W T ( B ) = F W T ( C ) FWT(A)*FWT(B)=FWT(C) FWT(A)FWT(B)=FWT(C),所以 F W T ( A ) x ∗ F W T ( B ) x = F W T ( C ) x FWT(A)_x*FWT(B)_x=FWT(C)_x FWT(A)xFWT(B)x=FWT(C)x,可写为

∑ i = 0 2 n − 1 w x , i ∗ A i ∗ ∑ j = 0 2 n − 1 w x , j ∗ B j = ∑ k = 0 2 n − 1 w x , k ∗ C k \sum_{i = 0}^{2^n-1}w_{x,i}*A_i * \sum_{j= 0}^{2^n-1}w_{x,j}*B_j=\sum_{k= 0}^{2^n-1}w_{x,k}*C_k i=02n1wx,iAij=02n1wx,jBj=k=02n1wx,kCk

∑ i = 0 2 n − 1 ∑ j = 0 2 n − 1 w x , i ∗ w x , j ∗ A i ∗ B j = ∑ k = 0 2 n − 1 w x , k ∗ C k \sum_{i=0}^{2^n-1}\sum_{j= 0}^{2^n-1}w_{x,i}* w_{x,j}*A_i *B_j=\sum_{k= 0}^{2^n-1}w_{x,k}*C_k i=02n1j=02n1wx,iwx,jAiBj=k=02n1wx,kCk

因为我们要使得

∑ i ⊕ j = k A i ∗ B j = C k \sum_{i\oplus j=k} A_i*B_j=C_k ij=kAiBj=Ck

即要使得以下式子成立

∑ i ⊕ j = k w x , i ∗ w x , j ∗ A i ∗ B j = w x , k ∗ C k \sum_{i\oplus j = k}w_{x,i}* w_{x,j}*A_i *B_j=w_{x,k}*C_k ij=kwx,iwx,jAiBj=wx,kCk

所以我们可以得出

w x , i ∗ w x , j = w x , k      ( i ⊕ j = k ) w_{x,i}*w_{x,j} = w_{x,k}\ \ \ \ (i\oplus j = k) wx,iwx,j=wx,k    (ij=k)

此时的矩阵就是范德蒙德矩阵。(就是我们再莫比乌斯反演里面学的哪个单位根矩阵)

我们考虑如何快速求出 F W T ( A ) x FWT(A)_x FWT(A)x

我们定义 A 0 A_0 A0 表示前半段, A 1 A_1 A1 表示后半段, x 1 x_1 x1 表示 x x x 的最高位(如果最高位为 1 1 1 x 1 x_1 x1 1 1 1,否则为 0 0 0), x 0 x_0 x0 表示 x x x 除了最高位其他的位, i 0 i_0 i0 也表示 i i i 除了最高位其它的位。

则我们可以推出

F W T ( A ) x = ∑ i = 0 2 n − 1 w x , i ∗ A i = ∑ i = 0 2 n − 1 − 1 w x , i ∗ A i + ∑ i = 2 n 2 n − 1 w x , i ∗ A i = ∑ i = 0 2 n − 1 − 1 w x 1 , 0 ∗ w x 0 , i 0 ∗ A i + ∑ i = 2 n 2 n − 1 w x 1 , 1 ∗ w x 0 , i 0 ∗ A i = w x 1 , 0 ∗ F W T ( A 0 ) x 0 + w x 1 , 1 ∗ F W T ( A 1 ) x 0 \begin{aligned} FWT(A)_{x}&=\sum_{i = 0}^{2^n-1}w_{x,i}*A_i\\ &=\sum_{i = 0}^{2^{n - 1} - 1}w_{x,i}*A_i+\sum_{i = 2^n}^{2^n - 1}w_{x,i}*A_i\\ &=\sum_{i = 0}^{2^{n - 1} - 1}w_{x_1,0}*w_{x_0,i_0}*A_i+\sum_{i = 2^n}^{2^n - 1}w_{x_1,1}*w_{x_0,i_0}*A_i\\ &=w_{x_1,0}*FWT(A_0)_{x_0}+w_{x_1,1}*FWT(A_1)_{x_0} \end{aligned} FWT(A)x=i=02n1wx,iAi=i=02n11wx,iAi+i=2n2n1wx,iAi=i=02n11wx1,0wx0,i0Ai+i=2n2n1wx1,1wx0,i0Ai=wx1,0FWT(A0)x0+wx1,1FWT(A1)x0

于是就跟多项式下的 FWT 一样可以递归求了。

k进制

于是我们进入正题,来看看 k k k 进制下的 FWT。(其实跟二进制下很像,只不过有一些细节上的区别)

同样规定多项式长度为 k n k^n kn

还是要构造矩阵使得 T A ∗ T B = T C TA * TB = TC TATB=TC

同样有

F W T ( A ) x = ∑ i = 0 k n − 1 w x , i ∗ A i = ∑ i = 0 k − 1 w x 1 , i ∗ F W T ( A i ) x 0 \begin{aligned} FWT(A)_{x}&=\sum_{i = 0}^{k^n-1}w_{x,i}*A_i\\ &=\sum_{i = 0}^{k - 1}w_{x_1,i}*FWT(A_i)_{x_0} \end{aligned} FWT(A)x=i=0kn1wx,iAi=i=0k1wx1,iFWT(Ai)x0

上面的定义和二进制类似,证明也类似,就直接写出来算了。(绝对不是因为懒

此时一样递归求就可以了。

矩阵就是下面的矩阵,其中 w k w_k wk 代表 k k k 下的单位根。(跟 FFT 里的矩阵一样)

[   1 1 1 ⋯ 1     1 ω n 1 ω n 2 ⋯ ω n n − 1     1 ω n 2 ω n 4 ⋯ ω n 2 ( n − 1 )     ⋮ ⋮ ⋮ ⋱ ⋮     1 ω n n − 1 ω n 2 ( n − 1 ) ⋯ ω n ( n − 1 ) ( n − 1 )   ] \left[ \begin{matrix} \ 1 & 1 & 1 & \cdots & 1 \ \\ \ 1 & \omega_n^1 & \omega_n^2 & \cdots & \omega_n^{n-1} \ \\ \ 1 & \omega_n^2 & \omega_n^4 & \cdots & \omega_n^{2(n-1)} \ \\ \ \vdots & \vdots & \vdots & \ddots & \vdots \ \\ \ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \cdots & \omega_n^{(n-1)(n-1)} \ \end{matrix} \right]  1 1 1  11ωn1ωn2ωnn11ωn2ωn4ωn2(n1)1 ωnn1 ωn2(n1)  ωn(n1)(n1) 

可以看看代码加深理解!!!

代码

我这份代码是这道题的,三进制下的。

#include 
using namespace std;
typedef long long LL;
int n, k, len = 1;
LL ans;
complex<double> a[1000005];
const complex<double> w = { -0.5, 0.5 * sqrt(3) }, w2 = { -0.5, -0.5 * sqrt(3) };
int in() {
    char ch = getchar();
    int s = 0;
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch <= '9' && ch >= '0') s = s * 3 + ch - '1', ch = getchar();
    return s;
}
void FWT(complex<double> *f, int flag) {
    for (int mid = 1; mid < len; mid = mid * 3) {
        for (int i = 0; i < len; i = i + mid * 3) {
            for (int j = i; j < i + mid; j++) {
                complex<double> t0 = f[j], t1 = f[j + mid], t2 = f[j + mid * 2];
                if (flag == 1) {
                    f[j] = t0 + t1 + t2;
                    f[j + mid] = t0 + t1 * w + t2 * w2;
                    f[j + mid * 2] = t0 + t1 * w2 + t2 * w;
                } else {
                    f[j] = t0 + t1 + t2;
                    f[j + mid] = t0 + t1 * w2 + t2 * w;
                    f[j + mid * 2] = t0 + t1 * w + t2 * w2;
                    double t = f[j].real();
                    f[j].real(t / 3);

                    t = f[j + mid].real();
                    f[j + mid].real(t / 3);

                    t = f[j + mid * 2].real();
                    f[j + mid * 2].real(t / 3);

                    t = f[j].imag();
                    f[j].imag(t / 3);

                    t = f[j + mid].imag();
                    f[j + mid].imag(t / 3);

                    t = f[j + mid * 2].imag();
                    f[j + mid * 2].imag(t / 3);
                }
            }
        }
    }
}
int main() {
    scanf("%d%d", &n, &k);
    for (int t = 0; t < k; t++) len = len * 3;
    for (int i = 0; i < n; i++) a[in()].real(1);
    FWT(a, 1);
    for (int i = 0; i < len; i++) a[i] = a[i] * a[i] * a[i];
    FWT(a, -1);
    ans = a[0].real() + 0.5;
    printf("%lld\n", (ans - n) / 6);
    return 0;
}

你可能感兴趣的:(数论,算法,c++)