【2020HDU多校】第三场1001(HDU6767)Tokitsukaze, CSL and Palindrome Game——回文自动机+树上倍增+Hash

题目链接
(因为Windows的栈和Linux的栈不同,我被卡了3个小时,又因为HDU那个老化的机子,我又被卡了3个小时(标程本地跑1.7s,HDU上能跑4s,我的代码本地跑2.1s HDU上直接爆掉6s,只能说6s的时限很合理,但是HDU的机子实在太差))

题意

给你一个回文串,然后 sjfcsl 各从中选出一个子串,保证子串也是回文。
随后还有另外一个串,这个串起始的时候为空,然后不断的随机向这个串中加入字符串。请问 sjfcsl 两人选的串哪一个会先出现

分析

基础结论

首先是根据《浅谈生成函数在掷骰子问题上的应用》可以得到题解中的那样一大段话

b o r d e r border border: 对于一个长度为 L L L 的序列 A A A,若 A [ 1 , i ] = A [ L − i + 1 , L ] A[1, i] = A[L − i + 1, L] A[1,i]=A[Li+1,L],则称 A [ 1 , i ] A[1, i] A[1,i] A A A 的一个 b o r d e r border border
本题只需要比较 S a . . b S_{a..b} Sa..b S c . . d S_{c..d} Sc..d b o r d e r border border 长度组成的序列的字典序即可。

具体可以参考这篇文章

接下来是怎么实现的过程

过程分析

首先是求算 b o r d e r border border
很明显是回文树就可以解决问题,找到当前区间的右端点所在回文树上的节点位置,然后沿着 fail 树向根前进,直到找到一个节点,节点的 len[i] 和给出的要求的长度相同,即找到了所给出的串。
根据这节点,我们沿着 fail 继续前进,可以得到一系列由 len 组成数列,这些数列就是这个字符串的 b o r d e r border border 序列

但是考虑到查找到 len[i] 和给出的要求的长度相同的过程耗时非常长,而整条 fail 树的路径是满足 len 递减的,所以考虑树上倍增的方法去找到对应的节点。

然后是比较两个序列的整条 fail 路径上的点的大小。这仍然非常耗时,所以继续考虑倍增的思路,去查找 len 值第一个不相等的地方。但是虽然能保证 len 是递减,但是没有办法保证倍增路径上的点相等可以推导出整条路径相等。所以增加一层判断。

我们对整个串进行 h a s h hash hash 使得所有的 fail 路径变成一条 h a s h hash hash 路径。我们定义 h a s h hash hash 函数为

h a s h [ i ] = ( a ∗ h a s h [ i − 1 ] + l e n [ i ] ) m o d   b hash[i] = (a*hash[i - 1] + len[i]) mod \space b hash[i]=(ahash[i1]+len[i])mod b

则对于任意一段区间( [ l , r ] [l, r] [l,r])的 h a s h hash hash 值为

h a s h [ l , r ] = h a s h [ r ] − h a s h [ l ] ∗ a r − l + 1 hash[l, r] = hash[r] - hash[l] * a^{r - l + 1} hash[l,r]=hash[r]hash[l]arl+1

预处理 a r − l + 1 a^{r - l + 1} arl+1 部分即可

为了保证准确率,通常会采用两组不同的 a , b a, b a,b 来保证算法的准确性,但是HDU的机子实在是有点差,我只能卡一组 a , b a, b a,b 的方式通过。

AC code

HDU上能够AC的code

#pragma GCC optimize(3, "Ofast", "inline")

#include 

using namespace std;

#define MAXN 100100
#define LOG 18
#define CHAR_NUM 30             // 字符集个数,注意修改下方的 (-'a')
#define ll long long
#define HASH 2

template<class __T>
inline void FastRead(__T &x) {
    x = 0;
    int ch = getchar();
    while ((ch > '9' || ch < '0') && ~ch) {
        ch = getchar();
    }
    while (ch <= '9' && ch >= '0') {
        x = x * 10 + ch - 48;
        ch = getchar();
    }
}

template<class T>
inline void FastPrint(T &x) {
    if (x > 9) FastPrint(x / 10);
    putchar(x % 10 + '0');
}

struct FUCK {
    int fa[MAXN][LOG], depth[MAXN];
    vector<int> link[MAXN];
    int len[MAXN];
//    ll hash[MAXN][HASH];
//    ll a[HASH], b[HASH], base[MAXN][HASH];
    int hash[MAXN][HASH];
    int a[HASH], b[HASH], base[MAXN][HASH];

    void dfs(int cur, int pre) {
        memset(fa[cur], 0, sizeof(fa[cur]));
        fa[cur][0] = pre;

        hash[cur][0] = (hash[pre][0] * a[0] + len[cur]) % b[0];
        hash[cur][1] = (hash[pre][1] * a[1] + len[cur]) % b[1];

        int p = 1;
        while (fa[cur][p - 1]) {
            fa[cur][p] = fa[fa[cur][p - 1]][p - 1];
            p++;
        }
        fa[cur][p] = 0;
        for (auto item : link[cur]) {
            depth[item] = depth[cur] + 1;
            dfs(item, cur);
        }
    }

    inline void init() {
        a[0] = 131;
        a[1] = 233;
        b[0] = 998244353;
        b[1] = 1e9 + 7;
        base[0][0] = 1;
        base[0][1] = 1;
        for (int i = 1; i < MAXN; ++i) {
            base[i][0] = int((ll)base[i - 1][0] * a[0] % b[0]);
            base[i][1] = int((ll)base[i - 1][1] * a[1] % b[1]);
        }
        depth[0] = 0;
        len[0] = -1;
    }

    void build() {
        queue<pair<int, int>> q;
        q.push({0, 0});
        while (!q.empty()) {
            auto cur = q.front();
            q.pop();
            memset(fa[cur.first], 0, sizeof(fa[cur.first]));
            fa[cur.first][0] = cur.second;

            hash[cur.first][0] = int(((ll)hash[cur.second][0] * a[0] + len[cur.first]) % b[0]);
            hash[cur.first][1] = int(((ll)hash[cur.second][1] * a[1] + len[cur.first]) % b[1]);

            int p = 1;
            while (fa[cur.first][p - 1]) {
                fa[cur.first][p] = fa[fa[cur.first][p - 1]][p - 1];
                p++;
            }
            fa[cur.first][p] = 0;
            for (auto item : link[cur.first]) {
                depth[item] = depth[cur.first] + 1;
                q.push({item, cur.first});
            }
        }
    }

    inline int find(int begin, int le) {
        if (len[begin] == le) return begin;
        for (int i = LOG - 1; i >= 0; --i)
            if (len[fa[begin][i]] > le)
                begin = fa[begin][i];
//        assert(len[fa[begin][0]] == le);
        return fa[begin][0];
    }

    inline int get(int l, int r, int id) {
        int lens = depth[r] - depth[l];
//        return (hash[r][id] + b[id] - (hash[l][id] * base[lens][id]) % b[id]) % b[id];
        int tmp = hash[r][id] - ((ll)hash[l][id] * base[lens][id]) % b[id];
        return tmp > 0 ? tmp : tmp + b[id];
    }

    inline void comp(int x, int y) {
        if (len[x] != len[y]) {
            if (len[x] < len[y]) /*cout << "sjfnb" << endl;*/ puts("sjfnb");
            else if (len[x] > len[y]) /*cout << "cslnb" << endl;*/ puts("cslnb");
            return;
        }
        for (int i = LOG; i >= 0; --i) {
            if (get(fa[x][i], x, 0) == get(fa[y][i], y, 0)) {
                x = fa[x][i];
                y = fa[y][i];
            }
        }
        if (len[x] < len[y]) /*cout << "sjfnb" << endl;*/ puts("sjfnb");
        else if (len[x] > len[y]) /*cout << "cslnb" << endl;*/ puts("cslnb");
        else /*cout << "draw" << endl;*/ puts("draw");
    }
} fuck;

struct PAM {
    int len[MAXN];
    int link[MAXN];
    int next[MAXN][CHAR_NUM];
    int pos[MAXN];
    int last;
    int tot;

    void init() {
        last = 1;
        tot = 2;
        link[0] = len[0] = -1;
        link[1] = len[1] = 0;
        memset(next[0], 0, sizeof(next[0]));
        memset(next[1], 0, sizeof(next[1]));
    }

    void build(char *s, int n) {
        for (int i = 0; i < n; ++i) {
            int ch = s[i] - 'a';
            int p = last, curLen;
            while (true) {
                curLen = len[p];
                if (i - curLen - 1 >= 0 && s[i - curLen - 1] == s[i])
                    break;
                p = link[p];
            }
            if (next[p][ch]) {
                last = next[p][ch];
                pos[i] = last;
                continue;
            }
            last = tot++;
            pos[i] = last;
            len[last] = len[p] + 2;

            memset(next[last], 0, sizeof(next[last]));

            next[p][ch] = last;

            if (len[last] == 1) {
                link[last] = 1;
                continue;
            }

            while (true) {
                p = link[p];
                curLen = len[p];
                if (i - curLen - 1 >= 0 && s[i - curLen - 1] == s[i]) {
                    link[last] = next[p][ch];
                    break;
                }
            }
        }
    }

    void clearLCA() {
        for (int i = 0; i < tot; ++i) {
            fuck.link[i].clear();
        }
    }

    void buildLCA() {
        for (int i = 1; i < tot; ++i) {
            fuck.link[link[i]].push_back(i);
            fuck.len[i] = len[i];
        }
        fuck.build();
    }
} pam;

char ss[MAXN];

void solve() {
    int T;
//    cin >> T;
//    scanf("%d", &T);
    FastRead(T);
    fuck.init();
    for (int ts = 0; ts < T; ++ts) {
        int n;
//        cin >> n >> ss;
//        scanf("%d%s", &n, &ss);
        FastRead(n);
        scanf("%s", &ss);
        pam.init();
        pam.build(ss, n);
        pam.clearLCA();
        pam.buildLCA();
        int q;
//        cin >> q;
//        scanf("%d", &q);
        FastRead(q);
        for (int qs = 0; qs < q; ++qs) {
            int a, b, c, d;
//            cin >> a >> b >> c >> d;
//            scanf("%d%d%d%d", &a, &b, &c, &d);
            FastRead(a);
            FastRead(b);
            FastRead(c);
            FastRead(d);
            fuck.comp(fuck.find(pam.pos[b - 1], b - a + 1),
                      fuck.find(pam.pos[d - 1], d - c + 1));
        }
    }
}

signed main() {
//    ios_base::sync_with_stdio(false);
//    cin.tie(nullptr);
//    cout.tie(nullptr);
#ifdef ACM_LOCAL
    freopen("1001.in", "r", stdin);
//    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    signed test_index_for_debug = 1;
    char acm_local_for_debug = 0;
    do {
        if (acm_local_for_debug == '$') exit(0);
        if (test_index_for_debug > 20)
            throw runtime_error("Check the stdin!!!");
        auto start_clock_for_debug = clock();
        solve();
        auto end_clock_for_debug = clock();
        cout << "Test " << test_index_for_debug << " successful" << endl;
        cerr << "Test " << test_index_for_debug++ << " Run Time: "
             << double(end_clock_for_debug - start_clock_for_debug) / CLOCKS_PER_SEC << "s" << endl;
        cout << "--------------------------------------------------" << endl;
    } while (cin >> acm_local_for_debug && cin.putback(acm_local_for_debug));
#else
    solve();
#endif
    return 0;
}

更加准确的解

#pragma GCC optimize(3, "Ofast", "inline")

#include 

using namespace std;

#define MAXN 100100
#define LOG 18
#define CHAR_NUM 30             // 字符集个数,注意修改下方的 (-'a')
#define ll long long
#define HASH 2

struct FUCK {
    int fa[MAXN][LOG], depth[MAXN];
    vector<int> link[MAXN];
    int len[MAXN];
    int hash[MAXN][HASH];
    int a[HASH], b[HASH], base[MAXN][HASH];

    void dfs(int cur, int pre) {
        memset(fa[cur], 0, sizeof(fa[cur]));
        fa[cur][0] = pre;

        hash[cur][0] = (hash[pre][0] * a[0] + len[cur]) % b[0];
        hash[cur][1] = (hash[pre][1] * a[1] + len[cur]) % b[1];

        int p = 1;
        while (fa[cur][p - 1]) {
            fa[cur][p] = fa[fa[cur][p - 1]][p - 1];
            p++;
        }
        fa[cur][p] = 0;
        for (auto item : link[cur]) {
            depth[item] = depth[cur] + 1;
            dfs(item, cur);
        }
    }

    inline void init() {
        a[0] = 131;
        a[1] = 233;
        b[0] = 998244353;
        b[1] = 1e9 + 7;
        base[0][0] = 1;
        base[0][1] = 1;
        for (int i = 1; i < MAXN; ++i) {
            base[i][0] = int((ll) base[i - 1][0] * a[0] % b[0]);
            base[i][1] = int((ll) base[i - 1][1] * a[1] % b[1]);
        }
        depth[0] = 0;
        len[0] = -1;
    }

    void build() {
        queue<pair<int, int>> q;
        q.push({0, 0});
        while (!q.empty()) {
            auto cur = q.front();
            q.pop();
            memset(fa[cur.first], 0, sizeof(fa[cur.first]));
            fa[cur.first][0] = cur.second;

            hash[cur.first][0] = int(((ll) hash[cur.second][0] * a[0] + len[cur.first]) % b[0]);
            hash[cur.first][1] = int(((ll) hash[cur.second][1] * a[1] + len[cur.first]) % b[1]);

            int p = 1;
            while (fa[cur.first][p - 1]) {
                fa[cur.first][p] = fa[fa[cur.first][p - 1]][p - 1];
                p++;
            }
            fa[cur.first][p] = 0;
            for (auto item : link[cur.first]) {
                depth[item] = depth[cur.first] + 1;
                q.push({item, cur.first});
            }
        }
    }

    inline int find(int begin, int le) {
        if (len[begin] == le) return begin;
        for (int i = LOG - 1; i >= 0; --i)
            if (len[fa[begin][i]] > le)
                begin = fa[begin][i];
        return fa[begin][0];
    }

    inline int get(int l, int r, int id) {
        int lens = depth[r] - depth[l];
        int tmp = hash[r][id] - ((ll) hash[l][id] * base[lens][id]) % b[id];
        return tmp > 0 ? tmp : tmp + b[id];
    }

    inline void comp(int x, int y) {
        if (len[x] != len[y]) {
            if (len[x] < len[y]) cout << "sjfnb" << endl;
            else if (len[x] > len[y]) cout << "cslnb" << endl;
            return;
        }
        for (int i = LOG; i >= 0; --i) {
            if (get(fa[x][i], x, 0) == get(fa[y][i], y, 0) &&
                get(fa[x][i], x, 1) == get(fa[y][i], y, 1)) {
                x = fa[x][i];
                y = fa[y][i];
            }
        }
        if (len[x] < len[y]) cout << "sjfnb" << endl;
        else if (len[x] > len[y]) cout << "cslnb" << endl;
        else cout << "draw" << endl;
    }
} fuck;

struct PAM {
    int len[MAXN];
    int link[MAXN];
    int next[MAXN][CHAR_NUM];
    int pos[MAXN];
    int last;
    int tot;

    void init() {
        last = 1;
        tot = 2;
        link[0] = len[0] = -1;
        link[1] = len[1] = 0;
        memset(next[0], 0, sizeof(next[0]));
        memset(next[1], 0, sizeof(next[1]));
    }

    void build(char *s, int n) {
        for (int i = 0; i < n; ++i) {
            int ch = s[i] - 'a';
            int p = last, curLen;
            while (true) {
                curLen = len[p];
                if (i - curLen - 1 >= 0 && s[i - curLen - 1] == s[i])
                    break;
                p = link[p];
            }
            if (next[p][ch]) {
                last = next[p][ch];
                pos[i] = last;
                continue;
            }
            last = tot++;
            pos[i] = last;
            len[last] = len[p] + 2;

            memset(next[last], 0, sizeof(next[last]));

            next[p][ch] = last;

            if (len[last] == 1) {
                link[last] = 1;
                continue;
            }

            while (true) {
                p = link[p];
                curLen = len[p];
                if (i - curLen - 1 >= 0 && s[i - curLen - 1] == s[i]) {
                    link[last] = next[p][ch];
                    break;
                }
            }
        }
    }

    void clearLCA() {
        for (int i = 0; i < tot; ++i) {
            fuck.link[i].clear();
        }
    }

    void buildLCA() {
        for (int i = 1; i < tot; ++i) {
            fuck.link[link[i]].push_back(i);
            fuck.len[i] = len[i];
        }
        fuck.build();
    }
} pam;

char ss[MAXN];

void solve() {
    int T;
    cin >> T;
//    scanf("%d", &T);
//    FastRead(T);
    fuck.init();
    for (int ts = 0; ts < T; ++ts) {
        int n;
        cin >> n >> ss;
//        scanf("%d%s", &n, &ss);
//        FastRead(n);
//        scanf("%s", &ss);
        pam.init();
        pam.build(ss, n);
        pam.clearLCA();
        pam.buildLCA();
        int q;
        cin >> q;
//        scanf("%d", &q);
//        FastRead(q);
        for (int qs = 0; qs < q; ++qs) {
            int a, b, c, d;
            cin >> a >> b >> c >> d;
//            scanf("%d%d%d%d", &a, &b, &c, &d);
//            FastRead(a);
//            FastRead(b);
//            FastRead(c);
//            FastRead(d);
            fuck.comp(fuck.find(pam.pos[b - 1], b - a + 1),
                      fuck.find(pam.pos[d - 1], d - c + 1));
        }
    }
}

signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
#ifdef ACM_LOCAL
    freopen("1001.in", "r", stdin);
//    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    signed test_index_for_debug = 1;
    char acm_local_for_debug = 0;
    do {
        if (acm_local_for_debug == '$') exit(0);
        if (test_index_for_debug > 20)
            throw runtime_error("Check the stdin!!!");
        auto start_clock_for_debug = clock();
        solve();
        auto end_clock_for_debug = clock();
        cout << "Test " << test_index_for_debug << " successful" << endl;
        cerr << "Test " << test_index_for_debug++ << " Run Time: "
             << double(end_clock_for_debug - start_clock_for_debug) / CLOCKS_PER_SEC << "s" << endl;
        cout << "--------------------------------------------------" << endl;
    } while (cin >> acm_local_for_debug && cin.putback(acm_local_for_debug));
#else
    solve();
#endif
    return 0;
}

你可能感兴趣的:(ACM,#,树上倍增,#,字符串)