Common Substring poj3415


 弱的无奈,看论文后想不到维护的方法,看题解知道了维护的方法自己写又WA的一B,最后照搬了了别人的代码==。。。维护方法见代码和注释吧


这题除了用栈来维护,还有一种用笛卡尔树(就是一种特殊的堆,黑书P94)来统计的方法也是,按照height的值建树(O(n))后统计一遍(O(N)),统计的方法是,对于height >= K的节点,

∑(左子树A后缀的个数×右子树B后缀的个数+左子树B后缀的个数×右子树A后缀的个数+与该节点不属于同一串的子节点的个数)×(height-K+1)

不过统计时可能会爆栈


#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include    
#include 
#include   
#include  
#include 
#include     
#include 
#include    
#include 
#include 

using std::priority_queue;
using std::vector;
using std::swap;
using std::stack;
using std::sort;
using std::max;
using std::min;
using std::pair;
using std::map;
using std::string;
using std::cin;
using std::cout;
using std::set;
using std::queue;
using std::string;
using std::stringstream;
using std::make_pair;
using std::getline;
using std::greater;
using std::endl;
using std::multimap;
using std::deque;
using std::unique;
using std::lower_bound;
using std::random_shuffle;
using std::bitset;

typedef long long LL;
typedef unsigned long long ULL;
typedef pair PAIR;
typedef multimap MMAP;
typedef LL TY;

const int MAXN(200010);
const int MAXM(50010);
const int MAXE(100010);
const int MAXK(6);
const int HSIZE(131313);
const int SIGMA_SIZE(52);
const int MAXH(19);
const int INFI((INT_MAX-1) >> 1);
const ULL BASE(31);
const LL LIM(10000000);
const int INV(-10000);
const int MOD(1000000007);
const double EPS(1e-7);

template void checkmax(T &a, T b){if(b > a) a = b;}
template void checkmin(T &a, T b){if(b < a) a = b;}
template T ABS(const T &a){return a < 0? -a: a;}

int K;

struct SA
{
	char S[MAXN];
	int sa[MAXN], t1[MAXN], t2[MAXN], cnt[MAXN], len, M;
	void init(int tl, int tm = 128)
	{
		len = tl;
		M = tm;
		int *p1 = t1;
		int *p2 = t2;
		for(int i = 0; i < M; ++i) cnt[i] = 0;
		for(int i = 0; i <= len; ++i) ++cnt[p1[i] = S[i]];
		for(int i = 1; i < M; ++i) cnt[i] += cnt[i-1];
		for(int i = len; i >= 0; --i) sa[--cnt[p1[i]]] = i;
		int temp = 1;
		for(int k = 1; temp <= len; k <<= 1)
		{
			temp = 0;
			for(int i = len-k+1; i <= len; ++i) p2[temp++] = i;
			for(int i = 0; i <= len; ++i)
				if(sa[i] >= k)
					p2[temp++] = sa[i]-k;
			for(int i = 0; i < M; ++i) cnt[i] = 0;
			for(int i = 0; i <= len; ++i) ++cnt[p1[p2[i]]];
			for(int i = 1; i < M; ++i) cnt[i] += cnt[i-1];
			for(int i = len; i >= 0; --i) sa[--cnt[p1[p2[i]]]] = p2[i];
			swap(p1, p2);
			temp = 1;
			p1[sa[0]] = 0;
			for(int i = 1; i <= len; ++i)
				p1[sa[i]] = p2[sa[i-1]] == p2[sa[i]] && p2[sa[i-1]+k] == p2[sa[i]+k]? temp-1: temp++;
			M = temp;
		}
	}
	int rank[MAXN], height[MAXN];
	void getHeight()
	{
		int k = 0;
		for(int i = 0; i <= len; ++i)
			rank[sa[i]] = i;
		for(int i = 0; i < len; ++i)
		{
			if(k) --k;
			int j = sa[rank[i]-1];
			while(S[i+k] == S[j+k]) ++k;
			height[rank[i]] = k;
		}
	}
	/*
	int Log[MAXN];
	int table[MAXH][MAXN];
	void initLog()
	{
		Log[0] = -1;
		for(int i = 1; i < MAXN; ++i)
			Log[i] = (i&(i-1))?Log[i-1]: Log[i-1]+1;
	}
	void initRMQ()
	{
		for(int i = 1; i <= len; ++i)
			table[0][i] = height[i];
		for(int i = 1; (1 << i) <= len; ++i)
			for(int j = 1; j+(1 << i)-1 <= len; ++j)
				table[i][j] = min(table[i-1][j], table[i-1][j+(1 << (i-1))]);
	}
	int lcp(int a, int b)
	{
		a = rank[a];
		b = rank[b];
		if(a > b) swap(a, b);
		++a;
		int temp = Log[b-a+1];
		return min(table[temp][a], table[temp][b-(1 << temp)+1]);
	}
	*/
}sa;

int len1, len2;
inline int idx(int ind)
{
	return ind <= len1? 0: 1;
}

LL ans;
int I;
int cnt[MAXN][2];

struct STACK
{
	LL arr[MAXN][3];
	int top;
	void init(int ind)
	{
		top = 0;
		arr[top][0] = 0;        //arr[top][2]与前面的串1可以形成多少相同串
		arr[top][1] = 0;		//arr[top][2]与前面的串2可以形成多少相同串
		arr[top][2] = ind;
	}
	void push(int id, int ind)
	{
		while(top && sa.height[ind] <= sa.height[arr[top][2]])
			--top;
		LL t[2];
		t[0] = arr[top][0]+(cnt[ind-1][0]-cnt[arr[top][2]-1][0])*(sa.height[ind]-K+1);
		t[1] = arr[top][1]+(cnt[ind-1][1]-cnt[arr[top][2]-1][1])*(sa.height[ind]-K+1);
		ans += t[!id];
		++top;
		arr[top][0] = t[0];
		arr[top][1] = t[1];
		arr[top][2] = ind;
	}
}st;

int main()
{
	while(scanf("%d", &K), K)
	{
		scanf("%s", sa.S);
		len1 = strlen(sa.S);
		sa.S[len1] = '*';
		scanf("%s", sa.S+len1+1);
		len2 = strlen(sa.S);
		sa.init(len2);
		sa.getHeight();
		ans = 0;
		for(I = 0; I <= len2; )
		{
			if(sa.sa[I]+K > len2)
			{
				++I;
				continue;
			}
			cnt[I-1][0] = cnt[I-1][1] = 0;
			int id = idx(sa.sa[I]);
			cnt[I][id] = 1;
			cnt[I][!id] = 0;
			st.init(I);
			int j = I+1;
			while(j <= len2 && sa.height[j] >= K)
			{
				id = idx(sa.sa[j]);
				cnt[j][id] = cnt[j-1][id]+1;
				cnt[j][!id] = cnt[j-1][!id];
				st.push(id, j);
				++j;
			}
			I = j;
		}
		printf("%I64d\n", ans);
	}
	return 0;
}





你可能感兴趣的:(数据结构,串)