Codeforces 452E Three strings 后缀数组 + 并查集

题目大意:

就是现在给出三个总长度不超过3*10^5的字符串, 每个字符串只包含字母'a' ~ 'z', 现在对于每一个L, (1 <= L <= minLength(s1, s2, s3))也就是L不超过s1, s2, s3中最短长度, 求出存在多少个i, j, k使得s1[ i ~ i + L - 1] == s2[ j ~ j + L - 1] == s3[ k ~ k + L - 1], 结果对于10^9 + 7取模之后输出


大致思路:

首先不难想到后缀数组处理三个串拼接起来的总串, 记录每一个字符的来源, 也就是记录每个后缀的来源, 然后需要根据height数组从大到小来利用并查集标记区间进行计算, 注意两个区间合并的时候之后 (i, j, k)三者不来自同一个原来的区间才能算, 所以稍微容斥一下即可

由于事先对height数组排序了, 所以也不需要树状数组之类的来辅助更新答案, 直接利用排序好的height数组的单调性即可


之前想过一个从height由小到大切割区间进行分治dfs的方法, 然后利用树状数组辅助更新答案, 但是复杂度还是太高了...果然还是需要用并查集


代码如下:

Result  :  Accepted     Memory  :  27500 KB     Time  :  202 ms

/*
 * Author: Gatevin
 * Created Time:  2015/3/18 16:10:35
 * File Name: Kotori_Itsuka.cpp
 */
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;

#define maxn 300010

int wa[maxn], wb[maxn], wv[maxn], Ws[maxn];

int cmp(int *r, int a, int b, int l)
{
    return r[a] == r[b] && r[a + l] == r[b + l];
}

void da(int *r, int *sa, int n, int m)
{
    int *x = wa, *y = wb, *t, i, j, p;
    for(i = 0; i < m; i++) Ws[i] = 0;
    for(i = 0; i < n; i++) Ws[x[i] = r[i]]++;
    for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
    for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i;
    for(j = 1, p = 1; p < n; j *= 2, m = p)
    {
        for(p = 0, i = n - j; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
        for(i = 0; i < n; i++) wv[i] = x[y[i]];
        for(i = 0; i < m; i++) Ws[i] = 0;
        for(i = 0; i < n; i++) Ws[wv[i]]++;
        for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
        for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i];
        for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
    return;
}

int rank[maxn], height[maxn];
void calheight(int *r, int *sa, int n)
{
    int i, j, k = 0;
    for(i = 1; i <= n; i++) rank[sa[i]] = i;
    for(i = 0; i < n; height[rank[i++]] = k)
        for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);
    return;
}

int f[maxn];
int find(int x)
{
    return x == f[x] ? x : f[x] = find(f[x]);
}

bool cmp2(int a, int b)
{
    return height[a] > height[b];
}

char in[maxn];
int s[maxn], sa[maxn], p[maxn], belong[maxn], N;
lint cnt[maxn][3], ans[maxn];
const lint mod = 1e9 + 7;

int main()
{
    int mlen = 1e9;
    N = 0;
    for(int i = 0; i < 3; i++)
    {
        scanf("%s", in);
        int len = strlen(in);
        mlen = min(len, mlen);
        for(int j = 0; j < len; j++)
        {
            belong[N] = i;
            s[N++] = in[j] - 'a' + 1;
        }
        belong[N] = -1;
        s[N++] = 27 + i;
    }
    N--;
    s[N] = 0;
    da(s, sa, N + 1, 30);
    calheight(s, sa, N);
    for(int i = 0; i <= N; i++) p[i] = f[i] = i;
    for(int i = 0; i <= N; i++)
        if(belong[i] != -1) cnt[i][belong[i]]++;
    sort(p + 1, p + N + 1, cmp2);
    lint result = 0;
    for(int i = 1; i <= N; i++)
    {
        if(i > 1 && height[p[i]] != height[p[i - 1]])
            for(int j = height[p[i]] + 1; j <= height[p[i - 1]]; j++)
                ans[j] = result;
        int bl = find(sa[p[i]]), br = find(sa[p[i] - 1]);
        result = (result - cnt[bl][0]*cnt[bl][1]*cnt[bl][2] % mod + mod) % mod;
        result = (result - cnt[br][0]*cnt[br][1]*cnt[br][2] % mod + mod) % mod;
        for(int j = 0; j < 3; j++)
            cnt[bl][j] = (cnt[bl][j] + cnt[br][j]) % mod;
        f[br] = bl;
        result = (result + cnt[bl][0]*cnt[bl][1]*cnt[bl][2]) % mod;
    }
    for(int i = 1; i <= mlen; i++)
        printf("%I64d ", ans[i]);
    return 0;
}



顺带祭奠一下以前写的TLE了的方法...

Result  :  Time Limit Exceeded on test 42

/*
 * Author: Gatevin
 * Created Time:  2015/3/12 22:32:12
 * File Name: Kotori_Itsuka.cpp
 */
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;

const lint mod = 1000000007LL;

#define maxn 300100

int wa[maxn], wb[maxn], wv[maxn], Ws[maxn];

int cmp(int *r, int a, int b, int l)
{
    return r[a] == r[b] && r[a + l] == r[b + l];
}

void da(int *r, int *sa, int n, int m)
{
    int *x = wa, *y = wb, *t, i, j, p;
    for(i = 0; i < m; i++) Ws[i] = 0;
    for(i = 0; i < n; i++) Ws[x[i] = r[i]]++;
    for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
    for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i;
    for(j = 1, p = 1; p < n; j *= 2, m = p)
    {
        for(p = 0, i = n - j; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
        for(i = 0; i < n; i++) wv[i] = x[y[i]];
        for(i = 0; i < m; i++) Ws[i] = 0;
        for(i = 0; i < n; i++) Ws[wv[i]]++;
        for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
        for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i];
        for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
    return;
}

int rank[maxn], height[maxn];
void calheight(int *r, int *sa, int n)
{
    int i, j, k = 0;
    for(i = 1; i <= n; i++) rank[sa[i]] = i;
    for(i = 0; i < n; height[rank[i++]] = k)
        for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);
    return;
}


char in[maxn];
int s[maxn], sa[maxn], belong[maxn], N;
lint ans[maxn];

lint C[maxn];
int lowbit(int x)
{
    return -x & x;
}

void add(int L, lint value)
{
    while(L <= N)
        C[L] = (C[L] + value) % mod, L += lowbit(L);
    return;
}

void update(int L, int R, lint value)//区间更新[L, R] += value
{
    add(L, value), add(R + 1, (-value + mod) % mod);
}

lint query(int pos)//单点查询
{
    lint ret = 0;
    while(pos)
        ret = (ret + C[pos]) % mod, pos -= lowbit(pos);
    return ret;
}

void dfs(int L, int R, int h)
{
    int i = L;
    while(i <= R)
    {
        while(i <= R && height[i] == h) i++;
        if(i > R) break;
        lint cnt[4]; memset(cnt, 0, sizeof(cnt));
        int j = i;
        cnt[belong[sa[j - 1]]]++;
        int nexh = height[i];
        while(j <= R && height[j] > h)
            cnt[belong[sa[j]]]++, nexh = min(nexh, height[j]), j++;
    //    for(int k = h + 1; k <= nexh; k++)
    //        ans[k] = (ans[k] + cnt[1]*cnt[2]*cnt[3] % mod) % mod;
        update(h + 1, nexh, cnt[1]*cnt[2]*cnt[3] % mod);
        dfs(i, j - 1, nexh);
        i = j;
    }
    return;
}

void solve(int mlen)
{
    dfs(1, N, 0);
    for(int i = 1; i <= mlen; i++)
        printf("%I64d ", query(i));
}

int main()
{
    int minlen = 1e9;
    N = 0;
    for(int i = 1; i <= 3; i++)
    {
        scanf("%s", in);
        int len = strlen(in);
        minlen = min(minlen, len);
        for(int j = 0; j < len; j++)
        {
            belong[N] = i;
            s[N++] = in[j] - 'a' + 1;
        }
        belong[N] = -1;
        s[N++] = 26 + i;
    }
    N--;
    s[N] = 0;
    da(s, sa, N + 1, 30);
    calheight(s, sa, N);
    solve(minlen);
    return 0;
}



你可能感兴趣的:(后缀数组,并查集,codeforces,strings,three,452E)