[手动搬家自网易博客 原发表日期:
2009-07-24 15:13 ]
刚在网上查了一下usaco prime3,发现还是有不少优化,可惜我都没用,只是写了一个比较暴力的数据结构(好像叫字典树?)+明显的剪枝。程序写了320行(包括一些debug用的东西),汗……
无辜的代码飘过:
/* * $File: prime3.cpp * $Date: Fri Jul 24 14:58:57 2009 ID: gy_jk2 LANG: C++ TASK: prime3 */ #define INPUT "prime3.in" #define OUTPUT "prime3.out" //#define DEBUG 1 #ifdef DEBUG #include <iostream> using namespace std; #endif #include <cstdio> #include <cassert> #include <cstring> namespace Solve { struct Tree_node { int val; Tree_node *child, *brother, *parent; Tree_node(int v, Tree_node *b = NULL, Tree_node *p = NULL) : val(v), child(NULL), brother(b), parent(p) {} ~Tree_node() {if (child != NULL) delete child; if (brother != NULL) delete brother;} }; Tree_node *prime_tree; void insert_into_tree(int num, int rest_depth, Tree_node *root); // insert @num into @root in !reverse! order (e.g, 123 -> 321) Tree_node* find_in_cur_level(int num, Tree_node *root); // return NULL if nothing found. #ifdef DEBUG void print_tree(Tree_node *root, int n) { if (root->child == NULL) cout << n << endl; for (Tree_node *ptr = root->child; ptr != NULL; ptr = ptr->brother) print_tree(ptr, n * 10 + ptr->val); } #endif const int PRIME_RANGE_BEGIN = 10000, PRIME_RANGE_END = PRIME_RANGE_BEGIN * 10; void init_prime(int sum); int square[5][5]; Tree_node *root_row[5], *root_diagonal[2]; int digit_usable[5][5]; void dfs(int row, int col, Tree_node *root); const char *NO_ANSWER = "NONE/n"; struct Answer { int square[5][5]; Answer *next; Answer(Answer *n) : next(n) {memcpy(this->square, Solve::square, sizeof(Solve::square));} bool operator < (const Answer &n); }; Answer *answer, **answer_array; int nanswer; void solve(FILE *fin, FILE *fout); } void Solve::solve(FILE *fin, FILE *fout) { int n0, sum; assert(fscanf(fin, "%d %d", ∑, &n0) == 2); init_prime(sum); #ifdef DEBUG print_tree(prime_tree, 0); #endif for (int i = 0; i < 5; i ++) { for (int j = 0; j < 5; j ++) if (i == 0 && j == 0) digit_usable[i][j] = 1 << n0; else digit_usable[i][j] = 0xFFF; root_row[i] = prime_tree; } root_diagonal[0] = root_diagonal[1] = prime_tree; dfs(0, 0, prime_tree); if (answer != NULL) { answer_array = new Answer*[nanswer]; int pos = 0; for (Answer *ptr = answer; ptr != NULL; ptr = ptr->next) answer_array[pos ++] = ptr; for (int i = 1; i < nanswer; i ++) { Answer *key = answer_array[i]; pos = i - 1; while (pos >= 0 && (*key) < (*answer_array[pos])) { answer_array[pos + 1] = answer_array[pos]; pos --; } answer_array[pos + 1] = key; } for (int i = 0; i < nanswer; i ++) { if (i) fprintf(fout, "/n"); for (int x = 0; x < 5; x ++) { for (int y = 0; y < 5; y ++) fprintf(fout, "%d", answer_array[i]->square[x][y]); fprintf(fout, "/n"); } delete answer_array[i]; } delete []answer_array; } else fprintf(fout, "%s", NO_ANSWER); delete prime_tree; } bool Solve::Answer::operator < (const Answer &n) { for (int i = 0; i < 5; i ++) for (int j = 0; j < 5; j ++) if (square[i][j] < n.square[i][j]) return true; else if (square[i][j] > n.square[i][j]) return false; return false; } Solve::Tree_node* Solve::find_in_cur_level(int num, Tree_node *root) { if (root == NULL) return NULL; Tree_node *pos = root->child; while (pos != NULL && pos->val < num) pos = pos->brother; if (pos == NULL || pos->val > num) return NULL; return pos; } void Solve::dfs(int row, int col, Tree_node *root) { if (row == 5) { row = 0; col ++; if (col == 5) { answer = new Answer(answer); nanswer ++; return; } root = prime_tree; } if (col == 4) { int du = digit_usable[row][col]; int &sq = square[row][col]; for (Tree_node *ptr = root->child; ptr != NULL; ptr = ptr->brother) if (du & (1 << ptr->val)) { sq = ptr->val; dfs(row + 1, col, ptr); } return; } int du = digit_usable[row][col]; int &sq = square[row][col]; bool d0 = (row == col), d1 = (row + col == 4); int u0 = digit_usable[row][col + 1], u1 = 0, u2 = 0; if (d0) u1 = digit_usable[row + 1][col + 1]; if (d1) u2 = digit_usable[row - 1][col + 1]; for (Tree_node *ptr = root->child; ptr != NULL; ptr = ptr->brother) if (du & (1 << ptr->val)) { Tree_node *r0 = find_in_cur_level(sq = ptr->val, root_row[row]), *r1 = NULL, *r2 = NULL; if (r0 == NULL) continue; if (d0) { if ((r1 = find_in_cur_level(ptr->val, root_diagonal[0])) == NULL) continue; } if (d1) { if ((r2 = find_in_cur_level(ptr->val, root_diagonal[1])) == NULL) continue; } root_row[row] = r0; int mask = 0; for (Tree_node *p = r0->child; p != NULL; p = p->brother) mask |= 1 << p->val; if ((digit_usable[row][col + 1] = u0 & mask) == 0) { root_row[row] = r0->parent; continue; } if (d0) { root_diagonal[0] = r1; mask = 0; for (Tree_node *p = r1->child; p != NULL; p = p->brother) mask |= 1 << p->val; if ((digit_usable[row + 1][col + 1] = u1 & mask) == 0) { root_row[row] = r0->parent; root_diagonal[0] = r1->parent; continue; } } if (d1) { root_diagonal[1] = r2; mask = 0; for (Tree_node *p = r2->child; p != NULL; p = p->brother) mask |= 1 << p->val; if ((digit_usable[row - 1][col + 1] = u2 & mask) == 0) { root_row[row] = r0->parent; if (d0) root_diagonal[0] = r1->parent; root_diagonal[1] = r2->parent; continue; } } dfs(row + 1, col, ptr); root_row[row] = r0->parent; if (d0) root_diagonal[0] = r1->parent; if (d1) root_diagonal[1] = r2->parent; } digit_usable[row][col + 1] = u0; if (d0) digit_usable[row + 1][col + 1] = u1; if (d1) digit_usable[row - 1][col + 1] = u2; } void Solve::insert_into_tree(int num, int rest_depth, Tree_node *root) { if (rest_depth == 0) return; Tree_node *pos = root->child, *last = NULL; int digit = num % 10; while (pos != NULL && pos->val < digit) { last = pos; pos = pos->brother; } if (pos != NULL && pos->val == digit) insert_into_tree(num / 10, rest_depth - 1, pos); else if (last == NULL) insert_into_tree(num / 10, rest_depth - 1, root->child = new Tree_node(digit, root->child, root)); else insert_into_tree(num / 10, rest_depth - 1, last->brother = new Tree_node(digit, pos, root)); } void Solve::init_prime(int sum) { unsigned int tmp[PRIME_RANGE_END / 32 + 1] = {0}; prime_tree = new Tree_node(-1); #ifdef DEBUG const int P[] = {11351, 14033, 30323, 53201, 13313, 14303, 13331}; for (int i = 0; i < 7; i ++) { int n1 = 0; for (int j = P[i]; j; j /= 10) n1 = n1 * 10 + (j % 10); insert_into_tree(n1, 5, prime_tree); } return; #endif #define VISIT(_p_) (tmp[(_p_) >> 5] & (1 << ((_p_) & 31))) #define SET(_p_) tmp[(_p_) >> 5] |= (1 << ((_p_) & 31)) for (int i = 2; i < PRIME_RANGE_END; i ++) if (VISIT(i) == 0) { if (i >= PRIME_RANGE_BEGIN) { int s = 0, n1 = 0; for (int j = i; j; j /= 10) { s += j % 10; n1 = n1 * 10 + (j % 10); } if (s == sum) insert_into_tree(n1, 5, prime_tree); } for (int j = i * 2; j < PRIME_RANGE_END; j += i) SET(j); } #undef VISIT #undef SET } int main() { FILE *fin = fopen(INPUT, "r"), *fout = fopen(OUTPUT, "w"); Solve::solve(fin, fout); fclose(fin); fclose(fout); }