poj 3415 Common Substrings (字符串_后缀数组)

题目链接: http://poj.org/problem?id=3415


题目大意:给定最大长度为10w的两个串,问两个串中子串长度均大于K并且相等的对数,比如a和aa,就是2.


解题思路:把两个串拼接起来,中间用神奇地‘$'字符隔开。然后用倍增算法求sa数组、rank数组、height数组,然后利用height数组统计第一个串称A串的所有后缀与第二个串B串的所有后缀所共有的长度大于k的子串。

   上面的逻辑过程容易理解,用aa和aa就能推出为什么答案是5,子串a有2个,aa有1个,2*2 + 1*1 = 5.问题的关键是要找出各个子串总匹配数(子串相等称为匹配)。height数组表示的是和其他串所有最长公共前缀中的最长公共前缀,也是和前一个串的公共前缀,用这个数组就可以求出总匹配数。height[i]-k+1代表和前一个子串相比符合条件的匹配数,如果height[i]大于它属于组内(根据height[i]是否大于k分组,height数组是波浪形,很容易分组)前面的每一个height[j]的话,就可以一次一次和前面的height[j]比较如果是大于,就增加height[i]-k+1,这就代表从sa[i]开始的后缀与sa[j]开始的后缀有height[i]-k+1个相等的串,一旦小于,就应该做些处理,把多出来的那部分去掉,因为这部分以后再也不会有相等的串了,没什么价值了。但这样,复杂度是O(n^2),可用单调栈优化,具体实现见代码。


    上面一段话是几个月前写的,其实是对于单调栈的理解不深所以没细讲,今天好好想了一下,恍然大悟。A串对B串,B串对A串的统计过程是一样的,分析下A串对B串然后以此类推到B串对A串即可。先说下统计公共子串的原理,其实是利用A串的后缀与排名比它高的B串的后缀的最长公共前缀l进行统计,如果l大于k,就增加l-k+1个公共子串。

     然后要说到这个最长公共前缀,子串b与当前串(排名比他低的A串后缀)的最长公共前缀是它的高一位排名到当前后缀之间最小的那个height,我们要统计前面所有B串的子串和当前A串后缀所共有的公共子串只需要用到几个height值就好,而这些height值是递增的。为什么是递增的,比如0,1,2,3,1,2,3,那么统计到第一个1的时候,0的和它的最长公共前缀为1,统计到2的时候前两个的最长公共前缀是1,2,统计到3的时候是1,2,3。统计到1的时候则有一个很大的变动,因为前4个的最长公共前缀都可以是这个1,所有只用到这个1就好...

    正因为上面一段的分析,我们可以用一个栈和一个值total来动态维护前面所有后缀和当前串的最长公共前缀长度。栈里存的是height值,total其实就是栈里递增的所有height值和当前串所共有的大于k的公共子串。当我们的某个更小的height值进栈时,需要维护这个total值,因为栈顶的几个height相对大些,表示的公共子串多些,当遇到更小的height值,最长公共前缀长度应该是更小的那个,所以要地减去那一部分多出来的,然后用个num[to]表示栈顶的这个height代表几个大灯与它的height。


测试数据:

2
aaa
aaa
1
ababa
ababa
1
xx
xx
2
aababaa
abaabaa
1
ccbb

ccbb


代码:

#include <stdio.h>
#include <string.h>
#define MAX 200100


int k,n,top;
int len,len1,len2;
__int64 tot,ans,num[MAX];
int st[MAX],arr[MAX];
int wa[MAX],wb[MAX];
int wv[MAX],wn[MAX];
int sa[MAX],rank[MAX],h[MAX];


int cmp(int *r,int a,int b,int l) {
	
	return r[a] == r[b] && r[a+l] == r[b+l];
}
void Da(int *r,int n,int m) {
	
	int i,j,k,p,*t;
	int *x = wa,*y = wb;
	
	
	for (i = 0; i < m; ++i) wn[i] = 0;
	for (i = 0; i < n; ++i) wn[x[i]=r[i]]++;
	for (i = 1; i < m; ++i) wn[i] += wn[i-1];
	for (i = n - 1; i >= 0; --i) sa[--wn[x[i]]] = i;
	
	
	for (j = 1,p = 1; p < n; j *= 2,m = p) {
		
		for (p = 0,i = n - j; i < n; ++i) y[p++] = i;
		for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
		
		
		for (i = 0; i < n; ++i) wv[i] = x[y[i]];
		for (i = 0; i < m; ++i) wn[i] = 0;
		for (i = 0; i < n; ++i) wn[wv[i]]++;
		for (i = 1; i < m; ++i) wn[i] += wn[i-1];
		for (i = n - 1; i >= 0; --i) sa[--wn[wv[i]]] = y[i];
		
		
		t = x,x = y,y = t,p = 1;
		for (x[sa[0]] = 0,i = 1; i < n; ++i)
			x[sa[i]] = cmp(y,sa[i-1],sa[i],j) ? p - 1 : p++;
	}
}
void CalHeight(int *r,int n) {
	
	int i,j,k = 0;
	for (i = 1; i <= n; ++i) rank[sa[i]] = i;
	for (i = 0; i < n; h[rank[i++]] = k)
		for (k ? k-- : 0,j = sa[rank[i]-1]; r[i+k] == r[j+k];k++);
}


__int64 Solve(int n,int k) {
	
	__int64 i,j,tp,ans = 0;
	for (i = 1; i <= n; ++i) {				//单调栈处理
		
		if (h[i] < k) tot = top = 0;		//分组,小于k的就分为一组,不多做处理
		else {

			tp = 0;							//默认tp = 0
			if (sa[i-1] > len1)				//如果前面一段是字符串B
				tp = 1,tot += h[i] - k + 1; //tp是累计可增加的贡献值

	
			while (top > 0 && st[top] >= h[i]) {

				tot -= num[top] * (st[top] - h[i]);
				tp += num[top],top--;
			}


			st[++top] = h[i],num[top] = tp;
			if (sa[i] < len1) ans += tot;
		}
	}
	

	for (i = 1; i <= n; ++i) { //单调栈处理
		
		if (h[i] < k) tot = top = 0;
		else {

			tp = 0;
			if (sa[i-1] < len1)
				tp = 1,tot += h[i] - k + 1;

	
			while (top > 0 && st[top] >= h[i]) {
				
				tot -= num[top] * (st[top] - h[i]);
				tp += num[top],top--;
			}


			st[++top] = h[i],num[top] = tp;
			if (sa[i] > len1) ans += tot;
		}
	}

	
	return ans;
}


int main()
{
	int i,j,t,cas = 0;
	char str1[MAX],str2[MAX];
	
	
	while (scanf("%d",&k),k) {
		
		scanf("%s%s",str1,str2);
		for (i = 0; str1[i]; ++i)
			arr[i] = str1[i];
		arr[i] = '$',len1 = i,i++;
		for (j = 0; str2[j];j++)
			arr[i+j] = str2[j];
		arr[i+j] = 0,len = i + j;
		
		
		Da(arr,len+1,150);
		CalHeight(arr,len);
		

		ans = Solve(len,k);
		printf("%I64d\n",ans);
	}
}

你可能感兴趣的:(优化,算法,测试)