我们已经知道了多项式版的二进制 FWT,但是我们似乎不是很好将其扩展到 k k k 进制下,于是我们来考虑另一种方式的 FWT。
规定三个多项式的长度为 2 n 2^n 2n,如果不够则往后面补 0 0 0。
还是先考虑二进制,我们假设存在矩阵使得 T A ∗ T B = T C TA * TB = TC TA∗TB=TC,考虑矩阵长什么样。
我们可以得到如下式子
F W T ( A ) x = ∑ i = 0 2 n − 1 w x , i ∗ A i FWT(A)_{x}=\sum_{i = 0}^{2^n-1}w_{x,i}*A_i FWT(A)x=i=0∑2n−1wx,i∗Ai
因为 F W T ( A ) ∗ F W T ( B ) = F W T ( C ) FWT(A)*FWT(B)=FWT(C) FWT(A)∗FWT(B)=FWT(C),所以 F W T ( A ) x ∗ F W T ( B ) x = F W T ( C ) x FWT(A)_x*FWT(B)_x=FWT(C)_x FWT(A)x∗FWT(B)x=FWT(C)x,可写为
∑ i = 0 2 n − 1 w x , i ∗ A i ∗ ∑ j = 0 2 n − 1 w x , j ∗ B j = ∑ k = 0 2 n − 1 w x , k ∗ C k \sum_{i = 0}^{2^n-1}w_{x,i}*A_i * \sum_{j= 0}^{2^n-1}w_{x,j}*B_j=\sum_{k= 0}^{2^n-1}w_{x,k}*C_k i=0∑2n−1wx,i∗Ai∗j=0∑2n−1wx,j∗Bj=k=0∑2n−1wx,k∗Ck
即
∑ i = 0 2 n − 1 ∑ j = 0 2 n − 1 w x , i ∗ w x , j ∗ A i ∗ B j = ∑ k = 0 2 n − 1 w x , k ∗ C k \sum_{i=0}^{2^n-1}\sum_{j= 0}^{2^n-1}w_{x,i}* w_{x,j}*A_i *B_j=\sum_{k= 0}^{2^n-1}w_{x,k}*C_k i=0∑2n−1j=0∑2n−1wx,i∗wx,j∗Ai∗Bj=k=0∑2n−1wx,k∗Ck
因为我们要使得
∑ i ⊕ j = k A i ∗ B j = C k \sum_{i\oplus j=k} A_i*B_j=C_k i⊕j=k∑Ai∗Bj=Ck
即要使得以下式子成立
∑ i ⊕ j = k w x , i ∗ w x , j ∗ A i ∗ B j = w x , k ∗ C k \sum_{i\oplus j = k}w_{x,i}* w_{x,j}*A_i *B_j=w_{x,k}*C_k i⊕j=k∑wx,i∗wx,j∗Ai∗Bj=wx,k∗Ck
所以我们可以得出
w x , i ∗ w x , j = w x , k ( i ⊕ j = k ) w_{x,i}*w_{x,j} = w_{x,k}\ \ \ \ (i\oplus j = k) wx,i∗wx,j=wx,k (i⊕j=k)
此时的矩阵就是范德蒙德矩阵。(就是我们再莫比乌斯反演里面学的哪个单位根矩阵)
我们考虑如何快速求出 F W T ( A ) x FWT(A)_x FWT(A)x。
我们定义 A 0 A_0 A0 表示前半段, A 1 A_1 A1 表示后半段, x 1 x_1 x1 表示 x x x 的最高位(如果最高位为 1 1 1 则 x 1 x_1 x1 为 1 1 1,否则为 0 0 0), x 0 x_0 x0 表示 x x x 除了最高位其他的位, i 0 i_0 i0 也表示 i i i 除了最高位其它的位。
则我们可以推出
F W T ( A ) x = ∑ i = 0 2 n − 1 w x , i ∗ A i = ∑ i = 0 2 n − 1 − 1 w x , i ∗ A i + ∑ i = 2 n 2 n − 1 w x , i ∗ A i = ∑ i = 0 2 n − 1 − 1 w x 1 , 0 ∗ w x 0 , i 0 ∗ A i + ∑ i = 2 n 2 n − 1 w x 1 , 1 ∗ w x 0 , i 0 ∗ A i = w x 1 , 0 ∗ F W T ( A 0 ) x 0 + w x 1 , 1 ∗ F W T ( A 1 ) x 0 \begin{aligned} FWT(A)_{x}&=\sum_{i = 0}^{2^n-1}w_{x,i}*A_i\\ &=\sum_{i = 0}^{2^{n - 1} - 1}w_{x,i}*A_i+\sum_{i = 2^n}^{2^n - 1}w_{x,i}*A_i\\ &=\sum_{i = 0}^{2^{n - 1} - 1}w_{x_1,0}*w_{x_0,i_0}*A_i+\sum_{i = 2^n}^{2^n - 1}w_{x_1,1}*w_{x_0,i_0}*A_i\\ &=w_{x_1,0}*FWT(A_0)_{x_0}+w_{x_1,1}*FWT(A_1)_{x_0} \end{aligned} FWT(A)x=i=0∑2n−1wx,i∗Ai=i=0∑2n−1−1wx,i∗Ai+i=2n∑2n−1wx,i∗Ai=i=0∑2n−1−1wx1,0∗wx0,i0∗Ai+i=2n∑2n−1wx1,1∗wx0,i0∗Ai=wx1,0∗FWT(A0)x0+wx1,1∗FWT(A1)x0
于是就跟多项式下的 FWT 一样可以递归求了。
于是我们进入正题,来看看 k k k 进制下的 FWT。(其实跟二进制下很像,只不过有一些细节上的区别)
同样规定多项式长度为 k n k^n kn。
还是要构造矩阵使得 T A ∗ T B = T C TA * TB = TC TA∗TB=TC。
同样有
F W T ( A ) x = ∑ i = 0 k n − 1 w x , i ∗ A i = ∑ i = 0 k − 1 w x 1 , i ∗ F W T ( A i ) x 0 \begin{aligned} FWT(A)_{x}&=\sum_{i = 0}^{k^n-1}w_{x,i}*A_i\\ &=\sum_{i = 0}^{k - 1}w_{x_1,i}*FWT(A_i)_{x_0} \end{aligned} FWT(A)x=i=0∑kn−1wx,i∗Ai=i=0∑k−1wx1,i∗FWT(Ai)x0
上面的定义和二进制类似,证明也类似,就直接写出来算了。(绝对不是因为懒)
此时一样递归求就可以了。
矩阵就是下面的矩阵,其中 w k w_k wk 代表 k k k 下的单位根。(跟 FFT 里的矩阵一样)
[ 1 1 1 ⋯ 1 1 ω n 1 ω n 2 ⋯ ω n n − 1 1 ω n 2 ω n 4 ⋯ ω n 2 ( n − 1 ) ⋮ ⋮ ⋮ ⋱ ⋮ 1 ω n n − 1 ω n 2 ( n − 1 ) ⋯ ω n ( n − 1 ) ( n − 1 ) ] \left[ \begin{matrix} \ 1 & 1 & 1 & \cdots & 1 \ \\ \ 1 & \omega_n^1 & \omega_n^2 & \cdots & \omega_n^{n-1} \ \\ \ 1 & \omega_n^2 & \omega_n^4 & \cdots & \omega_n^{2(n-1)} \ \\ \ \vdots & \vdots & \vdots & \ddots & \vdots \ \\ \ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \cdots & \omega_n^{(n-1)(n-1)} \ \end{matrix} \right] 1 1 1 ⋮ 11ωn1ωn2⋮ωnn−11ωn2ωn4⋮ωn2(n−1)⋯⋯⋯⋱⋯1 ωnn−1 ωn2(n−1) ⋮ ωn(n−1)(n−1)
可以看看代码加深理解!!!
我这份代码是这道题的,三进制下的。
#include
using namespace std;
typedef long long LL;
int n, k, len = 1;
LL ans;
complex<double> a[1000005];
const complex<double> w = { -0.5, 0.5 * sqrt(3) }, w2 = { -0.5, -0.5 * sqrt(3) };
int in() {
char ch = getchar();
int s = 0;
while (ch < '0' || ch > '9') ch = getchar();
while (ch <= '9' && ch >= '0') s = s * 3 + ch - '1', ch = getchar();
return s;
}
void FWT(complex<double> *f, int flag) {
for (int mid = 1; mid < len; mid = mid * 3) {
for (int i = 0; i < len; i = i + mid * 3) {
for (int j = i; j < i + mid; j++) {
complex<double> t0 = f[j], t1 = f[j + mid], t2 = f[j + mid * 2];
if (flag == 1) {
f[j] = t0 + t1 + t2;
f[j + mid] = t0 + t1 * w + t2 * w2;
f[j + mid * 2] = t0 + t1 * w2 + t2 * w;
} else {
f[j] = t0 + t1 + t2;
f[j + mid] = t0 + t1 * w2 + t2 * w;
f[j + mid * 2] = t0 + t1 * w + t2 * w2;
double t = f[j].real();
f[j].real(t / 3);
t = f[j + mid].real();
f[j + mid].real(t / 3);
t = f[j + mid * 2].real();
f[j + mid * 2].real(t / 3);
t = f[j].imag();
f[j].imag(t / 3);
t = f[j + mid].imag();
f[j + mid].imag(t / 3);
t = f[j + mid * 2].imag();
f[j + mid * 2].imag(t / 3);
}
}
}
}
}
int main() {
scanf("%d%d", &n, &k);
for (int t = 0; t < k; t++) len = len * 3;
for (int i = 0; i < n; i++) a[in()].real(1);
FWT(a, 1);
for (int i = 0; i < len; i++) a[i] = a[i] * a[i] * a[i];
FWT(a, -1);
ans = a[0].real() + 0.5;
printf("%lld\n", (ans - n) / 6);
return 0;
}