题目大意:
就是现在给出三个总长度不超过3*10^5的字符串, 每个字符串只包含字母'a' ~ 'z', 现在对于每一个L, (1 <= L <= minLength(s1, s2, s3))也就是L不超过s1, s2, s3中最短长度, 求出存在多少个i, j, k使得s1[ i ~ i + L - 1] == s2[ j ~ j + L - 1] == s3[ k ~ k + L - 1], 结果对于10^9 + 7取模之后输出
大致思路:
首先不难想到后缀数组处理三个串拼接起来的总串, 记录每一个字符的来源, 也就是记录每个后缀的来源, 然后需要根据height数组从大到小来利用并查集标记区间进行计算, 注意两个区间合并的时候之后 (i, j, k)三者不来自同一个原来的区间才能算, 所以稍微容斥一下即可
由于事先对height数组排序了, 所以也不需要树状数组之类的来辅助更新答案, 直接利用排序好的height数组的单调性即可
之前想过一个从height由小到大切割区间进行分治dfs的方法, 然后利用树状数组辅助更新答案, 但是复杂度还是太高了...果然还是需要用并查集
代码如下:
Result : Accepted Memory : 27500 KB Time : 202 ms
/* * Author: Gatevin * Created Time: 2015/3/18 16:10:35 * File Name: Kotori_Itsuka.cpp */ #include<iostream> #include<sstream> #include<fstream> #include<vector> #include<list> #include<deque> #include<queue> #include<stack> #include<map> #include<set> #include<bitset> #include<algorithm> #include<cstdio> #include<cstdlib> #include<cstring> #include<cctype> #include<cmath> #include<ctime> #include<iomanip> using namespace std; const double eps(1e-8); typedef long long lint; #define maxn 300010 int wa[maxn], wb[maxn], wv[maxn], Ws[maxn]; int cmp(int *r, int a, int b, int l) { return r[a] == r[b] && r[a + l] == r[b + l]; } void da(int *r, int *sa, int n, int m) { int *x = wa, *y = wb, *t, i, j, p; for(i = 0; i < m; i++) Ws[i] = 0; for(i = 0; i < n; i++) Ws[x[i] = r[i]]++; for(i = 1; i < m; i++) Ws[i] += Ws[i - 1]; for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i; for(j = 1, p = 1; p < n; j *= 2, m = p) { for(p = 0, i = n - j; i < n; i++) y[p++] = i; for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j; for(i = 0; i < n; i++) wv[i] = x[y[i]]; for(i = 0; i < m; i++) Ws[i] = 0; for(i = 0; i < n; i++) Ws[wv[i]]++; for(i = 1; i < m; i++) Ws[i] += Ws[i - 1]; for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i]; for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++) x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++; } return; } int rank[maxn], height[maxn]; void calheight(int *r, int *sa, int n) { int i, j, k = 0; for(i = 1; i <= n; i++) rank[sa[i]] = i; for(i = 0; i < n; height[rank[i++]] = k) for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++); return; } int f[maxn]; int find(int x) { return x == f[x] ? x : f[x] = find(f[x]); } bool cmp2(int a, int b) { return height[a] > height[b]; } char in[maxn]; int s[maxn], sa[maxn], p[maxn], belong[maxn], N; lint cnt[maxn][3], ans[maxn]; const lint mod = 1e9 + 7; int main() { int mlen = 1e9; N = 0; for(int i = 0; i < 3; i++) { scanf("%s", in); int len = strlen(in); mlen = min(len, mlen); for(int j = 0; j < len; j++) { belong[N] = i; s[N++] = in[j] - 'a' + 1; } belong[N] = -1; s[N++] = 27 + i; } N--; s[N] = 0; da(s, sa, N + 1, 30); calheight(s, sa, N); for(int i = 0; i <= N; i++) p[i] = f[i] = i; for(int i = 0; i <= N; i++) if(belong[i] != -1) cnt[i][belong[i]]++; sort(p + 1, p + N + 1, cmp2); lint result = 0; for(int i = 1; i <= N; i++) { if(i > 1 && height[p[i]] != height[p[i - 1]]) for(int j = height[p[i]] + 1; j <= height[p[i - 1]]; j++) ans[j] = result; int bl = find(sa[p[i]]), br = find(sa[p[i] - 1]); result = (result - cnt[bl][0]*cnt[bl][1]*cnt[bl][2] % mod + mod) % mod; result = (result - cnt[br][0]*cnt[br][1]*cnt[br][2] % mod + mod) % mod; for(int j = 0; j < 3; j++) cnt[bl][j] = (cnt[bl][j] + cnt[br][j]) % mod; f[br] = bl; result = (result + cnt[bl][0]*cnt[bl][1]*cnt[bl][2]) % mod; } for(int i = 1; i <= mlen; i++) printf("%I64d ", ans[i]); return 0; }
Result : Time Limit Exceeded on test 42
/* * Author: Gatevin * Created Time: 2015/3/12 22:32:12 * File Name: Kotori_Itsuka.cpp */ #include<iostream> #include<sstream> #include<fstream> #include<vector> #include<list> #include<deque> #include<queue> #include<stack> #include<map> #include<set> #include<bitset> #include<algorithm> #include<cstdio> #include<cstdlib> #include<cstring> #include<cctype> #include<cmath> #include<ctime> #include<iomanip> using namespace std; const double eps(1e-8); typedef long long lint; const lint mod = 1000000007LL; #define maxn 300100 int wa[maxn], wb[maxn], wv[maxn], Ws[maxn]; int cmp(int *r, int a, int b, int l) { return r[a] == r[b] && r[a + l] == r[b + l]; } void da(int *r, int *sa, int n, int m) { int *x = wa, *y = wb, *t, i, j, p; for(i = 0; i < m; i++) Ws[i] = 0; for(i = 0; i < n; i++) Ws[x[i] = r[i]]++; for(i = 1; i < m; i++) Ws[i] += Ws[i - 1]; for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i; for(j = 1, p = 1; p < n; j *= 2, m = p) { for(p = 0, i = n - j; i < n; i++) y[p++] = i; for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j; for(i = 0; i < n; i++) wv[i] = x[y[i]]; for(i = 0; i < m; i++) Ws[i] = 0; for(i = 0; i < n; i++) Ws[wv[i]]++; for(i = 1; i < m; i++) Ws[i] += Ws[i - 1]; for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i]; for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++) x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++; } return; } int rank[maxn], height[maxn]; void calheight(int *r, int *sa, int n) { int i, j, k = 0; for(i = 1; i <= n; i++) rank[sa[i]] = i; for(i = 0; i < n; height[rank[i++]] = k) for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++); return; } char in[maxn]; int s[maxn], sa[maxn], belong[maxn], N; lint ans[maxn]; lint C[maxn]; int lowbit(int x) { return -x & x; } void add(int L, lint value) { while(L <= N) C[L] = (C[L] + value) % mod, L += lowbit(L); return; } void update(int L, int R, lint value)//区间更新[L, R] += value { add(L, value), add(R + 1, (-value + mod) % mod); } lint query(int pos)//单点查询 { lint ret = 0; while(pos) ret = (ret + C[pos]) % mod, pos -= lowbit(pos); return ret; } void dfs(int L, int R, int h) { int i = L; while(i <= R) { while(i <= R && height[i] == h) i++; if(i > R) break; lint cnt[4]; memset(cnt, 0, sizeof(cnt)); int j = i; cnt[belong[sa[j - 1]]]++; int nexh = height[i]; while(j <= R && height[j] > h) cnt[belong[sa[j]]]++, nexh = min(nexh, height[j]), j++; // for(int k = h + 1; k <= nexh; k++) // ans[k] = (ans[k] + cnt[1]*cnt[2]*cnt[3] % mod) % mod; update(h + 1, nexh, cnt[1]*cnt[2]*cnt[3] % mod); dfs(i, j - 1, nexh); i = j; } return; } void solve(int mlen) { dfs(1, N, 0); for(int i = 1; i <= mlen; i++) printf("%I64d ", query(i)); } int main() { int minlen = 1e9; N = 0; for(int i = 1; i <= 3; i++) { scanf("%s", in); int len = strlen(in); minlen = min(minlen, len); for(int j = 0; j < len; j++) { belong[N] = i; s[N++] = in[j] - 'a' + 1; } belong[N] = -1; s[N++] = 26 + i; } N--; s[N] = 0; da(s, sa, N + 1, 30); calheight(s, sa, N); solve(minlen); return 0; }