Codeforces Round #146 (Div. 1) C - Cyclical Quest (后缀自动机SAM)

http://codeforces.com/problemset/problem/235/C

陈立杰出的后缀自动机。

题目大意:给一个字符串S,再给一个字符串T,设T的长度为len,问T的循环串在S中出现的次数,这里循环串的定义是:对于一个长度为len的字符串,我们把它首尾相接,然后从任意位置开始走len步所得到的串我们叫做T的循环串。如abaa的循环串有 abaa,baaa,aaab,aaba。(注意如果重复只算一次。比如aaa的循环串只有一个aaa)


思路:对于字符串S,我们构造S的后缀自动机,然后对于每一个字符串T,我们设T'为T去掉最后一个字符所得到的字符串,然后构造TT',在S的后缀自动机上进行匹配,我们可以算出对于TT'的每一个位置,可以匹配的最大总长度,那么当匹配长度大于等于len时(这里的len为T的长度),设当前所在状态为p,则我们可以根据fa链找到第一个匹配长度大于等于len时所对应的状态,设为q,则我们设状态q所表示的子串出现的次数为q->num,则ans += q->num,num的计算还是通过拓扑排序,自底向上求即可,注意这里有可能有重复,所以我们还得在每一个状态里设一个标记flag,表示当前状态是否被计算过,若已计算过则跳过即可。



数组写法(代码量少,空间小,访问速度快)

//155 ms 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

#define REP(i,n) for ( int i=1; i<=int(n); i++ )  
#define MP make_pair
#define PB push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()
#define X first
#define Y second
template inline bool sonkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template inline bool sonkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }

typedef long long LL;
typedef long double LD;
const int INF = 0x3f3f3f3f;

const int N = 1e6 + 10;


vector ans;
bool flag[N << 1];
namespace SAM {
  int sz = 0, rt = 0, last = 0;
  int son[N << 1][26], fa[N << 1], val[N << 1], sc[N << 1];
  void init() {
    for(int i = 0; i <= sz; i ++) {
        memset(son[i], 0, sizeof(son[i]));
        fa[i] = val[i] = sc[i] = 0;
    }
    sz = 0; rt = ++ sz; last = rt;
  }
  void add(int c) {
    int p = last, np = ++ sz;
    last = np; val[np] = val[p] + 1;
    sc[np] = 1;
    for (; p && !son[p][c]; p = fa[p]) son[p][c] = np;
    if (p) {
      int q = son[p][c];
      if (val[p] + 1 == val[q]) fa[np] = q;
      else {
        int nq = ++ sz; 
        memcpy(son[nq], son[q], sizeof(son[q]));
        fa[nq] = fa[q], val[nq] = val[p] + 1;
        fa[q] = fa[np] = nq;
        for (; p && son[p][c] == q; p = fa[p]) son[p][c] = nq;
      }
    }
    else fa[np] = rt;
  }
  void getRight(char *s, int n) {
    static int Q[N << 1];
    static int cnt[N];
    for (int i = 0; i <= n; ++ i) cnt[i] = 0;
    for (int p = rt; p <= sz; ++ p) cnt[val[p]] ++;
    for (int i = 1; i <= n; ++ i) cnt[i] += cnt[i - 1];
    for (int p = rt; p <= sz; ++ p) Q[-- cnt[val[p]]] = p;
    for (int i = sz - 1; i >= 0; -- i) {
        int p = Q[i]; if (fa[p]) sc[fa[p]] += sc[p];
    }
  }
  void build(char *s, int n) {
    init();
    for (int i = 0; i < n; ++ i) add(s[i] - 'a');
    getRight(s, n);
  }

  int solve(char *s, int n) {
    int res = 0, len = (n + 1) / 2;
    int p = rt;
    int matson_len = 0;
    for(int i = 0; i < n; i ++) {
        int c = s[i] - 'a';
        while(p && !son[p][c]) p = fa[p];
        if(p) {
            matson_len = min(matson_len, val[p]) + 1;
            p = son[p][c];
        } else p = rt, matson_len = 0;
        if(matson_len >= len) {
            while(val[fa[p]] >= len) {
                p = fa[p];
                matson_len = min(matson_len, val[p]);
            }
            if(flag[p]) continue;
            flag[p] = true;
            ans.push_back(p);
        }
    }
    for(int i = 0; i < ans.size(); i ++)
        flag[ans[i]] = 0, res += sc[ans[i]];
    ans.clear();
    return res;
  }
}

char str[N];
int main() {
    scanf("%s", str);
    SAM::build(str, strlen(str));
    int m;
    scanf("%d", &m);
    while(m --) {
        scanf("%s", str);
        int len = strlen(str);
        for(int i = 0; i < len - 1; i ++) str[len + i] = str[i];
        printf("%d\n", SAM::solve(str, 2 * len - 1));
    }
}


结构体写法(结构清晰,代码量长,访问较慢)

//280 ms
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

#define REP(i,n) for ( int i=1; i<=int(n); i++ )  
#define MP make_pair
#define PB push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()
#define X first
#define Y second
template inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }

typedef long long LL;
typedef long double LD;
const int INF = 0x3f3f3f3f;

typedef pair pii;

const int N = 1e6 + 10;


struct Node {
  int ch[26], fa, val, sc;
  Node(): fa(0), val(0), sc(0) {
    memset(ch, 0, sizeof(ch));
  }
  void clear() {
    memset(ch, 0, sizeof(ch));
    fa = 0; val = sc = 0;
  }
} pool[N << 1];
vector ans;
bool flag[N << 1];
namespace SAM {
  int sz = 0, rt = 0, last = 0;
  void init() {
    for(int i = 0; i <= sz; i ++) pool[i].clear();
    sz = 0; rt = ++ sz; last = rt;
  }
  void add(int c) {
    int p = last, np = ++ sz;
    last = np; pool[np].val = pool[p].val + 1;
    pool[np].sc = 1;
    for (; p && !pool[p].ch[c]; p = pool[p].fa) pool[p].ch[c] = np;
    if (p) {
      int q = pool[p].ch[c];
      if (pool[p].val + 1 == pool[q].val) pool[np].fa = q;
      else {
        int nq = ++ sz; 
        memcpy(&pool[nq], &pool[q], sizeof(pool[q]));
        pool[nq].val = pool[p].val + 1, pool[nq].sc = 0;
        pool[q].fa = nq; pool[np].fa = nq;
        for (; p && pool[p].ch[c] == q; p = pool[p].fa) pool[p].ch[c] = nq;
      }
    }
    else pool[np].fa = rt;
  }
  void getRight(char *s, int n) {
    static int Q[N << 1];
    static int cnt[N];
    for (int i = 0; i <= n; ++ i) cnt[i] = 0;
    for (int p = rt; p <= sz; ++ p) cnt[pool[p].val] ++;
    for (int i = 1; i <= n; ++ i) cnt[i] += cnt[i - 1];
    for (int p = rt; p <= sz; ++ p) Q[-- cnt[pool[p].val]] = p;
    for (int i = sz - 1; i >= 0; -- i) {
        int p = Q[i]; if (pool[p].fa) pool[pool[p].fa].sc += pool[p].sc;
    }
  }
  void build(char *s, int n) {
    init();
    for (int i = 0; i < n; ++ i) add(s[i] - 'a');
    getRight(s, n);
  }

  int solve(char *s, int n) {
    int res = 0, len = (n + 1) / 2;
    int p = rt;
    int match_len = 0;
    for(int i = 0; i < n; i ++) {
        int c = s[i] - 'a';
        while(p && !pool[p].ch[c]) p = pool[p].fa;
        if(p) {
            match_len = min(match_len, pool[p].val) + 1;
            p = pool[p].ch[c];
        } else p = rt, match_len = 0;
        if(match_len >= len) {
            while(pool[pool[p].fa].val >= len) {
                p = pool[p].fa;
                match_len = min(match_len, pool[p].val);
            }
            if(flag[p]) continue;
            flag[p] = true;
            ans.push_back(p);
        }
    }
    for(int i = 0; i < ans.size(); i ++)
        flag[ans[i]] = 0, res += pool[ans[i]].sc;
    ans.clear();
    return res;
  }
}

char str[N];
int main() {
    scanf("%s", str);
    SAM::build(str, strlen(str));
    int m;
    scanf("%d", &m);
    while(m --) {
        scanf("%s", str);
        int len = strlen(str);
        for(int i = 0; i < len - 1; i ++) str[len + i] = str[i];
        printf("%d\n", SAM::solve(str, 2 * len - 1));
    }
}

结构体指针写法(结构清晰,代码量小,访问慢,指针空间消耗大)


//280 ms
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

#define REP(i,n) for ( int i=1; i<=int(n); i++ )  
#define MP make_pair
#define PB push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()
#define X first
#define Y second
template inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }

typedef long long LL;
typedef long double LD;
const int INF = 0x3f3f3f3f;

typedef pair pii;

const int N = 1e6 + 10;


struct Node {
  Node *ch[26], *fa;
  int val, sc;
  Node(): fa(0), val(0), sc(0) {
    memset(ch, 0, sizeof(ch));
  }
  void clear() {
    memset(ch, 0, sizeof(ch));
    fa = 0; val = sc = 0;
  }
} pool[N << 1], *rt, *last;
vector ans;
bool flag[N << 1];
namespace SAM {
  Node *sz = pool;
  void init() {
    if (sz != pool) {
      for (Node *p = pool; p < sz; ++ p) p->clear();
    }
    sz = pool; rt = sz ++; last = rt;
  }
  void add(int c) {
    Node *p = last, *np = sz ++;
    last = np; np->val = p->val + 1;
    np->sc = 1;
    for (; p && !p->ch[c]; p = p->fa) p->ch[c] = np;
    if (p) {
      Node *q = p->ch[c];
      if (p->val + 1 == q->val) np->fa = q;
      else {
        Node *nq = sz ++; *nq = *q;
        nq->sc = 0;
        nq->val = p->val + 1;
        q->fa = nq; np->fa = nq;
        for (; p && p->ch[c] == q; p = p->fa) p->ch[c] = nq;
      }
    }
    else np->fa = rt;
  }
  void getRight(char *s, int n) {
    static Node* Q[N << 1];
    static int cnt[N];
    for (int i = 0; i <= n; ++ i) cnt[i] = 0;
    for (Node *p = pool; p < sz; ++ p) cnt[p->val] ++;
    for (int i = 1; i <= n; ++ i) cnt[i] += cnt[i - 1];
    for (Node *p = pool; p < sz; ++ p) Q[-- cnt[p->val]] = p;
    for (int i = (sz - pool) - 1; i >= 0; -- i) {
      Node *p = Q[i]; if (p->fa) p->fa->sc += p->sc;
    }
  }
  void build(char *s, int n) {
    init();
    for (int i = 0; i < n; ++ i) add(s[i] - 'a');
    getRight(s, n);
  }

  int solve(char *s, int n) {
    int res = 0, len = (n + 1) / 2;
    Node *p = rt;
    int match_len = 0;
    for(int i = 0; i < n; i ++) {
        int c = s[i] - 'a';
        while(p && !p->ch[c]) p = p->fa;
        if(p) {
            match_len = min(match_len, p->val) + 1;
            p = p->ch[c];
        } else p = rt, match_len = 0;
        if(match_len >= len) {
            while(p != rt && p->fa->val >= len) {
                p = p->fa;
                match_len = min(match_len, p->val);
            }
            if(flag[p - pool]) continue;
            flag[p - pool] = true;
            ans.push_back(p);
        }
    }
    for(int i = 0; i < ans.size(); i ++)
        flag[ans[i] - pool] = 0, res += ans[i]->sc;
    ans.clear();
    return res;
  }
}

char str[N];
int main() {
    scanf("%s", str);
    SAM::build(str, strlen(str));
    int m;
    scanf("%d", &m);
    while(m --) {
        scanf("%s", str);
        int len = strlen(str);
        for(int i = 0; i < len - 1; i ++) str[len + i] = str[i];
        printf("%d\n", SAM::solve(str, 2 * len - 1));
    }
}


你可能感兴趣的:(后缀自动机)