poj1185

状态压缩dp

poj1185
#include <cstdio>

#include <cstring>

#include <cstdlib>

#include <iostream>

using namespace std;



#define MAX_COL_NUM 15

#define MAX_ROW_NUM 105

#define MAX_STAT 2050



struct Status

{

    int row_bit[3];

    int last_bit_count;

}status[100000];



int row_num, col_num;

int grid[MAX_ROW_NUM];

int f[3][MAX_STAT][MAX_STAT];

int ans;

int status_cnt;



void input()

{

    scanf("%d%d", &row_num, &col_num);

    for (int i = 0; i < row_num; i++)

    {

        char row[MAX_COL_NUM];

        scanf("%s", row);

        int row_bit = 0;

        for (int j = 0; row[j]; j++)

        {

            if (row[j] == 'H')

            {

                row_bit = (row_bit << 1) | 1;

            }else

            {

                row_bit = (row_bit << 1) | 0;

            }

        }

        grid[i] = row_bit;

    }

}



int count_bit(int a)

{

    int ret = 0;

    while (a > 0)

    {

        if (a & 1)

            ret++;

        a >>= 1;

    }

    return ret;

}



bool ok(int bit)

{

    while (bit > 0)

    {

        if ((bit & 3) == 3)

            return false;

        if ((bit & 5) == 5)

            return false;

        bit >>= 1;

    }

    return true;

}



bool ok(int bit, int row)

{

    return (bit & grid[row]) == 0;

}



void init()

{

    memset(f, 0, sizeof(f));

    int col_stat = 1;

    for (int i = 0; i < col_num; i++)

        col_stat *= 3;

    for (int i = 0; i < col_stat; i++)

    {

        int temp = i;

        int row_bit[3];

        row_bit[1] = row_bit[0] = 0;

        for (int j = 0; j < col_num; j++)

        {

            row_bit[0] <<= 1;

            row_bit[1] <<= 1;

            row_bit[temp % 3] |= 1;

            temp /= 3;

        }

        if (!ok(row_bit[0], 0) || !ok(row_bit[1], 1) || !ok(row_bit[0]) || !ok(row_bit[1]))

        {

            f[0][row_bit[0]][row_bit[1]] = 0;

            continue;

        }

        f[0][row_bit[0]][row_bit[1]] = count_bit(row_bit[0]) + count_bit(row_bit[1]);

        ans = max(ans, f[0][row_bit[0]][row_bit[1]]);

    }    



}



void make_stat()

{

    int col_stat = 1;

    status_cnt = 0;

    for (int i = 0; i < col_num; i++)

        col_stat *= 4;

    for (int i = 0; i < col_stat; i++)

    {

        int temp = i;

        int row_bit[4];

        row_bit[2] = row_bit[1] = row_bit[0] = 0;

        for (int j = 0; j < col_num; j++)

        {

            row_bit[0] <<= 1;

            row_bit[1] <<= 1;

            row_bit[2] <<= 1;

            row_bit[temp % 4] |= 1;

            temp /= 4;

        }

        bool conflict = false;

        for (int j = 0; j < 3; j++)

            if (!ok(row_bit[j]))

            {

                conflict = true;

                break;

            }

        if (conflict)

            continue;

        status[status_cnt].last_bit_count = count_bit(row_bit[2]);

        for (int j = 0; j < 3; j++)

            status[status_cnt].row_bit[j] = row_bit[j];

        status_cnt++;

    }

}



void work()

{

    for (int i = 0; i < row_num - 2; i++)

    {

        for (int j = 0; j < status_cnt; j++)

        {

            int a = status[j].row_bit[0];

            int b = status[j].row_bit[1];

            int c = status[j].row_bit[2];

            bool conflict = false;

            for (int k = 0; k < 3; k++)

                if (!ok(status[j].row_bit[k], i + k))

                {

                    conflict = true;

                    break;

                }

            if (conflict)

                continue;

            f[(i + 1) % 3][b][c] = max(f[(i + 1) % 3][b][c], f[i % 3][a][b] + status[j].last_bit_count);

            ans = max(ans, f[(i + 1) % 3][b][c]);

        }

    }

}



int main()

{

    input();

    if (row_num == 1)

    {

        ans = 0;

        for (int i = 0; i < (1 << col_num); i++)

        {

            if (ok(i) && ok(i, 0))

                ans = max(ans, count_bit(i));

        }

        printf("%d\n", ans);

        return 0;

    }

    ans = 0;

    init();

    make_stat();

    work();

    printf("%d\n", ans);

    return 0;

}
View Code

 

你可能感兴趣的:(poj)