题意:询问一个区间内要么只包含3,要么只包含8的数有多少个?
解法:数位DP,可惜比赛的时候忘了怎么写了,随便写了个DP没有过,后来才调过了。正确解法是开设一个状态:
f[i][0]表示剩下i+1位数中且暂时不含有3,8的满足要求的数的个数(i位数字可从全0取到全9,后同)
f[i][1]表示剩下i+1位数中且暂时只含有3的满足要求的数的个数
f[i][2]表示剩下i+1位数中且暂时只含有8的满足要求的数的个数
f[i][3]表示剩下i+1位数中且已经含有3和8的满足要求的数的个数,该结果恒为零
标程的解法使用了位运算来计算状态,非常清爽。
代码如下:
#include <cstdlib> #include <cstdio> #include <cstring> #include <algorithm> #include <iostream> using namespace std; int f[15][4], bit[15]; int new_s(int s, int i) { if (i == 3) return s | 1; if (i == 8) return s | 2; return s; } int dfs(int p, int s, int e) { if (p == -1) { return s == 1 || s == 2; } if (!e && ~f[p][s]) return f[p][s]; int res = 0; int u = e ? bit[p] : 9; for (int i = 0; i <= u; ++i) { res += dfs(p-1, new_s(s, i), e&&i==u); } return e ? res : f[p][s] = res; } int cal(int x) { int idx = 0; while (x) { bit[idx++] = x % 10; x /= 10; } return dfs(idx-1, 0, 1); } int main() { int T, l, r; memset(f, 0xff, sizeof (f)); scanf("%d", &T); while (T--) { scanf("%d %d", &l, &r); printf("%d\n", cal(r) - cal(l-1)); for (int i = 0; i < 10; ++i) { printf("%d ", f[i][3]); } puts(""); } return 0; }
自己写的搓代码,但是总归也过了:
#include <cstdlib> #include <cstring> #include <cstdio> #include <iostream> using namespace std; int dp[15][10][3]; /* dp[i][j][0] 表示第i位值为j没有3和8的数有多少个 dp[i][j][1] 表示第i位值为j只含有3的数有多少个 dp[i][j][2] 表示第i位值为j只含有8的数有多少个 */ int bit[15], idx; int dfs(int pos, int num, int sta, int full) { if (!full && dp[pos][num][sta] != -1) { return dp[pos][num][sta]; } if (pos == 0) { if (num == 3) return dp[pos][num][sta] = (sta==1); else if (num == 8) return dp[pos][num][sta] = (sta==2); else return dp[pos][num][sta] = (sta==0); } int end = full ? bit[pos-1] : 9; int temp = 0; for (int i = 0; i <= end; ++i) { if (sta == 0) { if (i == 3 || i == 8) continue; temp += dfs(pos-1, i, 0, full && i==end); } else if (sta == 1) { if (i == 8) continue; temp += dfs(pos-1, i, 1, full && i==end); if (num == 3 && i != 3) { temp += dfs(pos-1, i, 0, full && i==end); } } else { if (i == 3) continue; temp += dfs(pos-1, i, 2, full && i==end); if (num == 8 && i != 8) { temp += dfs(pos-1, i, 0, full && i==end); } } } if (!full) dp[pos][num][sta] = temp; return temp; } int cal(int x) { int ret = 0; idx = 0; while (x) { bit[idx++] = x % 10; x /= 10; } for (int i = 0; i <= bit[idx-1]; ++i) { if (i != 8) { ret += dfs(idx-1, i, 1, i==bit[idx-1]); // 枚举最高位 } if (i != 3) { ret += dfs(idx-1, i, 2, i==bit[idx-1]); } } return ret; } int main() { int T, a, b; memset(dp, 0xff, sizeof (dp)); int x; scanf("%d", &T); while (T--) { scanf("%d %d", &a, &b); printf("%d\n", cal(b) - cal(a-1)); } return 0; };