先对原串建立后缀自动机。考虑每个询问。询问的子串相当于考虑它的n个循环串,因此我们把询问串连接在询问串,那么这个新串的长度为n的子串就是所要求的串。把新串放在后缀自动机上面跑,同时记录现在新串匹配的最长后缀len。如果len>=当前串长度n,那么我们就沿着fa指针跳到满足len>=n且长度最小的节点。那么原点到这个节点必然有一条路径的字符串是当前匹配的长度为n的后缀。这个节点的right集合大小就是所求的答案。但是可能会出现重复,每个节点记录一个vis值就可以了。。累加答案即可。。。
#include <iostream> #include <queue> #include <stack> #include <map> #include <set> #include <bitset> #include <cstdio> #include <algorithm> #include <cstring> #include <climits> #include <cstdlib> #include <cmath> #include <time.h> #define maxn 1000005 #define maxm 2000005 #define eps 1e-7 #define mod 1000000007 #define INF 0x3f3f3f3f #define PI (acos(-1.0)) #define lowbit(x) (x&(-x)) #define mp make_pair #define ls o<<1 #define rs o<<1 | 1 #define lson o<<1, L, mid #define rson o<<1 | 1, mid+1, R #define pii pair<int, int> #pragma comment(linker, "/STACK:16777216") typedef long long LL; typedef unsigned long long ULL; //typedef int LL; using namespace std; LL qpow(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base;base=base*base;b/=2;}return res;} LL powmod(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base%mod;base=base*base%mod;b/=2;}return res;} // head const int alpha = 26; struct node { int len, cnt, vis; node *ch[alpha], *fa; }*last, *tail, pool[maxm], *root; node *tp[maxm]; char ss[maxm]; char s[maxm]; int c[maxm]; int n, m; node* newnode(int len) { tail->len = len; tail->cnt = tail->vis = 0; tail->fa = NULL; memset(tail->ch, 0, sizeof tail->ch); return tail++; } void init() { tail = pool; root = last = newnode(0); memset(c, 0, sizeof c); } void add(int c) { node *p = last, *np = newnode(p->len + 1); last = np; for(; p && !p->ch[c]; p = p->fa) p->ch[c] = np; if(!p) np->fa = root; else { node *q = p->ch[c]; if(q->len == p->len + 1) np->fa = q; else { node *nq = newnode(p->len + 1); *nq = *q; nq->len = p->len + 1; np->fa = q->fa = nq; for(; p && p->ch[c] == q; p = p->fa) p->ch[c] = nq; } } } void solve(int T) { int ans = 0; node *p = root; int len = 0; int n = strlen(s); for(int i = n; i < 2 * n; i++) s[i] = s[i-n]; for(int i = 0; i < 2 * n; i++) { int t = s[i] - 'a'; while(p && !p->ch[t]) { p = p->fa; if(p) len = p->len; else len = 0; } if(p) p = p->ch[t], len++; else p = root, len = 0; if(len >= n) { while(p->fa->len >= n) p = p->fa, len = p->len; if(p->vis != T) p->vis = T, ans += p->cnt; } } printf("%d\n", ans); } void work() { scanf("%s", ss); init(); for(int i = 0; ss[i]; i++) add(ss[i] - 'a'); node *o = root; for(int i = 0; ss[i]; i++) o = o->ch[ss[i] - 'a'], o->cnt++; int n = strlen(ss); for(node *p = pool; p != tail; p++) c[p->len]++; for(int i = 1; i <= n; i++) c[i] += c[i-1]; for(node *p = pool; p != tail; p++) tp[--c[p->len]] = p; int tot = tail - pool; for(int i = tot - 1; i > 0; i--) tp[i]->fa->cnt += tp[i]->cnt; scanf("%d", &m); for(int i = 1; i <= m; i++) { scanf("%s", s); solve(i); } } int main() { work(); return 0; }