zju3545

AC自动机+状态压缩DP

注意:相同的串可能出现多次,如果匹配成功则将各次权值加和。

zju3545
#include <cstdio>

#include <queue>

#include <cstring>

using namespace std;



#define D(x) 



const int MAX_N = 15;

const int MAX_LEN = 105;

const int MAX_CHILD_NUM = 4;

const int MAX_NODE_NUM = MAX_LEN * MAX_N;



//1.init() 2.insert() 3.build() 4.query()

struct Trie

{

    int next[MAX_NODE_NUM][MAX_CHILD_NUM];

    int fail[MAX_NODE_NUM];

    int count[MAX_NODE_NUM];

    int node_cnt;

    bool vis[MAX_NODE_NUM]; //set it to false

    int root;



    void init()

    {

        node_cnt = 0;

        root = newnode();

        memset(vis, 0, sizeof(vis));

    }



    int newnode()

    {

        for (int i = 0; i < MAX_CHILD_NUM; i++)

            next[node_cnt][i] = -1;

        count[node_cnt++] = 0;

        return node_cnt - 1;

    }



    int get_id(char a)

    {

        if (a == 'A')

            return 0;

        if (a == 'T')

            return 1;

        if (a == 'C')

            return 2;

        return 3;

    }



    void insert(char buf[], int id)

    {

        int now = root;

        for (int i = 0; buf[i]; i++)

        {

            int id = get_id(buf[i]);

            if (next[now][id] == -1)

                next[now][id] = newnode();

            now = next[now][id];

        }

        count[now] |= (1 << id);

    }



    void build()

    {

        queue<int>Q;

        fail[root] = root;

        for (int i = 0; i < MAX_CHILD_NUM; i++)

            if (next[root][i] == -1)

                next[root][i] = root;

            else

            {

                fail[next[root][i]] = root;

                Q.push(next[root][i]);

            }

        while (!Q.empty())

        {

            int now = Q.front();

            Q.pop();

            for (int i = 0; i < MAX_CHILD_NUM; i++)

                if (next[now][i] == -1)

                    next[now][i] = next[fail[now]][i];

                else

                {

                    fail[next[now][i]]=next[fail[now]][i];

                    count[next[now][i]] |= count[fail[next[now][i]]];

                    Q.push(next[now][i]);

                }

        }

    }



    int query(char buf[])

    {

        int now = root;

        int res = 0;



        memset(vis, 0, sizeof(vis));

        for (int i = 0; buf[i]; i++)

        {

            now = next[now][get_id(buf[i])];

            int temp = now;

            while (temp != root && !vis[temp])

            {

                res += count[temp];

                 // optimization: prevent from searching this fail chain again.

                //also prevent matching again.

                vis[temp] = true;

                temp = fail[temp];

            }

        }

        return res;

    }



    void debug()

    {

        for(int i = 0;i < node_cnt;i++)

        {

            printf("id = %3d,fail = %3d,end = %3d,chi = [",i,fail[i],count[i]);

            for(int j = 0;j < MAX_CHILD_NUM;j++)

                printf("%2d",next[i][j]);

            printf("]\n");

        }

    }

}ac;



const int MAX_STATUS = (1 << 10) + 20;



int n, len;

char st[MAX_LEN];

int w[MAX_N];

bool dp[2][MAX_NODE_NUM][MAX_STATUS];



int cal(int status)

{

    int ret = 0;

    for (int i = 0; i < n; i++)

    {

        if (status & (1 << i))

            ret += w[i];

    }

    return ret;

}



int work()

{

    int ret = -1;

    memset(dp, 0, sizeof(dp));

    dp[0][ac.root][0] = true;

    for (int i = 0; i < len; i++)

    {

        for (int j = 0; j < ac.node_cnt; j++)

            for (int status = 0; status < (1 << n); status++)

                dp[(i + 1) & 1][j][status] = false;

        D(printf("%d\n", dp[(i + 1) & 1][2][0]));

        for (int j = 0; j < ac.node_cnt; j++)

        {

            for (int status = 0; status < (1 << n); status++)

            {

                if (!dp[i & 1][j][status])

                    continue;

                for (int k = 0; k < 4; k++)

                {

                    int v = ac.next[j][k];

                    dp[(i + 1) & 1][v][status | ac.count[v]] = true;

                    D(printf("%d %d\n", j, status));

                    D(printf("%d %d %d %d\n", (i + 1) & 1, v, status | ac.count[v], dp[(i + 1) & 1][0][1044]));

                }

            }

        }

    }

    for (int i = 0; i < ac.node_cnt; i++)

        for (int status = 0; status < (1 << n); status++)

        {

            if (dp[len & 1][i][status])

            {

                if (dp[len & 1][i][0] && status == 0)

                {

                    D(printf("*%d %d\n", i, ac.count[i]));

                }

                ret = max(ret, cal(status));

            }

        }

    return ret;

}



int main()

{

    while (scanf("%d%d", &n, &len) != EOF)

    {

        ac.init();

        for (int i = 0; i < n; i++)

        {

            scanf("%s%d", st, &w[i]);

            ac.insert(st, i);

        }

        ac.build();

        int ans = work();

        if (ans < 0)

            puts("No Rabbit after 2012!");

        else

            printf("%d\n", ans);

    }

    return 0;

}
View Code

 

你可能感兴趣的:(zju3545)