Codeforces Edu 85 G. Substring Search(FFT字符串匹配)

题目连接

题意:

给你一个s和t串,然后用s去匹配t串中的每一个长度与s相等的子串。如果 s i = t j s_i=t_j si=tj或者 p i d x ( s i ) = i d x ( t j ) p_{idx(s_i)}=idx(t_j) pidx(si)=idx(tj)则可以匹配,输出每一个位置可以匹配的情况。

做法:

显然是不能用常规的字符串匹配去解决。这里如果你做过洛谷上的带有通配符的字符串的匹配,可能知道怎么做,就是用FFT 解决。
带通配符的字符串匹配
这里我们构造一个式子, ( s i − t j ) 2 ( p j − t j ) 2 (s_i-t_j)^2(p_j-t_j)^2 (sitj)2(pjtj)2表示i和j这两个位置匹配。
如果长度为m的匹配则 ∑ j = 0 m − 1 ( s j − t i + j ) 2 ( p j − t i + j ) 2 \sum_{j=0}^{m-1}(s_{j}-t_{i+j})^2(p_{j}-t_{i+j})^2 j=0m1(sjti+j)2(pjti+j)2有多少i的位置等于零,这个化简一下显然就是多项式相乘,用FFT就可以了。
我这里写得是NTT写得丑刚刚卡过去,如果WA126注意随机一下 idx,就可以了。

#include "bits/stdc++.h"

using namespace std;
#define VI vector
#define ll long long
#define SZ(x) ((int)x.size())
#define all(x) x.begin(),x.end()
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x; }
const int maxn = 1 << 19;
const ll mod = 998244353;
int Mod(int x) {
    if (x >= mod) x -= mod;
    return x;
}
ll quick(ll a, ll n) {
    ll ans = 1;
    while (n) {
        if (n & 1) ans = ans * a % mod;
        n >>= 1;
        a = a * a % mod;
    }
    return ans;
}
const ll g = 3;
int r[maxn], tot, lim, roots[33];
void ntt(int *a, int inv) {
    for (int i = 0; i < tot; i++) {
        if (i < r[i]) swap(a[i], a[r[i]]);
    }
    for (int l = 2, id = 1; l <= tot; l <<= 1, id++) {
        int tmp = roots[id];
        if (inv == -1) tmp = quick(tmp, mod - 2);
        int m = l / 2;
        for (int j = 0; j < tot; j += l) {
            int w = 1;
            for (int i = 0; i < m; i++) {
                int t = 1ll * a[j + i + m] * w % mod;
                a[j + i + m] = Mod(a[j + i] - t + mod);
                a[j + i] = Mod(a[j + i] + t);
                w = 1ll * w * tmp % mod;
            }
        }
    }
    if (inv == -1) {
        int t = quick(tot, mod - 2);
        for (int i = 0; i < tot; i++) {
            a[i] = 1LL * a[i] * t % mod;
        }
    }
}

void init(int n, int m) {
    tot = 1, lim = 0;
    while (tot < n + m) tot <<= 1, lim++;
    for (int i = 0; i < tot; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lim - 1));
    }
    for (int i = 1; i <= lim; i++) {
        int t = 1 << i;
        roots[i] = quick(g, (mod - 1) / t);
    }
}
int A[maxn], B[maxn], P[maxn];
vector<int> multiply(int *a, int *b, int n, int m) {
    for (int i = 0; i < m; i++) B[i] = b[i];
    for (int i = 0; i < n; i++) A[i] = a[i];
    ntt(A, 1);
    ntt(B, 1);
    for (int i = 0; i < tot; i++) P[i] = 1ll * A[i] * B[i] % mod;
    ntt(P, -1);
    vector<int> ans(tot, 0);
    for (int i = 0; i < tot; i++) {
        ans[i] = P[i];
        P[i] = A[i] = B[i] = 0;
    }
    return ans;
}
int n, p[30];
char s[maxn], t[maxn];
int c[5][maxn], d[5][maxn], res[maxn], val[maxn];

int main() {
//    freopen("1.in", "r", stdin);
//    double ss = clock();
    for (int i = 0; i < 26; i++) val[i] = rnd(100);
    for (int i = 0, x; i < 26; i++) {
        scanf("%d", &x);
        p[x - 1] = i;
    }
    scanf("%s%s", s, t);
    n = strlen(t);
    for (int i = 0; i < n; i++) {
        int id = t[i] - 'a';
        int x = val[id], y = val[p[id]];
        c[4][i] = 1ll;
        c[3][i] = (mod - 2 * x - 2 * y);
        c[2][i] = x * x + y * y + 4 * x * y;
        c[1][i] = mod - 2 * x * y * (x + y);
        c[0][i] = x * x * y * y;
    }
    int m = strlen(s);
    for (int i = 0; i < m; i++) {
        int id = val[s[i] - 'a'];
        d[4][i] = id * id * id * id;
        d[3][i] = id * id * id;
        d[2][i] = id * id;
        d[1][i] = id;
        d[0][i] = 1ll;
    }
    init(n, m);
    for (int i = 0; i < 5; i++) {
        reverse(d[i], d[i] + m);
        vector<int> pp = multiply(c[i], d[i], n, m);
        for (int j = m - 1; j < n; j++) res[j] = (res[j] + pp[j]) % mod;
    }
    for (int i = m - 1; i < n; i++) {
        if (res[i] == 0)printf("1");
        else printf("0");
    }
    puts("");
//    cout << clock() - ss << endl;
    return 0;
}

你可能感兴趣的:(FFT/NTT/FWT)