题目链接: http://poj.org/problem?id=3415
题目大意:给定最大长度为10w的两个串,问两个串中子串长度均大于K并且相等的对数,比如a和aa,就是2.
解题思路:把两个串拼接起来,中间用神奇地‘$'字符隔开。然后用倍增算法求sa数组、rank数组、height数组,然后利用height数组统计第一个串称A串的所有后缀与第二个串B串的所有后缀所共有的长度大于k的子串。
上面的逻辑过程容易理解,用aa和aa就能推出为什么答案是5,子串a有2个,aa有1个,2*2 + 1*1 = 5.问题的关键是要找出各个子串总匹配数(子串相等称为匹配)。height数组表示的是和其他串所有最长公共前缀中的最长公共前缀,也是和前一个串的公共前缀,用这个数组就可以求出总匹配数。height[i]-k+1代表和前一个子串相比符合条件的匹配数,如果height[i]大于它属于组内(根据height[i]是否大于k分组,height数组是波浪形,很容易分组)前面的每一个height[j]的话,就可以一次一次和前面的height[j]比较如果是大于,就增加height[i]-k+1,这就代表从sa[i]开始的后缀与sa[j]开始的后缀有height[i]-k+1个相等的串,一旦小于,就应该做些处理,把多出来的那部分去掉,因为这部分以后再也不会有相等的串了,没什么价值了。但这样,复杂度是O(n^2),可用单调栈优化,具体实现见代码。
上面一段话是几个月前写的,其实是对于单调栈的理解不深所以没细讲,今天好好想了一下,恍然大悟。A串对B串,B串对A串的统计过程是一样的,分析下A串对B串然后以此类推到B串对A串即可。先说下统计公共子串的原理,其实是利用A串的后缀与排名比它高的B串的后缀的最长公共前缀l进行统计,如果l大于k,就增加l-k+1个公共子串。
然后要说到这个最长公共前缀,子串b与当前串(排名比他低的A串后缀)的最长公共前缀是它的高一位排名到当前后缀之间最小的那个height,我们要统计前面所有B串的子串和当前A串后缀所共有的公共子串只需要用到几个height值就好,而这些height值是递增的。为什么是递增的,比如0,1,2,3,1,2,3,那么统计到第一个1的时候,0的和它的最长公共前缀为1,统计到2的时候前两个的最长公共前缀是1,2,统计到3的时候是1,2,3。统计到1的时候则有一个很大的变动,因为前4个的最长公共前缀都可以是这个1,所有只用到这个1就好...
正因为上面一段的分析,我们可以用一个栈和一个值total来动态维护前面所有后缀和当前串的最长公共前缀长度。栈里存的是height值,total其实就是栈里递增的所有height值和当前串所共有的大于k的公共子串。当我们的某个更小的height值进栈时,需要维护这个total值,因为栈顶的几个height相对大些,表示的公共子串多些,当遇到更小的height值,最长公共前缀长度应该是更小的那个,所以要地减去那一部分多出来的,然后用个num[to]表示栈顶的这个height代表几个大灯与它的height。
测试数据:
2
ccbb
代码:
#include <stdio.h> #include <string.h> #define MAX 200100 int k,n,top; int len,len1,len2; __int64 tot,ans,num[MAX]; int st[MAX],arr[MAX]; int wa[MAX],wb[MAX]; int wv[MAX],wn[MAX]; int sa[MAX],rank[MAX],h[MAX]; int cmp(int *r,int a,int b,int l) { return r[a] == r[b] && r[a+l] == r[b+l]; } void Da(int *r,int n,int m) { int i,j,k,p,*t; int *x = wa,*y = wb; for (i = 0; i < m; ++i) wn[i] = 0; for (i = 0; i < n; ++i) wn[x[i]=r[i]]++; for (i = 1; i < m; ++i) wn[i] += wn[i-1]; for (i = n - 1; i >= 0; --i) sa[--wn[x[i]]] = i; for (j = 1,p = 1; p < n; j *= 2,m = p) { for (p = 0,i = n - j; i < n; ++i) y[p++] = i; for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j; for (i = 0; i < n; ++i) wv[i] = x[y[i]]; for (i = 0; i < m; ++i) wn[i] = 0; for (i = 0; i < n; ++i) wn[wv[i]]++; for (i = 1; i < m; ++i) wn[i] += wn[i-1]; for (i = n - 1; i >= 0; --i) sa[--wn[wv[i]]] = y[i]; t = x,x = y,y = t,p = 1; for (x[sa[0]] = 0,i = 1; i < n; ++i) x[sa[i]] = cmp(y,sa[i-1],sa[i],j) ? p - 1 : p++; } } void CalHeight(int *r,int n) { int i,j,k = 0; for (i = 1; i <= n; ++i) rank[sa[i]] = i; for (i = 0; i < n; h[rank[i++]] = k) for (k ? k-- : 0,j = sa[rank[i]-1]; r[i+k] == r[j+k];k++); } __int64 Solve(int n,int k) { __int64 i,j,tp,ans = 0; for (i = 1; i <= n; ++i) { //单调栈处理 if (h[i] < k) tot = top = 0; //分组,小于k的就分为一组,不多做处理 else { tp = 0; //默认tp = 0 if (sa[i-1] > len1) //如果前面一段是字符串B tp = 1,tot += h[i] - k + 1; //tp是累计可增加的贡献值 while (top > 0 && st[top] >= h[i]) { tot -= num[top] * (st[top] - h[i]); tp += num[top],top--; } st[++top] = h[i],num[top] = tp; if (sa[i] < len1) ans += tot; } } for (i = 1; i <= n; ++i) { //单调栈处理 if (h[i] < k) tot = top = 0; else { tp = 0; if (sa[i-1] < len1) tp = 1,tot += h[i] - k + 1; while (top > 0 && st[top] >= h[i]) { tot -= num[top] * (st[top] - h[i]); tp += num[top],top--; } st[++top] = h[i],num[top] = tp; if (sa[i] > len1) ans += tot; } } return ans; } int main() { int i,j,t,cas = 0; char str1[MAX],str2[MAX]; while (scanf("%d",&k),k) { scanf("%s%s",str1,str2); for (i = 0; str1[i]; ++i) arr[i] = str1[i]; arr[i] = '$',len1 = i,i++; for (j = 0; str2[j];j++) arr[i+j] = str2[j]; arr[i+j] = 0,len = i + j; Da(arr,len+1,150); CalHeight(arr,len); ans = Solve(len,k); printf("%I64d\n",ans); } }