一道sam练习题

题意

有两个字符串A,B,求这两个字符串长度大于等于k的公共子串数量

数据范围

串 长 ≤ 1 e 5 , ∣ ∑ ∣ 小 写 字 母 串长\le 1e5 ,|\sum|小写字母 1e5,

解法

sam板子题…具体来讲,首先求出A的sam,然后把B串放在上面跑LCS的操作,求出对于B的每个前缀在A上的LCS.然后记下这个LCS的长度和位置.

然后做dfs,求出sam上每个节点的贡献就可以了,具体实现看代码

#include
using namespace std;
const int maxn=3e5+5;
inline int read(){
	char c=getchar();int t=0,f=1;
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
	return t*f;
}
int k,n,m;
long long ans;
char A[maxn],B[maxn];
struct sam{
	int ch[maxn][26],len[maxn],sz[maxn],fa[maxn],ed,ct,vis[maxn];
	vector<int> son[maxn],flag[maxn];
	inline void insert(int c){
		int np=++ct,p=ed;
		ed=ct;len[np]=len[p]+1;sz[np]=1;
		for(;p&&(!ch[p][c]);p=fa[p])ch[p][c]=np;
		if(!p)fa[np]=1;
		else{
			int q=ch[p][c];
			if(len[q]==len[p]+1){fa[np]=q;}
			else{
				int nq=++ct;
				len[nq]=len[p]+1;fa[nq]=fa[q];
				fa[q]=fa[np]=nq;
				for(int i=0;i<26;i++)ch[nq][i]=ch[q][i];
				for(;p&&(ch[p][c]==q);p=fa[p])ch[p][c]=nq;
			}
		}
	}
	void clear(){
		for(int i=1;i<=ct;i++){son[i].clear();flag[i].clear();}
		ed=ct=1;
		memset(ch,0,sizeof(ch));
		memset(fa,0,sizeof(fa));
		memset(len,0,sizeof(len));
		memset(sz,0,sizeof(sz));
		memset(vis,0,sizeof(vis));
	}
	void dfs(int u){
		for(int i=0;i<son[u].size();i++){
			int v=son[u][i];
			dfs(v);
			sz[u]+=sz[v];
		}
	}
	void dfs2(int u){
		for(int i=0;i<son[u].size();i++){
			int v=son[u][i];
			dfs2(v);vis[u]+=vis[v];
		}
		if(len[u]>=k){
			vis[u]-=flag[u].size();
			for(int i=0;i<flag[u].size();i++){
				ans=ans+1ll*(flag[u][i]-max(len[fa[u]]+1,k)+1)*sz[u];//计算u这个节点自己的答案 
			}
			ans=ans+1ll*vis[u]*(len[u]-max(len[fa[u]]+1,k)+1)*sz[u];//计算u的后辈节点带来的影响(这里的vis[u]是u的后辈节点的vis[u]之和) 
			vis[u]+=flag[u].size();//这里记得改回来 
		}
	}
}a;
signed main(){
	freopen("1.in","r",stdin);
	freopen("1.out","w",stdout);
	k=read();
	while(k){
		scanf("%s",A+1);
		scanf("%s",B+1);
		a.clear();
		ans=0;
		n=strlen(A+1);
		m=strlen(B+1);
		for(int i=1;i<=n;i++){
			a.insert(A[i]-'a');
		}
		for(int i=2;i<=a.ct;i++)a.son[a.fa[i]].push_back(i);
		a.dfs(1);
		int p=1,tot=0;
		for(int i=1;i<=m;i++){
			int c=B[i]-'a';
			if(a.ch[p][c]){
				p=a.ch[p][c];tot++;
			}
			else{
				while(p&&(!a.ch[p][c])){
					p=a.fa[p];
				}
				if(!p){
					p=1;tot=0;
				}
				else{
					tot=a.len[p]+1;p=a.ch[p][c];
				}
			}//求出B串的每个前缀和A串的LCS长度,以及对应的节点位置 
			if(tot>=k){
				a.flag[p].push_back(tot);
				a.vis[p]++;
			}
		}
		a.dfs2(1);
		printf("%lld\n",ans);
		k=read();
	}
	return 0;
}

你可能感兴趣的:(sam)