又是愣把串总长复杂度的题写成了给Trie大小的题
(我又可以造题啦!开森
做法是这样的,考虑一个串会有很多种拼接办法,我们取第二个前缀最短的办法。
对应到AC自动机上就是,你先枚举第二个前缀,假设在AC自动机上的节点是x,那么你不能选择一个y,存在一个从y出发的(从x跳到fail[x]丢失的那个前缀)的转移。考虑补集转化,就是统计有多少节点存在这样的转移(根除外,因为第一个串非空),换言之,就是统计这个串在Trie中出现了几次。这就是个广义SAM,直接上即可。特判fail[x]=rt的情况即可。
#include
#include
#include
#include
#include
#define SIG 26
#define LEN 32
#define N 300010
#define lint long long
#define debug(x) cerr<<#x<<"="<
#define sp <<" "
#define ln <
using namespace std;
queue<int> q;int stc[N],stc_sz;
char s[LEN];
struct SAM{
int val[N*2],node_cnt,rt,fa[N*2],ch[N*2][SIG];
vector<int> g[N*2];int sz[N*2];
SAM() { node_cnt=rt=new_node(0); }
inline int new_node(int v)
{
int x=++node_cnt;memset(ch[x],0,sizeof ch[x]);
return val[x]=v,fa[x]=sz[x]=0,node_cnt;
}
inline int extend(int las,int w)//return last
{
int p=las,np=new_node(val[p]+1);
while(p&&!ch[p][w]) ch[p][w]=np,p=fa[p];
if(!p) fa[np]=rt;
else{
int q=ch[p][w],v=val[p]+1;
if(val[q]==v) fa[np]=q;
else{
int nq=new_node(v);
fa[nq]=fa[q],fa[q]=fa[np]=nq;
memcpy(ch[nq],ch[q],sizeof ch[q]);
while(p&&ch[p][w]==q) ch[p][w]=nq,p=fa[p];
}
}
return sz[np]=1,np;
}
inline int get_sz(int x=0)
{
if(!x)
{
for(int i=1;i<=node_cnt;i++)
if(fa[i]) g[fa[i]].push_back(i);
return get_sz(rt);
}
for(int i=0;i<(int)g[x].size();i++)
sz[x]+=get_sz(g[x][i]);return sz[x];
}
}ts;
struct Trie{
int node_cnt,rt,las[N],dpt[N],pos[N],ch[N][SIG],fail[N];
Trie() { rt=node_cnt=new_node(); }
inline int new_node()
{
int x=++node_cnt;memset(ch[x],0,sizeof ch[x]);
return las[x]=dpt[x]=0,node_cnt;
}
inline int insert(char *s,int n)
{
for(int i=1,x=rt;i<=n;i++)
{
int c=s[i]-'a';
if(!ch[x][c]) ch[x][c]=new_node();
dpt[ch[x][c]]=dpt[x]+1,x=ch[x][c];
}
return 0;
}
inline int build_SAM(SAM &t)//before get_fail
{
while(!q.empty()) q.pop();
for(q.push(rt),las[rt]=t.rt;!q.empty();q.pop())
for(int x=q.front(),i=0,y;iif (y=ch[x][i]) las[y]=t.extend(las[x],i),q.push(y);
return t.get_sz();
}
inline int get_pos(const SAM &t)//before get_fail
{
pos[rt]=t.rt;while(!q.empty()) q.pop();
for(q.push(rt);!q.empty();q.pop())
for(int x=q.front(),i=0,y;iif (y=ch[x][i]) pos[y]=t.ch[pos[x]][i],q.push(y);
return 0;
}
inline int get_fail()
{
while(!q.empty()) q.pop();
for(int i=0;iint &y=ch[rt][i];if(y) fail[y]=rt,q.push(y);else y=rt; }
for(;!q.empty();q.pop())
for(int i=0,x=q.front();iint &y=ch[x][i],f=fail[x],c=ch[f][i];
if(y) fail[y]=c,q.push(y);else y=c;
}
return 0;
}
inline lint get_ans(int *sz,int x=0)
{
if(!x) return stc_sz=0,get_ans(sz,rt);
lint ans=(x!=rt)*(node_cnt-1);
if(x!=rt&&fail[x]!=rt) ans-=sz[pos[stc[dpt[x]-dpt[fail[x]]]]]-1;
for(int i=0,y;iif(dpt[y=ch[x][i]]==dpt[x]+1)
ans+=get_ans(sz,stc[++stc_sz]=y),stc_sz--;
return ans;
}
}t;
int main()
{
int n;scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%s",s+1),t.insert(s,(int)strlen(s+1));
t.build_SAM(ts),t.get_pos(ts),t.get_fail();
return !printf("%lld\n",t.get_ans(ts.sz));
}