给出n*m的方格,有些格子不能铺线,其它格子必须铺,形成一个闭合回路。问有多少种铺法?
比赛时基本做不出来,就学个新算法玩玩。
学习链接
代码对于我这个不会hash_table 的不太友好,先自己封装了一个用着舒服的hash_table,当然也可以直接用STL里的 unordered_map
,初学算法我认为直接使用后者更好,循序渐进。
插头dp简单的说还是轮廓线的状压dp?,多考虑了一种连通性问题。
用括号表示法表示轮廓线后,然后考虑状态转移。
枚举每点,考虑左边的朝右插头 r
,上边朝下插头 d
,左插头1表示,右插头2表示
所有的情况就这么多,然后每点考虑插头形状转移即可
一个坑点 | ^
运算符优先级不同,不是从左到右运算,然后找了半天bug,还是用 + -
比较稳。
STL unordered_map
代码
991ms
吸氧后 228ms
#include
using namespace std;
#define ll long long
unordered_map<ll,ll> dp[3];
// chatou 0 null, 1 right, 2 down
ll n, m, endx, endy, a[15][15], bit[15];
#define prel (1ll<
#define prer (1ll<
void solve() {
ll cur = 0, ans = 0;
dp[cur].clear();
dp[0][0] = 1;
for(ll i = 1; i <= n; ++i) {
dp[2].clear();
for(auto j : dp[cur]) dp[2][j.first<<2] = j.second;
dp[cur] = dp[2];
for(ll j = 1; j <= m; ++j) {
cur ^= 1;
dp[cur].clear();
for(auto k : dp[cur^1]) {
ll sta = k.first;
ll w = k.second;
ll d = (sta>>bit[j])&3ll;
ll r = (sta>>bit[j-1])&3ll;
// printf("%lld %lld - %lld %lld\n",i,j,sta,w);
// printf("%lld = %lld\n\n",r,d);
if(!a[i][j]) {
if(!r && !d) dp[cur][sta] += w;
}
else if(!r && !d) {
if(a[i+1][j] && a[i][j+1]) dp[cur][sta + prel + (2*prer)] += w;
}
else if(r && !d) {
if(a[i+1][j]) dp[cur][sta] += w;
if(a[i][j+1]) dp[cur][sta - (r*prel) + (r*prer)] += w;
}
else if(!r && d) {
if(a[i+1][j]) dp[cur][sta + (d*prel) - (d*prer)] += w;
if(a[i][j+1]) dp[cur][sta] += w;
}
else if(r == 1 && d == 1) {
ll cnt = 1;
for(ll p = j+1; p <= m; ++p) {
if(((sta>>bit[p])&3ll) == 1) ++cnt;
if(((sta>>bit[p])&3ll) == 2) --cnt;
if(!cnt) {
dp[cur][(sta - (r*prel) - (d*prer)) - (1<<bit[p])] += w;
break;
}
}
}
else if(r == 2 && d == 2) {
ll cnt = 1;
for(ll p = j-2; p >= 0; --p) {
if(((sta>>bit[p])&3ll) == 1) --cnt;
if(((sta>>bit[p])&3ll) == 2) ++cnt;
if(!cnt) {
dp[cur][(sta - (r*prel) - (d*prer)) + (1<<bit[p])] += w;
break;
}
}
}
else if(r == 2 && d == 1) {
dp[cur][sta - (r*prel) - (d*prer)] += w;
}
else if(r == 1 && d == 2) { // ok
if(i == endx && j == endy) ans += w;
}
}
}
}
printf("%lld\n",ans);
}
int main() {
scanf("%lld%lld",&n,&m);
for(ll i = 1; i <= n; ++i) {
char s[15];
scanf("%s",s+1);
for(ll j = 1; j <= m; ++j) {
if(s[j] == '.') {
a[i][j] = 1;
endx = i;
endy = j;
}
}
}
for(ll i = 1; i <= 13; ++i) bit[i] = i<<1;
solve();
return 0;
}
手写hash_table
代码
577ms
吸氧后 562ms
#include
using namespace std;
#define ll long long
struct hash_table {
ll hash_mod = 590027;
ll state[600000], ans[600000], up;
ll tot, first[600000], nxt[600000], w[600000];
void init() {
memset(first, 0, sizeof(first));
tot = 0;
up = 0;
}
ll ins(ll sta, ll val) {
ll key = sta%hash_mod;
for(ll i = first[key]; i; i = nxt[i]) {
if(state[w[i]] == sta) return ans[w[i]] += val;
}
state[++up] = sta;
ans[up] = val;
nxt[++tot] = first[key];
w[tot] = up;
first[key] = tot;
return val;
}
}dp[2];
/*hash_table*/
// chatou 0 null, 1 right, 2 down
ll n, m, endx, endy, a[15][15], bit[15];
#define prel (1ll<
#define prer (1ll<
void solve() {
ll cur = 0, ans = 0;
dp[cur].init();
dp[0].ins(0,1);
for(ll i = 1; i <= n; ++i) {
for(ll j = 1; j <= dp[cur].up; ++j) dp[cur].state[j] <<= 2;
for(ll j = 1; j <= m; ++j) {
cur ^= 1;
dp[cur].init();
for(ll k = 1; k <= dp[cur^1].up; ++k) {
ll sta = dp[cur^1].state[k];
ll w = dp[cur^1].ans[k];
ll d = (sta>>bit[j])&3ll;
ll r = (sta>>bit[j-1])&3ll;
// printf("%lld %lld - %lld %lld\n",i,j,sta,w);
// printf("%lld = %lld\n\n",r,d);
if(!a[i][j]) {
if(!r && !d) dp[cur].ins(sta,w);
}
else if(!r && !d) {
if(a[i+1][j] && a[i][j+1]) dp[cur].ins(sta + prel + (2*prer),w);
}
else if(r && !d) {
if(a[i+1][j]) dp[cur].ins(sta,w);
if(a[i][j+1]) dp[cur].ins(sta - (r*prel) + (r*prer),w);
}
else if(!r && d) {
if(a[i+1][j]) dp[cur].ins(sta + (d*prel) - (d*prer),w);
if(a[i][j+1]) dp[cur].ins(sta,w);
}
else if(r == 1 && d == 1) {
ll cnt = 1;
for(ll p = j+1; p <= m; ++p) {
if(((sta>>bit[p])&3ll) == 1) ++cnt;
if(((sta>>bit[p])&3ll) == 2) --cnt;
if(!cnt) {
dp[cur].ins((sta - (r*prel) - (d*prer)) - (1<<bit[p]),w);
break;
}
}
}
else if(r == 2 && d == 2) {
ll cnt = 1;
for(ll p = j-2; p >= 0; --p) {
if(((sta>>bit[p])&3ll) == 1) --cnt;
if(((sta>>bit[p])&3ll) == 2) ++cnt;
if(!cnt) {
dp[cur].ins((sta - (r*prel) - (d*prer)) + (1<<bit[p]),w);
break;
}
}
}
else if(r == 2 && d == 1) {
dp[cur].ins(sta - (r*prel) - (d*prer),w);
}
else if(r == 1 && d == 2) { // ok
if(i == endx && j == endy) ans += w;
}
}
}
}
printf("%lld\n",ans);
}
int main() {
scanf("%lld%lld",&n,&m);
for(ll i = 1; i <= n; ++i) {
char s[15];
scanf("%s",s+1);
for(ll j = 1; j <= m; ++j) {
if(s[j] == '.') {
a[i][j] = 1;
endx = i;
endy = j;
}
}
}
for(ll i = 1; i <= 13; ++i) bit[i] = i<<1;
solve();
return 0;
}