[ZOJ 3817 Chinese Knot] 字符串hash+DP

题目

http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3817

分析

字符串hash+DP

将knot loop上的字符串从两个方向hash,对给定的字符串s从前往后hashLeft

然后设f[i][j]表示从s的第i个字符往后的部分从第j个中心点在其所在的loop上能不能匹配

则f[i][j]容易递推

然后枚举knot loop上的一个位置,利用f[i][j]可以很容易的判断是否有解和找出解

代码

/**************************************************
 *        Problem:  ZOJ 3817
 *         Author:  clavichord93
 *          State:  Accepted
 **************************************************/

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;

const int MAX_N = 100005;

int n;
char s[4][MAX_N];
long long hashLeft[MAX_N][4];
long long hashRight[MAX_N][4];

int m;
char pattern[MAX_N];
long long hash[MAX_N];

long long pow31[MAX_N * 4];

bool f[MAX_N][8];
int ans[MAX_N];

void findPath(int i, int j, int dir) {
    for (int cur = 1; cur <= m; cur++) {
        ans[cur] = i * n + j;
        if (!dir) {
            j++;
            if (j == n + 1) {
                for (int k = 0; k < 8; k++) {
                    if (k != 2 * i + 1 && f[cur + 1][k]) {
                        dir = k & 1;
                        i = k / 2;
                        j = dir ? n : 1;
                        break;
                    }
                }
            }
        }
        else {
            j--;
            if (j == 0) {
                for (int k = 0; k < 8; k++) {
                    if (k != 2 * i && f[cur + 1][k]) {
                        dir = k & 1;
                        i = k / 2;
                        j = dir ? n : 1;
                        break;
                    }
                }
            }
        }
    }
    printf("%d", ans[1]);
    for (int i = 2; i <= m; i++) {
        printf(" %d", ans[i]);
    }
    printf("\n");
}

int main() {
    #ifdef LOCAL_JUDGE
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #endif

    pow31[0] = 1;
    for (int i = 1; i <= 400000; i++) {
        pow31[i] = pow31[i - 1] * 31;
    }

    int T;
    scanf("%d", &T);
    for (int cas = 1; cas <= T; cas++) {
        memset(hash, 0, sizeof(hash));
        memset(hashLeft, 0, sizeof(hashLeft));
        memset(hashRight, 0, sizeof(hashRight));
        memset(f, 0, sizeof(f));

        scanf("%d %d", &n, &m);
        for (int i = 0; i < 4; i++) {
            scanf("%s", s[i] + 1);
            for (int j = 1; j <= n; j++) {
                hashLeft[j][i] = hashLeft[j - 1][i] + s[i][j] * pow31[j - 1];
            }
            for (int j = n; j >= 1; j--) {
                hashRight[n - j + 1][i] = hashRight[n - j][i] + s[i][j] * pow31[n - j];
            }
        }
        scanf("%s", pattern + 1);
        for (int i = 1; i <= m; i++) {
            hash[i] = hash[i - 1] + pattern[i] * pow31[i - 1];
        }

        for (int i = m; i >= 1; i--) {
            for (int j = 0; j < 4; j++) {
                int length = m - i + 1;

                if (length <= n) {
                    long long hashPattern = hash[m] - hash[i - 1];

                    long long hashValueL = hashLeft[length][j] * pow31[i - 1];
                    if (hashValueL == hashPattern) {
                        f[i][2 * j] = 1;
                    }

                    long long hashValueR = hashRight[length][j] * pow31[i - 1];
                    if (hashValueR == hashPattern) {
                        f[i][2 * j + 1] = 1;
                    }
                }
                else {
                    long long hashPattern = hash[i + n - 1] - hash[i - 1];
                    
                    long long hashValueL = hashLeft[n][j] * pow31[i - 1];
                    if (hashValueL == hashPattern) {
                        for (int k = 0; k < 8; k++) {
                            if (k != 2 * j + 1 && f[i + n][k]) {
                                f[i][2 * j] = 1;
                                break;
                            }
                        }
                    }

                    long long hashValueR = hashRight[n][j] * pow31[i - 1];
                    if (hashValueR == hashPattern) {
                        for (int k = 0; k < 8; k++) {
                            if (k != 2 * j && f[i + n][k]) {
                                f[i][2 * j + 1] = 1;
                                break;
                            }
                        }
                    }
                }
            }
        }

        //for (int i = 1; i <= m; i++) {
            //for (int j = 0; j < 8; j++) {
                //printf("%d ", f[i][j]);
            //}
            //printf("\n");
        //}

        bool ans = 0;
        for (int i = 0; i < 4; i++) {
            for (int j = 1; j <= n; j++) {
                int length = n - j + 1;
                
                if (length >= m) {
                    long long hashValue = hashLeft[j + m - 1][i] - hashLeft[j - 1][i];
                    long long hashPattern = hash[m] * pow31[j - 1];
                    if (hashValue == hashPattern) {
                        ans = 1;
                        findPath(i, j, 0);
                    }
                }
                else {
                    long long hashValue = hashLeft[n][i] - hashLeft[j - 1][i];
                    long long hashPattern = hash[length] * pow31[j - 1];
                    if (hashValue == hashPattern) {
                        for (int k = 0; k < 8; k++) {
                            if (k != 2 * i + 1 && f[length + 1][k]) {
                                ans = 1;
                                findPath(i, j, 0);
                                break;
                            }
                        }
                    }
                }

                if (ans) {
                    break;
                }

                length = j;
                if (length >= m) {
                    long long hashValue = hashRight[n - j + m][i] - hashRight[n - j][i];
                    long long hashPattern = hash[m] * pow31[n - j];
                    if (hashValue == hashPattern) {
                        ans = 1;
                        findPath(i, j, 1);
                    }
                }
                else {
                    long long hashValue = hashRight[n][i] - hashRight[n - j][i];
                    long long hashPattern = hash[length] * pow31[n - j];
                    if (hashValue == hashPattern) {
                        for (int k = 0; k < 8; k++) {
                            if (k != 2 * i && f[length + 1][k]) {
                                ans = 1;
                                findPath(i, j, 1);
                                break;
                            }
                        }
                    }
                }

                if (ans) {
                    break;
                }
            }
            
            if (ans) {
                break;
            }
        }

        if (!ans) {
            printf("No solution!\n");
        }
    }

    return 0;
}


你可能感兴趣的:([ZOJ 3817 Chinese Knot] 字符串hash+DP)