练习一下字符串,做一下这道题。
首先是关于一个字符串有多少不同子串的问题,串由小到大排起序来应该是按照sa[i]的顺序排出来的产生的。
好像abbacd,排序出来的后缀是这样的
1---abbacd 第一个串产生的6个前缀都是新的子串
2---acd 第二个串除了和上一个串的前缀1 3-1=2 产生了2个子串
3---bacd 4-0=4
4---bbacd 5-1=4
5---cd 2-0=0
6---d 1-0=0
所以所有不同的前缀应该是(len-sa[i])-lcp[i-1]的和,即串长减去与上一个串的最长公共前缀,然后求和。
所以我们可以预处理出dp[i]表示sa[i]的后缀所产生的新串的个数,然后对dp[i]求一次前缀和,那么每次询问第k大的串的时候就可以直接lower_bound,找出串的左端和右端。但是这个(l,r)不一定是最小的,最小的可能是在sa[i+1]..sa[i+2]...里产生,所以我们首先要二分出合法的sa边界,即sa[i]....sa[j]里都可以产生的这个串,然后sa[i]...sa[j]的最小值即是我们要求的。写二分总是要跪要跪的- -0
#include <iostream> #include <cstdio> #include <cstring> #include <vector> #include <algorithm> #include <string> #include <numeric> #include <cassert> using namespace std; #define maxn 120000 #define ll long long struct SuffixArray { int n; int m[2][maxn]; int sa[maxn]; char s[maxn]; void indexSort(int sa[], int ord[], int id[], int nId){ static int cnt[maxn]; memset(cnt, 0, sizeof(0)*nId); for (int i = 0; i < n; i++){ cnt[id[i]]++; } partial_sum(cnt, cnt + nId, cnt); for (int i = n - 1; i >= 0; i--){ sa[--cnt[id[ord[i]]]] = ord[i]; } } int *id, *oId; void init(){ n = strlen(s) + 1; static int w[maxn]; for (int i = 0; i <= n; i++) w[i] = s[i]; sort(w, w + n); int nId = unique(w, w + n) - w; id = m[0], oId = m[1]; for (int i = 0; i < n; i++){ id[i] = lower_bound(w, w + nId, s[i]) - w; } static int ord[maxn]; for (int i = 0; i < n; i++){ ord[i] = i; } indexSort(sa, ord, id, nId); for (int k = 1; k <= n&&nId < n; k <<= 1){ int cur = 0; for (int i = n - k; i < n; i++){ ord[cur++] = i; } for (int i = 0; i < n; i++){ if (sa[i] >= k) ord[cur++] = sa[i] - k; } indexSort(sa, ord, id, nId); cur = 0; swap(oId, id); for (int i = 0; i < n; i++){ int c = sa[i], p = i ? sa[i - 1] : 0; id[c] = (i == 0 || oId[c] != oId[p] || oId[c + k] != oId[p + k]) ? cur++ : cur - 1; } nId = cur; } } // lcp relevant int rk[maxn], lcp[maxn]; void getlcp(){ for (int i = 0; i < n; i++) rk[sa[i]] = i; int h = 0; lcp[0] = 0; for (int i = 0; i < n; i++){ int j = sa[rk[i] - 1]; for (h ? h-- : 0; i + h < n&&j + h < n&&s[i + h] == s[j + h]; h++); lcp[rk[i] - 1] = h; } } // lcp query relevant int d[maxn + 50][25]; int mi[maxn+50][25]; void getrmq(){ for (int i = 0; i < n; i++) d[i][0] = lcp[i]; for (int j = 1; (1 << j) < n; j++){ for (int i = 0; (i + (1 << j) - 1) < n; i++){ d[i][j] = min(d[i][j - 1], d[i + (1 << (j - 1))][j - 1]); } } for(int i=0;i<n;i++) mi[i][0]=sa[i]; for (int j = 1; (1 << j) < n; j++){ for (int i = 0; (i + (1 << j) - 1) < n; i++){ mi[i][j] = min(mi[i][j - 1], mi[i + (1 << (j - 1))][j - 1]); } } } int rmq_query3(int l,int r){ if(l==r) return mi[l][0]; int k=0;int len=r-l+1; while((1<<(k+1))<len) ++k; return min(mi[l][k], mi[r - (1 << k) + 1][k]); } int rmq_query(int l, int r){ if(l==r) return n-1-sa[l]; if (l > r) swap(l, r); r -= 1; int k = 0; int len = r - l + 1; while ((1 << (k + 1)) < len) k++; return min(d[l][k], d[r - (1 << k) + 1][k]); } int rmq_query2(int l, int r){ l = rk[l], r = rk[r]; if (l > r) swap(l, r); r -= 1; int k = 0; int len = r - l + 1; while ((1 << (k + 1)) < len) k++; return min(d[l][k], d[r - (1 << k) + 1][k]); } }sa; int nQ; ll dp[maxn]; int n; int main() { while(~scanf("%s",sa.s)){ sa.init(); sa.getlcp(); sa.getrmq(); n=sa.n-1; dp[0]=0; for(int i=1;i<=n;++i){ dp[i]=n-sa.sa[i]-sa.lcp[i-1]; dp[i]+=dp[i-1]; } ll ansl=0,ansr=0; ll ki; scanf("%d",&nQ); while(nQ--){ scanf("%I64d",&ki); ki=(ki^ansl^ansr)+1; if(ki>dp[n]){ ansl=ansr=0; printf("%d %d\n",ansl,ansr); continue; } int tl,tr; int id=lower_bound(dp,dp+n+1,ki)-dp; tl=sa.sa[id]; tr=tl+sa.lcp[id-1]+ki-dp[id-1]-1; int len=tr-tl+1; int lf=id,rf=n; while(lf<rf){ int mid=(lf+rf+1)>>1; if(sa.rmq_query(id,mid) >= len) lf=mid; else rf=mid-1; } ansl=sa.rmq_query3(id,lf)+1; ansr=ansl+len-1; printf("%I64d %I64d\n",ansl,ansr); } } return 0; }