HDU 4436 str2int 后缀数组 + 前缀和预处理 或 后缀自动机

题目大意:

对于给出的n个字符串(n <= 1e5), 每个字符串包含字符'0' ~ '9', 对于每个字符串的子串, 对应一个整数值, 求所有n个字符串的所有子串转换成整数后出现的不同的整数的和对 2012 取模的值, 每组测试数据的所有字符串的总长度不超过1e6

例如串"101"子串有 1, 10, 101, 0, 01, 1, 对应的不同整数是1, 10, 101 (0对求和没有影响, 可以略去)


大致思路:

首先第一眼看得出和后缀数组有关, 但是将所有的串连接起来之后, 有一些计数上的细节需要考虑, 首先如果后缀以'0'开头则不需要加入计数(这样的后缀的所有子串对应的整数一定会在其他的后缀当中出现, 或者子串对应的值为0, 没有影响), 对于一个后缀数组sa, 从第一个到最后一个, 如果肯定需要计数的是 sa[i] + height[i] ~ i所在字符串的结尾位置,那么这样的字符串如何计数呢? 很明显如果对于每一个后缀从sa[i]开始的话, 依次计算sa[i], sa[i]sa[i] + 1...对应的子串转换成整数的值的话肯定会超时, 所以需要进行预处理

定义前缀和sigma[i] = s0 + s0s1 + s0s1s2 + ... + (s0s1s2s2s4...si)这一连续的整数值定义rest[i] = s0s1s2s3s4...si

也就是说sigma[i]是rest[i]的前缀和

例如对于字符串 s = “12345” sigma[0] = 1, sigma[1] = 1 + 12, sigma[3] = 1 + 12 + 123...

rest[0] = 1, rest[1] = 12, rest[3] = 123...

对于用未出现字符隔开的连起来的总串, 预处理出sigma, rest数组, 用tens[i] 表示(∑10^j) % 2012 (1 <= j <= i)

就和容易找到连续的起点在sa[i], 终点在sa[i] + height[i] ~i所在字符串的整数值的和了

例如对于“12345” 查询3 + 34 + 345那么 3 + 34 + 345 = (123 + 1234 + 12345) - 12*(10 + 100 + 1000) = (sigma[4] - sigma[1] + 2012 - rest[1]*tens[3]) % 2012

想到这个前缀和的关系剩下的就很好做了


另外还有后缀自动机的做法:

将所有串中间中10隔开连接起来建立后缀自动机, 然后从根节点开始按照拓扑序向下遍历, 计算出到达每一个结点的方案数(不能沿着10走, 根节点还不能沿着0走), 然后对于状态 s 经过边 j 到达 t 状态, t 状态中从s转移来的字符串的贡献是 s状态中字符串的贡献*10 + j*s中不同满足条件的字符串数量, 拓扑序遍历即可


代码如下:

后缀数组解法:

Result  :  Accepted     Memory  :  17392 KB     Time  :  187 ms

/*
 * Author: Gatevin
 * Created Time:  2015/3/9 15:33:02
 * File Name: Kotori_Itsuka.cpp
 */
#include<cstdio>
#include<cstring>//之前写了很多头文件导致rank数组模糊定义了..看来HDU上交后缀数组还是要注意一下
using namespace std;
const double eps(1e-8);
typedef long long lint;

const int mod = 2012;

#define maxn 1000010

int wa[maxn], wb[maxn], wv[maxn], Ws[maxn];

int cmp(int *r, int a, int b, int l)
{
    return r[a] == r[b] && r[a + l] == r[b + l];
}

void da(int *r, int *sa, int n, int m)
{
    int *x = wa, *y = wb, *t, i, j, p;
    for(i = 0; i < m; i++) Ws[i] = 0;
    for(i = 0; i < n; i++) Ws[x[i] = r[i]]++;
    for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
    for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i;
    for(j = 1, p = 1; p < n; j *= 2, m = p)
    {
        for(p = 0, i = n - j; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
        for(i = 0; i < n; i++) wv[i] = x[y[i]];
        for(i = 0; i < m; i++) Ws[i] = 0;
        for(i = 0; i < n; i++) Ws[wv[i]]++;
        for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
        for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i];
        for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
    return;
}

int rank[maxn], height[maxn];
void calheight(int *r, int *sa, int n)
{
    int i, j, k = 0;
    for(i = 1; i <= n; i++) rank[sa[i]] = i;
    for(i = 0; i < n; height[rank[i++]] = k)
        for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);
    return;
}

int n, N;
char in[maxn];
int s[maxn], sa[maxn];
int end[maxn];
int sigma[maxn];
int rest[maxn];
int tens[maxn];//tens[i] = (∑10^j) % mod (1 <= j <= i)

int main()
{
    while(scanf("%d", &n) != EOF)
    {
        N = 0;
        for(int i = 0; i < n; i++)
        {
            scanf("%s", in);
            int tmp = strlen(in);
            int pos = N + tmp - 1;
            for(int j = 0; j < tmp; j++)
            {
                end[N] = pos;
                s[N++] = in[j] - '0' + 1;
            }
            end[N] = pos;
            s[N++] = 11;
        }
        N--;
        s[N] = 0;
        da(s, sa, N + 1, 12);
        calheight(s, sa, N);
        memset(sigma, 0, sizeof(sigma));
        memset(rest, 0, sizeof(rest));
        memset(tens, 0, sizeof(tens));
        rest[0] = (s[0] - 1) % mod;
        sigma[0] = rest[0];
        tens[0] = 0;
        int ten = 1;
        for(int i = 1; i <= N; i++)
        {
            rest[i] = (rest[i - 1] * 10 + s[i] - 1) % mod;
            sigma[i] = (sigma[i - 1] + rest[i]) % mod;
            ten = ten*10 % mod;
            tens[i] = (tens[i - 1] + ten) % mod;
        }
        int ans = 0;
        for(int i = 1; i <= N; i++)
        {
            if(s[sa[i]] == 1) continue;
            int start = sa[i] + height[i];
            int tail = end[sa[i]];
            if(tail < start) continue;
            if(start == 0)
                ans = (ans + sigma[tail]) % mod;
            else
                ans = (ans + sigma[tail] - sigma[start - 1] + mod - rest[sa[i] - 1] * (tens[tail - sa[i] + 1] - tens[start - sa[i]] + mod) % mod + mod) % mod;
        }
        printf("%d\n", ans);
    }
    return 0;
}

后缀自动机的做法:

Result  :  Accepted     Memory  :  16732 KB     Time  :  249 ms

/*
 * Author: Gatevin
 * Created Time:  2015/4/16 13:04:40
 * File Name: Rin_Tohsaka.cpp
 */
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
#define foreach(e, x) for(__typeof(x.begin()) e = x.begin(); e != x.end(); ++e)
#define SHOW_MEMORY(x) cout<<sizeof(x)/(1024*1024.)<<"MB"<<endl

#define maxn 222000
#define maxm 111000

const int mod = 2012;

struct Suffix_Automation
{
    struct State
    {
        State *par;
        State *go[12];
        int right, val, mi, sum, cnt;
        void init(int _val)
        {
            par = 0, val = _val, right = mi = sum = cnt = 0;
            memset(go, 0, sizeof(go));
        }
    };
    State *root, *last, *cur;
    State nodePool[maxn];
    State* newState(int val = 0)
    {
        cur->init(val);
        return cur++;
    }
    void initSAM()
    {
        cur = nodePool;
        root = newState();
        last = root;
    }
    void extend(int w, int len)
    {
        State *p = last;
        State *np = newState(p->val + 1);
        np->right = 1;
        while(p && p->go[w] == 0)
        {
            p->go[w] = np;
            p = p->par;
        }
        if(p == 0)
        {
            np->par = root;
        }
        else
        {
            State *q = p->go[w];
            if(q->val == p->val + 1)
            {
                np->par = q;
            }
            else
            {
                State *nq = newState(p->val + 1);
                memcpy(nq->go, q->go, sizeof(q->go));
                nq->par = q->par;
                q->par = nq;
                np->par = nq;
                while(p && p->go[w] == q)
                {
                    p->go[w] = nq;
                    p = p->par;
                }
            }
        }
        last = np;
    }
    int d[maxm];
    State *b[maxn];
    void topo()
    {
        int cnt = cur - nodePool;
        int maxVal = 0;
        memset(d, 0, sizeof(d));
        for(int i = 1; i < cnt; i++)
            maxVal = max(maxVal, nodePool[i].val), d[nodePool[i].val]++;
        for(int i = 1; i <= maxVal; i++) d[i] += d[i - 1];
        for(int i = 1; i < cnt; i++) b[d[nodePool[i].val]--] = &nodePool[i];
        b[0] = root;
    }
    /*
    void SAMInfo()
    {
        int cnt = cur - nodePool;
        State *p;
        for(int i = cnt - 1; i > 0; i--)
        {
            p = b[i];
            p->par->right += p->right;
            p->mi = p->par->val + 1;
        }
    }
    */
};

Suffix_Automation sam;
char s[maxn];
int pre[maxn];

int main()
{
    int n;
    while(~scanf("%d", &n))
    {
        sam.initSAM();
        while(n--)
        {
            scanf("%s", s);
            int len = strlen(s);
            for(int i = 0; i < len; i++)
                sam.extend(s[i] - '0', i + 1);
            sam.extend(10, -1);
        }
        sam.topo();
        //sam.SAMInfo();
        int ans = 0;
        int cnt = sam.cur - sam.nodePool;
        sam.b[0]->cnt = 1;
        for(int i = 0; i < cnt; i++)
        {
            for(int j = 0; j < 10; j++)//不沿着10走
            {
                if(!i && !j) continue;
                if(!sam.b[i]->go[j]) continue;
                sam.b[i]->go[j]->sum = (sam.b[i]->go[j]->sum + sam.b[i]->sum*10 + sam.b[i]->cnt*j) % mod;
                sam.b[i]->go[j]->cnt = (sam.b[i]->go[j]->cnt + sam.b[i]->cnt) % mod;
            }
            ans = (ans + sam.b[i]->sum) % mod;
        }
        printf("%d\n", ans);
    }
    return 0;
}



你可能感兴趣的:(后缀数组,HDU,后缀自动机,4436,str2int,前缀和预处理)