八数码问题

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

typedef int state_t[9];

const int dx[] = {-1, 1, 0, 0};
const int dy[] = {0, 0, -1, 1};
const int MAX_STATE = 1000000;
const int HASH_SIZE = 1000003;
state_t st[MAX_STATE], goal;
int dist[MAX_STATE];
int head[HASH_SIZE], next[MAX_STATE];

void init() {
    memset(dist, 0, sizeof (dist));
}

void initLookupTable() {
    memset(head, 0, sizeof (head));
}

int _hash(state_t & s) {
    int v = 0;
    for (int i = 0; i < 9; i++)
        v = v * 10 + s[i];
    return v % HASH_SIZE;
}

int tryToInsert(int s) {
    int h = _hash(st[s]);
    int u = head[h];
    while (u) {
        if (memcmp(st[u], st[s], sizeof (st[s])) == 0)
            return 0;
        u = next[u];
    }
    next[s] = head[h];
    head[h] = s;
    return 1;
}

int bfs() {
    initLookupTable();
    int frt = 1, rear = 2;
    while (frt < rear) {
        state_t & s = st[frt];
        if (memcmp(goal, s, sizeof (s)) == 0)
            return frt;
        int z;
        for (z = 0; z < 9; z++)
            if (s[z] == 0)
                break;
        int x = z / 3, y = z % 3;
        for (int d = 0; d < 4; d++) {
            int newx = x + dx[d];
            int newy = y + dy[d];
            if (newx >= 0 && newx < 3 &&
                    newy >=0 && newy < 3) {
                int newz = newx * 3 + newy;
                state_t & t = st[rear];
                memcpy(&t, &s, sizeof (s));
                t[newz] = s[z];
                t[z] = s[newz];
                dist[rear] = dist[frt] + 1;
                if (tryToInsert(rear))
                    rear = (rear + 1) % MAX_STATE;
            }
        }
        frt = (frt + 1) % MAX_STATE;
    }
    return 0;
}

int main() {
    for (int i = 0; i < 9; i++)
        scanf("%d", &st[1][i]);
    for (int i = 0; i < 9; i++)
        scanf("%d", &goal[i]);
    init();
    initLookupTable();
    int ans = bfs();
    if (ans > 0)
        printf("%d\n", dist[ans]);
    else
        printf("-1\n");
    return 0;
}

你可能感兴趣的:(八数码问题)