Codeforces Round #305 (Div. 1)E. Mike and Friends【后缀数组+线段树】

传送门:Codeforces Round #305 (Div. 1)E. Mike and Friends

这题既然求得是子串,就可以用后缀数组来做(也可以用后缀自动机balabala)
我的方法是 O(nlog2n) O(nlogn) 的方法就是把线段树换成主席树来实现,但是我并不会高冷的主席树。
首先用类似于求N个串的最长公共子串那样,将所有串接在一起,然后用后缀数组得到sa数组。我们知道连接起来的串,每个下标都属于一个串,或者是分隔符,然后我们以sa数组作为下标建立线段树,线段树每个节点所表示的区间就是rank的区间啦。把这个rank区间内所有的下标(sa[i])属于的原本的串的编号提取出来,排个序,以待之后查找使用。
对于一个询问包含的串K,找到K的起点在连接起来的大串上的位置,我们可以二分其在sa数组上能延伸的最左端点以及最右端点,这个时候使用的是rmq,求最长公共前缀的方法。这样就可以使得我们得到的这个区间之间的所有串与K的LCP都大于等于K的串长,也就是说K是这个区间内所有的串的子串。
然后我们只要在线段树内统计编号大于等于i且小于等于j的数个数就好了,这个我们只要在每个被完全包含的区间内二分答案,然后累加即可。

// whn6325689
// Mr.Phoebe
// http://blog.csdn.net/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")

using namespace std;

#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);

typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;

#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))

#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))

#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n1

template<class T>
inline bool read(T &n)
{
    T x = 0, tmp = 1;
    char c = getchar();
    while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
    if(c == EOF) return false;
    if(c == '-') c = getchar(), tmp = -1;
    while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
    n = x*tmp;
    return true;
}
template <class T>
inline void write(T n)
{
    if(n < 0)
    {
        putchar('-');
        n = -n;
    }
    int len = 0,data[20];
    while(n)
    {
        data[len++] = n%10;
        n /= 10;
    }
    if(!len) data[len++] = 0;
    while(len--) putchar(data[len]+48);
}
//-----------------------------------

const int MAXN=500010;
const int LOGF=20;

vi t[MAXN<<2];
int sa[MAXN],rank[MAXN],height[MAXN];
int t1[MAXN],t2[MAXN],xy[MAXN],c[MAXN];

int len[MAXN],belong[MAXN],start[MAXN];
char str[MAXN];
int s[MAXN],dp[MAXN][LOGF],logn[MAXN];
int n,m;

int cmp(int *str,int a,int b,int d)
{
    return str[a]==str[b] && str[a+d]==str[b+d];
}

void get_height(int n,int k=0)
{
    for(int i=0;i<=n;i++)
        rank[sa[i]]=i;
    for(int i=0;i<n;i++)
    {
        if(k)k--;
        int j=sa[rank[i]-1];
        while(s[i+k]==s[j+k]) k++;
        height[rank[i]]=k;
    }
}

void gao(int n,int m,int t1[],int t2[])
{
    int *x=t1,*y=t2;
    for(int i=0;i<m;i++)    c[i]=0;
    for(int i=0;i<n;i++)    c[x[i]=s[i]]++;
    for(int i=1;i<m;i++)    c[i]+=c[i-1];
    for(int i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
    for(int d=1,p=0;p<n;d<<=1,m=p)
    {
        p=0;
        for(int i=n-d;i<n;i++)  y[p++]=i;
        for(int i=0;i<n;i++)    if(sa[i]>=d)    y[p++]=sa[i]-d ;
        for(int i=0;i<m;i++)    c[i]=0;
        for(int i=0;i<n;i++)    c[xy[i]=x[y[i]]]++ ;
        for(int i=1;i<m;i++)    c[i]+=c[i-1];
        for(int i=n-1;i>=0;i--) sa[--c[xy[i]]]=y[i] ;
        swap(x,y) ;
        p=0 ;
        x[sa[0]]=p++;
        for(int i=1;i<n;i++)    x[sa[i]]=cmp(y,sa[i-1],sa[i],d) ? p-1:p++;
    }
    get_height(n-1);
}

void init_rmq(int n)
{
    for(int i=1;i<=n;i++)   dp[i][0]=height[i];
    logn[1]=0;
    for(int i=2;i<=n;i++)   logn[i]=logn[i-1]+(i==lowbit(i));
    for(int j=1;(1<<j)<n;j++)
        for(int i=1;i+(1<<j)-1<=n;i++)
            dp[i][j]=min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);
}

int rmq(int l,int r)
{
    int k=logn[r-l+1];
    return min(dp[l][k],dp[r-(1<<k)+1][k]);
}

void build(int idx,int l,int r)
{
    t[idx].clear();
    for(int i=l;i<=r;i++)
        t[idx].pb(belong[sa[i]]);
    t[idx].pb(MAXN);
    sort(t[idx].begin(),t[idx].end());
    if(l==r)    return;
    int mid=MID(l,r);
    build(lson);build(rson);
}

int query(int L,int R,int x,int y,int idx,int l,int r)
{
    if(L==l && r==R)
    {
        int dw=lower_bound(t[idx].begin(),t[idx].end(),x)-t[idx].begin()-1;
        int up=lower_bound(t[idx].begin(),t[idx].end(),y+1)-t[idx].begin()-1 ;
        return up-dw;
    }
    int mid=MID(l,r);
    if(R<=mid) return query(L,R,x,y,lson);
    else if(L>mid) return query(L,R,x,y,rson);
    else    return query(L,mid,x,y,lson)+query(mid+1,R,x,y,rson);
}

int maxl(int len,int l,int r)
{
    int rr=r;
    while(l<r)
    {
        int mid=MID(l,r);
        if(rmq(mid+1,rr)>=len)  r=mid;
        else    l=mid+1;
    }
    return l;
}

int maxr(int len,int l,int r)
{
    int ll=l;
    while(l<r)
    {
        int mid=MID(l,r+1);
        if(rmq(ll+1,mid)>=len)  l=mid;
        else    r=mid-1;
    }
    return r;
}


int main()
{
// freopen("data.txt","r",stdin);
    while(read(n)&&read(m))
    {
        int n1=0,n2=27;
        for(int i=1;i<=n;i++)
        {
            scanf("%s",str);
            len[i]=strlen(str);
            start[i]=n1;
            for(int j=0;j<len[i];j++)
            {
                belong[n1]=i;
                s[n1++]=str[j]-'a'+1;
            }
            belong[n1]=0;
            s[n1++]=n2++;
        }
        s[--n1]=0;
        gao(n1+1,n2,t1,t2);
        init_rmq(n1);
        build(root);
        int x,y,k;
        while(m--)
        {
            read(x),read(y),read(k);
            int L=maxl(len[k],1,rank[start[k]]);
            int R=maxr(len[k],rank[start[k]],n1);
            int ans=query(L,R,x,y,root);
            write(ans),putchar('\n');
        }
    }
    return 0;
}

你可能感兴趣的:(后缀)