[csu1605]数独(精确覆盖问题)

题意 :给定数独的某些初始值,规定每个格子的得分,求得分最大的数独的解。

思路:这是某年的noip的原题,高中时就写过,位运算也就是那个时候学会的--。这题明显是暴搜,但是需要注意两点,一是需要加一些常数优化,也就是位运算,一个是剪枝,填完某个数后发现某个格子无解了则换个数填,并且那些可填的数的种数少的格子尽量先填,因为这样尽可能让矛盾在靠近根的地方出现。今天粗略学了一下舞蹈链--DLX,这个算法(准确来说是一个结构)可以比较高效的解决一些精确覆盖问题,对于重复覆盖问题稍作修改也适用。用DLX写了一遍数独,发现效率比位运算略高一点,但不明显。

位运算:

  1 #pragma comment(linker, "/STACK:10240000,10240000")

  2 

  3 #include <iostream>

  4 #include <cstdio>

  5 #include <algorithm>

  6 #include <cstdlib>

  7 #include <cstring>

  8 #include <map>

  9 #include <queue>

 10 #include <deque>

 11 #include <cmath>

 12 #include <vector>

 13 #include <ctime>

 14 #include <cctype>

 15 #include <set>

 16 #include <bitset>

 17 #include <functional>

 18 #include <numeric>

 19 #include <stdexcept>

 20 #include <utility>

 21 

 22 using namespace std;

 23 

 24 #define mem0(a) memset(a, 0, sizeof(a))

 25 #define mem_1(a) memset(a, -1, sizeof(a))

 26 #define lson l, m, rt << 1

 27 #define rson m + 1, r, rt << 1 | 1

 28 #define define_m int m = (l + r) >> 1

 29 #define rep_up0(a, b) for (int a = 0; a < (b); a++)

 30 #define rep_up1(a, b) for (int a = 1; a <= (b); a++)

 31 #define rep_down0(a, b) for (int a = b - 1; a >= 0; a--)

 32 #define rep_down1(a, b) for (int a = b; a > 0; a--)

 33 #define all(a) (a).begin(), (a).end()

 34 #define lowbit(x) ((x) & (-(x)))

 35 #define constructInt4(name, a, b, c, d) name(int a = 0, int b = 0, int c = 0, int d = 0): a(a), b(b), c(c), d(d) {}

 36 #define constructInt3(name, a, b, c) name(int a = 0, int b = 0, int c = 0): a(a), b(b), c(c) {}

 37 #define constructInt2(name, a, b) name(int a = 0, int b = 0): a(a), b(b) {}

 38 #define pchr(a) putchar(a)

 39 #define pstr(a) printf("%s", a)

 40 #define sstr(a) scanf("%s", a)

 41 #define sint(a) scanf("%d", &a)

 42 #define sint2(a, b) scanf("%d%d", &a, &b)

 43 #define sint3(a, b, c) scanf("%d%d%d", &a, &b, &c)

 44 #define pint(a) printf("%d\n", a)

 45 #define test_print1(a) cout << "var1 = " << a << endl

 46 #define test_print2(a, b) cout << "var1 = " << a << ", var2 = " << b << endl

 47 #define test_print3(a, b, c) cout << "var1 = " << a << ", var2 = " << b << ", var3 = " << c << endl

 48 

 49 typedef long long LL;

 50 typedef pair<int, int> pii;

 51 typedef vector<int> vi;

 52 

 53 const int dx[8] = {0, 0, -1, 1, 1, 1, -1, -1};

 54 const int dy[8] = {-1, 1, 0, 0, 1, -1, 1, -1 };

 55 const int maxn = 3e4 + 7;

 56 const int md = 10007;

 57 const int inf = 1e9 + 7;

 58 const LL inf_L = 1e18 + 7;

 59 const double pi = acos(-1.0);

 60 const double eps = 1e-6;

 61 

 62 template<class T>T gcd(T a, T b){return b==0?a:gcd(b,a%b);}

 63 template<class T>bool max_update(T &a,const T &b){if(b>a){a = b; return true;}return false;}

 64 template<class T>bool min_update(T &a,const T &b){if(b<a){a = b; return true;}return false;}

 65 template<class T>T condition(bool f, T a, T b){return f?a:b;}

 66 template<class T>void copy_arr(T a[], T b[], int n){rep_up0(i,n)a[i]=b[i];}

 67 int make_id(int x, int y, int n) { return x * n + y; }

 68 

 69 int ans, a[10][10], f[1 << 13], row[10], col[10], block[10], sp[1 << 13];

 70 

 71 int getScore(int i, int j) {

 72     return min(min(i, 8 - i), min(j, 8 - j)) + 6;

 73 }

 74 

 75 void init() {

 76     rep_up0(i, 12) {

 77         f[1 << i] = i;

 78     }

 79 }

 80 

 81 void dfs(int k, int score) {

 82     if (k >= 81) {

 83         max_update(ans, score);

 84         return ;

 85     }

 86     int x, y, c = 10;

 87     rep_up0(i, 9) {

 88         bool ok = false;

 89         rep_up0(j, 9) {

 90             if (a[i][j]) continue;

 91             int tmp = row[i] | col[j] | block[make_id(i / 3, j / 3, 3)];

 92             int tot = 0x3fe ^ tmp;

 93             int cnt = 0;

 94             if (tot == 0) {

 95                 ok = true;

 96                 c = 0;

 97                 break;

 98             }

 99             cnt = sp[tot];

100             if (cnt < c) {

101                 x = i;

102                 y = j;

103                 c = cnt;

104             }

105         }

106         if (ok) break;

107     }

108     if (c == 0 || c == 10) return ;

109     int i = x, j = y;

110     int tmp = row[i] | col[j] | block[make_id(i / 3, j / 3, 3)];

111     int tot = 0x3fe ^ tmp;

112     while (tot) {

113         tmp = lowbit(tot);

114         row[i] ^= 1 << f[tmp];

115         col[j] ^= 1 << f[tmp];

116         block[make_id(i / 3, j / 3, 3)] ^= 1 << f[tmp];

117         a[i][j] = f[tmp];

118         dfs(k + 1, score + f[tmp] * getScore(i, j));

119         row[i] ^= 1 << f[tmp];

120         col[j] ^= 1 << f[tmp];

121         block[make_id(i / 3, j / 3, 3)] ^= 1 << f[tmp];

122         a[i][j] = 0;

123         tot -= tmp;

124     }

125 }

126 

127 int main() {

128     //freopen("in.txt", "r", stdin);

129     sp[0] = 0;

130     rep_up1(i, 1 << 10) {

131         sp[i] = sp[i - lowbit(i)] + 1;

132     }

133     int T;

134     init();

135     cin >> T;

136     while (T --) {

137         int sum = 0, cnt = 0, ok = true;

138         mem0(col);

139         mem0(row);

140         mem0(block);

141         rep_up0(i, 9) {

142             rep_up0(j, 9) {

143                 sint(a[i][j]);

144                 sum += a[i][j] * getScore(i, j);

145                 if (a[i][j]) {

146                     cnt ++;

147                     if (col[j] & (1 << a[i][j])) ok = false;

148                     if (row[i] & (1 << a[i][j])) ok = false;

149                     if (block[make_id(i / 3, j / 3, 3)] & (1 << a[i][j])) ok = false;

150                     col[j] |= 1 << a[i][j];

151                     row[i] |= 1 << a[i][j];

152                     block[make_id(i / 3, j / 3, 3)] |= 1 << a[i][j];

153                 }

154             }

155         }

156         ans = -1;

157         if (ok) dfs(cnt, sum);

158         cout << ans << endl;

159     }

160 }
View Code

DLX(模板):

  1 #pragma comment(linker, "/STACK:102400000,102400000")

  2 

  3 #include <iostream>

  4 #include <cstdio>

  5 #include <algorithm>

  6 #include <cstdlib>

  7 #include <cstring>

  8 #include <map>

  9 #include <queue>

 10 #include <deque>

 11 #include <cmath>

 12 #include <vector>

 13 #include <ctime>

 14 #include <cctype>

 15 #include <set>

 16 #include <bitset>

 17 #include <functional>

 18 #include <numeric>

 19 #include <stdexcept>

 20 #include <utility>

 21 

 22 using namespace std;

 23 

 24 #define mem0(a) memset(a, 0, sizeof(a))

 25 #define mem_1(a) memset(a, -1, sizeof(a))

 26 #define lson l, m, rt << 1

 27 #define rson m + 1, r, rt << 1 | 1

 28 #define define_m int m = (l + r) >> 1

 29 #define rep_up0(a, b) for (int a = 0; a < (b); a++)

 30 #define rep_up1(a, b) for (int a = 1; a <= (b); a++)

 31 #define rep_down0(a, b) for (int a = b - 1; a >= 0; a--)

 32 #define rep_down1(a, b) for (int a = b; a > 0; a--)

 33 #define all(a) (a).begin(), (a).end()

 34 #define lowbit(x) ((x) & (-(x)))

 35 #define constructInt4(name, a, b, c, d) name(int a = 0, int b = 0, int c = 0, int d = 0): a(a), b(b), c(c), d(d) {}

 36 #define constructInt3(name, a, b, c) name(int a = 0, int b = 0, int c = 0): a(a), b(b), c(c) {}

 37 #define constructInt2(name, a, b) name(int a = 0, int b = 0): a(a), b(b) {}

 38 #define pchr(a) putchar(a)

 39 #define pstr(a) printf("%s", a)

 40 #define sstr(a) scanf("%s", a)

 41 #define sint(a) scanf("%d", &a)

 42 #define sint2(a, b) scanf("%d%d", &a, &b)

 43 #define sint3(a, b, c) scanf("%d%d%d", &a, &b, &c)

 44 #define pint(a) printf("%d\n", a)

 45 #define test_print1(a) cout << "var1 = " << a << endl

 46 #define test_print2(a, b) cout << "var1 = " << a << ", var2 = " << b << endl

 47 #define test_print3(a, b, c) cout << "var1 = " << a << ", var2 = " << b << ", var3 = " << c << endl

 48 

 49 typedef long long LL;

 50 typedef pair<int, int> pii;

 51 typedef vector<int> vi;

 52 

 53 const int dx[8] = {0, 0, -1, 1, 1, 1, -1, -1};

 54 const int dy[8] = {-1, 1, 0, 0, 1, -1, 1, -1 };

 55 const int maxn = 1e5 + 7;

 56 const int md = 10007;

 57 const int inf = 1e9 + 7;

 58 const LL inf_L = 1e18 + 7;

 59 const double pi = acos(-1.0);

 60 const double eps = 1e-6;

 61 

 62 template<class T>T gcd(T a, T b){return b==0?a:gcd(b,a%b);}

 63 template<class T>bool max_update(T &a,const T &b){if(b>a){a = b; return true;}return false;}

 64 template<class T>bool min_update(T &a,const T &b){if(b<a){a = b; return true;}return false;}

 65 template<class T>T condition(bool f, T a, T b){return f?a:b;}

 66 template<class T>void copy_arr(T a[], T b[], int n){rep_up0(i,n)a[i]=b[i];}

 67 int make_id(int x, int y, int n) { return x * n + y; }

 68 

 69 ///行编号从1开始,列编号1~n,结点0是表头结点,结点1~n是各列顶部的虚拟结点

 70 int result;

 71 int b[10][10];

 72 

 73 int encode(int a, int b, int c) {

 74     return a * 81 + b * 9 + c + 1;

 75 }

 76 void decode(int code, int &a, int &b, int &c) {

 77     code --;

 78     c = code % 9; code /= 9;

 79     b = code % 9; code /= 9;

 80     a = code;

 81 }

 82 

 83 struct DLX

 84 {

 85     const static int maxn = 1050;

 86     const static int maxnode = 100007;

 87     int n , sz;                                                 // 行数,节点总数

 88     int S[maxn];                                                // 各列节点总数

 89     int row[maxnode],col[maxnode];                              // 各节点行列编号

 90     int L[maxnode],R[maxnode],U[maxnode],D[maxnode];            // 十字链表

 91 

 92     int ansd,ans[maxn];                                         //

 93 

 94     void init(int n )

 95     {

 96         this->n = n ;

 97         for(int i = 0 ; i <= n; i++ )

 98             {

 99               U[i] = i ;

100               D[i] = i ;

101               L[i] = i - 1;

102               R[i] = i + 1;

103         }

104         R[n] = 0 ;

105         L[0] = n;

106         sz = n + 1 ;

107         memset(S,0,sizeof(S));

108     }

109     void addRow(int r,vector<int> c1)

110     {

111         int first = sz;

112         for(int i = 0 ; i < c1.size(); i++ ){

113             int c = c1[i];

114             L[sz] = sz - 1 ; R[sz] = sz + 1 ; D[sz] = c ; U[sz] = U[c];

115             D[U[c]] = sz; U[c] = sz;

116             row[sz] = r; col[sz] = c;

117             S[c] ++ ; sz ++ ;

118         }

119         R[sz - 1] = first ; L[first] = sz - 1;

120     }

121     // 顺着链表A,遍历除s外的其他元素

122     #define FOR(i,A,s) for(int i = A[s]; i != s ; i = A[i])

123 

124     void remove(int c) {

125         L[R[c]] = L[c];

126         R[L[c]] = R[c];

127         FOR(i,D,c)

128             FOR(j,R,i) {U[D[j]] = U[j];D[U[j]] = D[j];--S[col[j]];}

129     }

130     void restore(int c) {

131         FOR(i,U,c)

132             FOR(j,L,i) {++S[col[j]];U[D[j]] = j;D[U[j]] = j; }

133         L[R[c]] = c;

134         R[L[c]] = c;

135     }

136     void update() {

137         int score = 0;

138         rep_up0(i, ansd) {

139             int r, c, v;

140             decode(ans[i], r, c, v);

141             score += (v + 1) * b[r][c];

142         }

143         max_update(result, score);

144     }

145     bool dfs(int d) {

146         if(R[0] == 0) {

147           ansd = d;

148           update();

149           return true;

150         }

151         // 找S最小的列c

152         int c = R[0];

153         FOR(i,R,0) if(S[i] < S[c]) c = i;

154 

155         remove(c);

156         FOR(i,D,c) {

157             ans[d] = row[i];

158             FOR(j,R,i) remove(col[j]);

159             //if(dfs(d + 1)) return true;

160             dfs(d + 1);

161             FOR(j,L,i) restore(col[j]);

162         }

163         restore(c);

164 

165         //return false;

166     }

167     bool solve(vector<int> & v) {

168         v.clear();

169         if(!dfs(0)) return false;

170         for(int i = 0 ; i < ansd ;i ++) v.push_back(ans[i]);

171         return true;

172     }

173 };

174 

175 DLX solver;

176 int a[12][12];

177 

178 

179 int main() {

180     //freopen("in.txt", "r", stdin);

181     rep_up0(i, 9) {

182         rep_up0(j, 9) {

183             b[i][j] = 6 + min(min(i, 8 - i), min(j, 8 - j));

184         }

185     }

186     int T, x;

187     cin >> T;

188     while (T --) {

189         solver.init(324);

190         rep_up0(i, 9) {

191             rep_up0(j, 9) {

192                 int x;

193                 sint(x);

194                 rep_up0(k, 9) {

195                     if (x == 0 || x == k + 1) {

196                         vector<int> col;

197                         col.push_back(encode(0, i, j));

198                         col.push_back(encode(1, i, k));

199                         col.push_back(encode(2, j, k));

200                         col.push_back(encode(3, make_id(i / 3, j / 3, 3), k));

201                         solver.addRow(encode(i, j, k), col);

202                     }

203                 }

204             }

205         }

206         result = -1;

207         solver.dfs(0);

208         cout << result << endl;

209     }

210     return 0;

211 }
View Code

 

你可能感兴趣的:(问题)