考虑多项式 A(x)=∑i=0naixi A ( x ) = ∑ i = 0 n a i x i ,其中 {a0,a1,…,an} { a 0 , a 1 , … , a n } 被称为多项式 A(x) A ( x ) 的系数向量。每个多项式都有唯一的系数向量,每个系数向量都对应唯一的多项式。
我们可以把多项式 A(x) A ( x ) 看做是一个 n n 次函数,我们可以取 n+1 n + 1 个不同的值 b0,b1,⋯,bn b 0 , b 1 , ⋯ , b n 带入分别求出 n+1 n + 1 个多项式的值 c0,c1,⋯,cn c 0 , c 1 , ⋯ , c n 。可以看出,从系数表示法到点值表示法是唯一的,而点值表示法在系数未知的时候可以看做是一个 n+1 n + 1 元一次方程组,可以解出唯一系数。因此点值表示与多项式也一一对应。
令 C(x)=A(x)B(x)=∑i=0n∑j=0maibjxi+j C ( x ) = A ( x ) B ( x ) = ∑ i = 0 n ∑ j = 0 m a i b j x i + j ,其中 A(x),B(x) A ( x ) , B ( x ) 分别是 n,m n , m 次多项式, A(x),B(x) A ( x ) , B ( x ) 的系数向量是 a→,b→ a → , b → 。
容易发现,用系数表示法使两个向量相乘是 O(n2) O ( n 2 ) 的复杂度,那如何才能优化呢?
考虑两个点值表示的多项式相乘,易发现此时只需要把两个多项式的对应点值相乘即可,复杂度为 O(n) O ( n ) 。但是如何把系数表示转化为点值表示,再转化回来呢?。
如果我们选取 n+1 n + 1 个值暴力代入,复杂度仍然为 O(n2) O ( n 2 ) ,甚至转化回来的时候会用到 O(n3) O ( n 3 ) 的高斯消元,难道点值表示就没有任何可取之处了吗?
因此,一种算法叫做“快速傅里叶变换”诞生了,它可以在 O(nlogn) O ( n l o g n ) 的时间内完成上述两部转化。
若 xn=1 x n = 1 ,则 x x 被称为 n n 次单位根。 n n 次单位根共有 n n 个,分别形如 e2kπin,0≤k<n,k∈Z e 2 k π i n , 0 ≤ k < n , k ∈ Z ,注意这里的 i i 是虚数单位。为什么呢?
在讨论性质时,均假定 n n 为偶数。
单位根具有对称性,即 ωkn=−ωk+n2n ω n k = − ω n k + n 2 。这个定理是比较好证明的,因为有
上面我们说了那么多,究竟是要干什么呢?没错!把单位根当做数值带入多项式,求出多项式的点值表示。但是到此为止,我们的复杂度还是 O(n2) O ( n 2 ) 的,甚至由于涉及到复数运算,常数只会比原来更大。于是我们要好好利用单位根的性质进行简化。接下来假设 n n 是2的整数次幂。
考虑关于单位根的 n−1 n − 1 次多项式 A(ωkn) A ( ω n k ) ,先暴力计算(注意这里的 i i 不是虚数啦):
我们已经可以在 O(nlogn) O ( n l o g n ) 的时间内把多项式的系数表示转化为点值表示,但是如何把点值表示转化为系数表示呢?
也就是我们需要解出一个 n n 元一次方程组,考虑把它化为矩阵形式:
于是我们得到了一个结论:两个矩阵的乘积除主对角线为 n n ,其它位置全部为0.这可以看做是 n n 倍的单位矩阵,也就是说,我们把 Q Q 矩阵和右边的点值向量相乘,就可以得到系数向量。但是这样的复杂度仍然是 O(n2) O ( n 2 ) 的。
注意到 ω−kn ω n − k 实际上仍然是 n n 次单位根!证明:
我们来观察一下最后一步时所有数字的顺序。考虑把所有二进制串反过来,比如1000变为0001,我们会发现最后一步时fft的顺序就是从0到 n−1 n − 1 !(最后两个似乎画反了……)也就是说,原串中第 i i 个数到fft的最后一步时就变成了第 rev(i) r e v ( i ) 个数,其中 rev r e v 函数表示翻转一个数的二进制表示。只要我们按照这个排好序,一步一步合并上去就行了!
于是我们从头到尾扫一遍数组,假设当前扫到第 i i 个数,只要 rev(i)>i r e v ( i ) > i ,我们就可以交换 rev(i)和i r e v ( i ) 和 i 的值,这样最后得到的数组就是fft最后一步的数组!然后就可以很方便地迭代实现fft了!
const int maxn = 1 << 18;
const long double PI = (long double)3.14159265358979323846;
struct Complex{
long double r, i;
Complex(){r = i = 0;}
Complex(long double a, long double b){r = a, i = b;}
Complex operator+(const Complex &c) const
{return Complex(r + c.r, i + c.i);}
Complex operator-(const Complex &c) const
{return Complex(r - c.r, i - c.i);}
Complex operator*(const Complex &c) const
{return Complex(r * c.r - i * c.i, i * c.r + r * c.i);}
} A[maxn];
void rader(Complex *a, int n){//倒位序
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);//j=rev(i)
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;//反向二进制加法
}
}
void fft(Complex *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1;
Complex wn = Complex(cosl(PI / hh), rev * sinl(PI / hh));
for(int i = 0; i < n; i += h){
Complex *ta = a + i, *tb = a + i + hh, w = Complex(1, 0);
for(int j = 0; j < hh; ++j, ++ta, ++tb){
Complex x = *ta, y = w * *tb;
*ta = x + y, *tb = x - y, w = w * wn;
}
}
}
if(rev == -1) for(int i = 0; i < n; i++)
a[i] = a[i] * Complex(1.0 / n, 0);
}
一个小优化:正常来说我们都是进行两边DFT,然后点值乘法,再IDFT,但实际上在用FFT算卷积的时候可以去掉一个DFT。比如计算a和b的卷积,我们把需要进行FFT的复数数组的实数部分设置为a,虚数部分设置为b,然后DFT一次,计算自己的平方卷积,再IDFT出来,虚数部分结果除以2就是原来的答案。
FFT已经可以在 O(nlogn) O ( n l o g n ) 的时间内完成多项式点值和系数表示之间的转换,但是在OI中,我们经常要求的是对于某个数求模的结果,这样FFT的精度显然不够了。
考虑在模运算下定义单位根。设模数为质数 p p ,那么它的原根 gp−1n g p − 1 n 实际上和 ωn ω n 等价。为什么呢?考虑单位根的几个性质:
1. n n 个单位根互不相等,根据原根定义,原根的0次幂到 p−1 p − 1 次幂都不相等,上面那 n n 个值自然不相等。
2.单位根的 n n 次幂等于1,这个根据费马小定理,任意与 p p 互质正整数的 p−1 p − 1 次幂都为1,因此对于上面的也成立。
3.单位根的对称性。证明:
#include
using namespace std;
typedef long long ll;
ll modmul(ll a, ll b, ll mod){
ll res = 0;
for(; b; b >>= 1){
if(b & 1) res = (res + a) % mod;
a = (a + a) % mod;
}
return res;
}
ll modpow(ll a, ll b, ll mod){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = modmul(res, a, mod) % mod;
a = modmul(a, a, mod) % mod;
}
return res;
}
vector vec;
int main(){
ll mod;
while(~scanf("%lld", &mod)){
vec.clear();
ll p = mod - 1;
for(ll i = 2; i * i <= p; i++) if(p % i == 0) {
vec.push_back(i);
while(p % i == 0) p /= i;
}
if(p > 1) vec.push_back(p);
int sz = vec.size();
for(int i = 2;; i++){
int flag = 1;
for(int j = 0; j < sz; j++)
if(modpow(i, (mod - 1) / vec[j], mod) == 1){flag = 0; break;}
if(flag == 1){printf("%d\n", i); break;}
}
}
return 0;
}
再附上NTT的板子~
typedef long long ll;
const int mod = 998244353, G = 3;
ll modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void NTT(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
此时我们就可以愉快的做题啦!
原题链接
题意:求如下函数的值:
#include
using namespace std;
typedef long long ll;
const int maxn = 100005, mod = 998244353, G = 3;
int modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void NTT(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
ll fact[maxn], revf[maxn], rev[maxn], A[1 << 18], B[1 << 18];
int main(){
int n; scanf("%d", &n);
rev[1] = fact[0] = revf[0] = 1;
for(int i = 1; i <= n; i++){
if(i > 1) rev[i] = mod - (ll)mod / i * rev[mod % i] % mod;
revf[i] = revf[i - 1] * rev[i] % mod;
fact[i] = fact[i - 1] * i % mod;
}
int tn = 1;
while(tn < 2 * n + 1) tn <<= 1;
for(int i = 0; i <= n; i++){
A[i] = i & 1 ? mod - revf[i] : revf[i];
if(i > 0) B[i] = (i > 1 ? (modpow(i, n + 1) - 1) * rev[i - 1] % mod : n + 1) * revf[i] % mod;
else B[i] = 1;
}
NTT(A, tn, 0), NTT(B, tn, 0);
for(int i = 0; i < tn; i++) A[i] = A[i] * B[i] % mod;
NTT(A, tn, 1);
ll res = 0;
for(int i = 1, j = 0; j <= n; j++){
res = (fact[j] * i % mod * A[j] + res) % mod;
i = i * 2 % mod;
}
printf("%lld\n", res);
return 0;
}
原题链接
像这种区别于i,j的卷积可以使用CDQ分治+NTT处理。先考虑把一整段剖成左右两块,分别递归处理,然后再使用NTT计算左边的x对右边的y的贡献。但是这道题右边x对左边y的贡献是减法卷积,我们可以把左边的多项式翻转,再求卷积,理论复杂度为 O(nlog2n) O ( n l o g 2 n ) ,但似乎常数超级大,而且明明在本机跑得比别人快2倍,在BZOJ上却T掉……
不管了,假装自己过了
还是放上我自己常数巨大的代码吧……
#include
using namespace std;
typedef long long ll;
const int maxn = 1 << 17;
const double PI = 3.1415926535898;
int A[maxn], B[maxn], rev[maxn], n, m, Q, T;
ll C[maxn];
struct Complex{
double r, i;
Complex(){r = i = 0.0;}
Complex(double a, double b){r = a, i = b;}
Complex operator+(const Complex &c) const {
return Complex(r + c.r, i + c.i);
}
Complex operator-(const Complex &c) const {
return Complex(r - c.r, i - c.i);
}
Complex operator*(const Complex &c) const {
return Complex(r * c.r - i * c.i, r * c.i + i * c.r);
}
} AA[maxn], BB[maxn], R[maxn];
void FFT(Complex *a, int n, int r){
for(int i = 0; i < n; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1;
Complex wn = Complex(cos(PI / hh), sin(PI / hh));
if(r) wn.i = -wn.i;
Complex *ta = a, *tb = a + hh;
for(int i = 0; i < n; i += h){
Complex w = Complex(1, 0);
for(int j = 0; j < hh; ++j, ++ta, ++tb){
Complex x = *ta, y = w * *tb;
*ta = x + y, *tb = x - y;
w = w * wn;
}
ta += hh, tb += hh;
}
}
if(r) for(int i = 0; i < n; i++) a[i].r = a[i].r / n;
}
void mul(Complex *a, Complex *b, int n){
if(n <= 32){
for(int i = 0; i < n; i++) R[i] = Complex(0, 0);
for(int i = 0; i < n >> 1; i++)
for(int j = 0; j < n >> 1; j++)
R[i + j] = R[i + j] + a[i] * b[j];
for(int i = 0; i < n; i++) a[i] = R[i];
} else {
FFT(a, n, 0), FFT(b, n, 0);
for(int i = 0; i < n; i++) a[i] = a[i] * b[i];
FFT(a, n, 1);
}
}
void cdq(int l, int r){
if(l == r - 1){C[0] += A[l] * B[l]; return;}
int mid = (l + r) >> 1, hlen = r - l, len = hlen << 1;
int t = __builtin_ctz(hlen);
for(int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << t;
for(int i = 0; i < len; i++) AA[i] = BB[i] = Complex(0, 0);
for(int i = l; i < mid; i++) AA[i - l].r = A[i];
for(int i = mid; i < r; i++) BB[i - mid].r = B[i];
mul(AA, BB, len);
for(int i = 0; i < len; i++) C[i + l + mid] += (int)(AA[i].r + 0.1);
for(int i = 0; i < len; i++) AA[i] = BB[i] = Complex(0, 0);
for(int i = mid; i < r; i++) AA[i - mid].r = A[i];
for(int i = l; i < mid; i++) BB[i - l].r = B[mid - i + l - 1];
mul(AA, BB, len);
for(int i = 0; i < len; i++) C[i + 1] += (int)(AA[i].r + 0.1);
cdq(l, mid), cdq(mid, r);
}
const int maxr = 10000000;
char str[maxr], prt[maxr]; int rpos, ppos;
char readc(){
if(!rpos) fread(str, 1, maxr, stdin);
char c = str[rpos++];
if(rpos == maxr) rpos = 0;
return c;
}
int read(){
int x; char c;
while((c = readc()) < '0' || c > '9');
x = c - '0';
while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
return x;
}
void print(ll x){
if(x){
static char sta[20];
int tp = 0;
for(; x; x /= 10) sta[tp++] = x % 10 + '0';
while(tp > 0) prt[ppos++] = sta[--tp];
} else prt[ppos++] = '0';
prt[ppos++] = '\n';
}
int main(){
for(T = read(); T--;){
n = read(), m = read(), Q = read();
int mx = 0, N = 1;
memset(A, 0, sizeof(A));
memset(B, 0, sizeof(B));
memset(C, 0, sizeof(C));
for(int i = 0; i < n; i++){
int t = read();
mx = max(mx, t);
++A[t];
}
for(int i = 0; i < m; i++){
int t = read();
mx = max(mx, t);
++B[t];
}
while(N <= mx) N <<= 1;
cdq(0, N);
while(Q--) print(C[read()]);
}
fwrite(prt, 1, ppos, stdout);
return 0;
}
原题链接
这道题似乎比较水啊,考虑如何算出确定两个装饰的亮度时不同旋转位置的差异值。我们把 ∑ni=1(xi−yi)2 ∑ i = 1 n ( x i − y i ) 2 变成 ∑ni=1xi+∑ni=1yi−2∑ni=1xiyi ∑ i = 1 n x i + ∑ i = 1 n y i − 2 ∑ i = 1 n x i y i ,会发现前面两个是常数,后面一个是乘法。考虑把乘法化成卷积的形式,把一个手环上所有装饰的信息复制一遍接在后面(对没错就跟普通处理环的方法一样),然后翻转另一个手环的亮度信息,这样做一个卷积,就可以得到旋转不同的角度得到的差异值了。
再考虑如何计算最小差异值。如果我们确定一个亮度时算出了不同位置的差异值,并且找出了最小值所在的位置,那么无论我把一个手环的亮度整体加上多少,卷积后取最小值的位置必然不会改变,因为整体加 k k 可以看做是所有角度上的值都加上了 k∑ni=1yi k ∑ i = 1 n y i ,然后前面两个常数的变化其实可以暴力算,复杂度 O(nlogn+nm) O ( n l o g n + n m ) 就解决了。
其实再深入一点,差异值关于亮度整体的变化值是一个凹函数,因此可以三分优化到 O(n(logn+logm)) O ( n ( l o g n + l o g m ) ) 。
#include
using namespace std;
typedef long long ll;
const int maxn = 1 << 18;
const double PI = 3.14159265358979323846;
int A[maxn], B[maxn], n, m;
struct Complex{
double r, i;
Complex(){r = i = 0.0;}
Complex(double a, double b){r = a, i = b;}
Complex operator+(const Complex &c) const {
return Complex(r + c.r, i + c.i);
}
Complex operator-(const Complex &c) const {
return Complex(r - c.r, i - c.i);
}
Complex operator*(const Complex &c) const {
return Complex(r * c.r - i * c.i, i * c.r + r * c.i);
}
} AA[maxn];
void rader(Complex *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void fft(Complex *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1;
Complex wn = Complex(cos(PI / hh), rev * sin(PI / hh));
for(int i = 0; i < n; i += h){
Complex w = Complex(1, 0);
for(int j = i; j < i + hh; j++){
Complex x = a[j], y = w * a[j + hh];
a[j] = x + y, a[j + hh] = x - y;
w = w * wn;
}
}
}
if(rev == -1) for(int i = 0; i < n; i++)
a[i] = a[i] * Complex(1.0 / n, 0);
}
int calc(int i, int mn, int ss){
int sum = 0;
for(int j = 0; j < n; j++) sum += (A[j] + i) * (A[j] + i);
return sum - 2 * (mn + i * ss);
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 0; i < n; i++){
scanf("%d", A + i);
A[i] = A[i + n] = A[i];
AA[i].r = AA[i + n].r = A[i];
}
for(int i = 0; i < n; i++){
scanf("%d", B + i);
AA[n - i - 1].i = B[i];
}
int len = 1;
while(len < 3 * n) len <<= 1;
fft(AA, len, 1);
for(int i = 0; i < len; i++) AA[i] = AA[i] * AA[i];
fft(AA, len, -1);
int mn = INT_MIN, res = INT_MAX, ini = 0, ss = 0;
for(int i = n - 1; i < 2 * n - 1; i++){
mn = max(mn, int(AA[i].i / 2 + 0.5));
//printf("%d %d\n", i, int(AA[i].i / 2 - 0.5));
}
for(int i = 0; i < n; i++) ini += B[i] * B[i], ss += B[i];
int l = -m, r = m;
while(l + 3 <= r){
int len = (r - l + 1) / 3;
int m1 = l + len, m2 = r - len;
if(calc(m1, mn, ss) < calc(m2, mn, ss)) r = m2;
else l = m1;
}
for(int i = l; i <= r; i++)
res = min(res, ini + calc(i, mn, ss));
printf("%d\n", res);
return 0;
}
上述代码便使用了以前说的FFT计算时的小优化,两次FFT即可计算出卷积。
原题链接
其实这道题并不难,主要是推公式的时候一定要细心细心再细心!!
首先我们考虑枚举有多少种颜色恰好出现了s次,再令 g(i,j) g ( i , j ) 表示 i i 种颜色填入 j j 个格子且没有颜色出现s次的方案数。然后对于原题,我们会发现最多只会有 ⌊ns⌋ ⌊ n s ⌋ 种颜色出现恰好s次,于是令 N=min(⌊ns⌋,m) N = m i n ( ⌊ n s ⌋ , m ) ,可以得到:
#include
using namespace std;
typedef long long ll;
const int maxn = 100005, maxt = 1 << 18, mod = 1004535809, G = 3;
const int maxmn = 10000005;
ll fact[maxmn], rev[maxmn], A[maxt], B[maxt], val[maxn], n, m, N, S;
ll modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void NTT(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
int main(){
scanf("%d%d%d", &n, &m, &S);
for(int i = 0; i <= m; i++) scanf("%lld", val + i);
N = min(n / S, m);
fact[0] = rev[0] = 1;
int mn = max(n, m);
for(int i = 1; i <= mn; i++)
fact[i] = fact[i - 1] * i % mod;
rev[mn] = modpow(fact[mn], mod - 2);
for(int i = mn - 1; i > 0; i--)
rev[i] = rev[i + 1] * (i + 1) % mod;
for(int i = 0; i <= N; i++){
A[i] = val[i] * rev[i] % mod;
B[i] = i & 1 ? mod - rev[i] : rev[i];
}
int tn = 1;
while(tn < 2 * (N + 1)) tn <<= 1;
NTT(A, tn, 0), NTT(B, tn, 0);
for(int i = 0; i < tn; i++) A[i] = A[i] * B[i] % mod;
NTT(A, tn, 1);
ll sf = 1, res = 0;
for(int i = 1; i <= S; i++) sf = sf * i % mod;
for(int i = 0; i <= N; i++){
ll t = fact[m] * fact[n] % mod * modpow(m - i, n - i * S) % mod;
t = t * rev[m - i] % mod * rev[n - i * S] % mod * modpow(sf, mod - 1 - i) % mod;
res = (res + t * A[i]) % mod;
}
printf("%lld\n", res);
return 0;
}
原题链接
观察到n很大,如果用矩阵乘法的话m又过大了,考虑用倍增(快速幂)解决。
考虑dp。设 f[i][j] f [ i ] [ j ] 表示当前已经选了i个数,得到的乘积模m为j的选取方案数。如果我们能够快速合并 f[a],f[b] f [ a ] , f [ b ] 得到 f[a+b] f [ a + b ] 的值,这道题就解决了。由于
#include
using namespace std;
typedef long long ll;
const int maxn = 1 << 14, mod = 1004535809, G = 3;
ll modpow(ll a, ll b, ll p = mod){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void ntt(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, (mod - 1 + rev * (mod - 1) / h) % (mod - 1));
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
ll x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev == -1){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
int n, m, x, S, g, id[maxn], num[maxn];
ll A[maxn], B[maxn], C[maxn];
void init(){
int p = m - 1;
vector<int> fac;
for(int i = 2; i * i <= p; i++) if(p % i == 0){
while(p % i == 0) p /= i;
fac.push_back(i);
}
if(p > 1) fac.push_back(p);
int s = fac.size();
for(g = 2;; g++){
int flag = 1;
for(int i = 0; i < s; i++)
if(modpow(g, (m - 1) / fac[i], m) == 1){flag = 0; break;}
if(flag) break;
}
for(int i = 0, pw = 1; i < m - 1; i++, pw = pw * g % m)
num[id[pw] = i] = pw;
}
int main(){
scanf("%d%d%d%d", &n, &m, &x, &S);
init();
for(int i = 0; i < S; i++){
int t; scanf("%d", &t);
if(t % m > 0) ++A[id[t % m]];
}
int len = 1;
while(len < 2 * m - 1) len <<= 1;
for(int flag = 1; n; n >>= 1){
if(n & 1){
if(!flag){
memcpy(C, A, sizeof(C));
ntt(C, len, 1), ntt(B, len, 1);
for(int i = 0; i < len; i++) B[i] = B[i] * C[i] % mod;
ntt(B, len, -1);
for(int i = m - 1; i < len; i++)
(B[i % (m - 1)] += B[i]) %= mod, B[i] = 0;
} else memcpy(B, A, sizeof(B)), flag = 0;
}
ntt(A, len, 1);
for(int i = 0; i < len; i++) A[i] = A[i] * A[i] % mod;
ntt(A, len, -1);
for(int i = m - 1; i < len; i++)
(A[i % (m - 1)] += A[i]) %= mod, A[i] = 0;
}
printf("%lld\n", B[id[x]]);
return 0;
}
原题链接
考虑定义字符匹配函数 match(x,y)=(x−y)2 m a t c h ( x , y ) = ( x − y ) 2 ,这样只有当两个字符相等时函数值才为0。于是可以考虑类似的定义字符串匹配函数:
#include
using namespace std;
typedef long long ll;
const int maxn = 1 << 20, mod = 998244353, G = 3;
int id[128], A[maxn], B[maxn], ans[maxn], n, m, ppos;
char sa[maxn], sb[maxn], prt[10000000];
void print(int x, char c){
if(x){
static char sta[10];
int tp = 0;
for(; x; x /= 10) sta[tp++] = '0' + x % 10;
while(tp > 0) prt[ppos++] = sta[--tp];
} else prt[ppos++] = '0';
prt[ppos++] = c;
}
inline int modpow(int a, int b){
int res = 1;
for(; b; b >>= 1){
if(b & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
}
return res;
}
inline void rader(int *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void ntt(int *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod, a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = (ll)a[i] * inv % mod;
}
}
int main(){
for(char i = 'a'; i <= 'z'; i++) id[i] = i - 'a' + 1;
scanf("%d%d%s%s", &n, &m, sa, sb);
int len = 1;
while(len < n + m - 1) len <<= 1;
for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]] * id[sa[i]] * id[sa[i]];
for(int i = 0; i < m; i++) B[i] = id[sb[i]];
ntt(A, len, 0), ntt(B, len, 0);
for(int i = 0; i < len; i++) ans[i] = (ll)A[i] * B[i] % mod;
memset(A, 0, sizeof(A)), memset(B, 0, sizeof(B));
for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]];
for(int i = 0; i < m; i++) B[i] = id[sb[i]] * id[sb[i]] * id[sb[i]];
ntt(A, len, 0), ntt(B, len, 0);
for(int i = 0; i < len; i++) ans[i] = ((ll)A[i] * B[i] + ans[i]) % mod;
memset(A, 0, sizeof(A)), memset(B, 0, sizeof(B));
for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]] * id[sa[i]];
for(int i = 0; i < m; i++) B[i] = id[sb[i]] * id[sb[i]];
ntt(A, len, 0), ntt(B, len, 0);
for(int i = 0; i < len; i++) ans[i] = (ans[i] - 2LL * A[i] * B[i] % mod + mod) % mod;
ntt(ans, len, 1);
int tp = 0;
for(int i = n - 1; i < m; i++) if(!ans[i]) A[tp++] = i - n + 2;
print(tp, '\n');
for(int i = 0; i < tp; i++) print(A[i], ' ');
fwrite(prt, 1, ppos, stdout);
return 0;
}