数位DP

学了一下怎么写递归,发现确实比较简单;

dp[pos][][]对应dfs()中的参数的状态,记忆化当前状态的值,不用考虑这个状态表示什么意思;

然后就是设计好dfs()中的参数;

hdu 3555 http://acm.hdu.edu.cn/showproblem.php?pid=3555

题意:统计1~n之间含有49的数字的个数;

需要记录当前位置,前一位置放了那个数字,当前是否已经包含49,是否有上界;

dfs(pos,pre,istrue,limit);

 1 #include<cstdio>

 2 #include<cstring>

 3 #include<iostream>

 4 #include<algorithm>

 5 #include<iostream>

 6 using  namespace std;

 7 const int N = 20;

 8 typedef long long LL;

 9 int dig[N];

10 LL dp[N][10][2];

11 

12 LL dfs(int pos,int pre,int istrue,int limit) {

13     if (pos < 0) return istrue;

14     if (!limit && dp[pos][pre][istrue] != -1)

15         return dp[pos][pre][istrue];

16     int last = limit ? dig[pos] : 9;

17     LL ret = 0;

18     for (int i = 0; i <= last; i++) {

19         ret += dfs(pos-1,i,istrue || (pre == 4 && i == 9),limit && (i == last));

20     }

21     if (!limit) {

22         dp[pos][pre][istrue] = ret;

23     }

24     return ret;

25 }

26 LL solve(LL n) {

27     int len = 0;

28     while (n) {

29         dig[len++] = n % 10;

30         n /= 10;

31     }

32     return dfs(len-1,0,0,1);

33 }

34 int main(){

35     memset(dp,-1,sizeof(dp));

36     int T; scanf("%d",&T);

37     while (T--) {

38         LL n;

39         cin>>n;

40         cout<<solve(n)<<endl;

41     }

42     return 0;

43 }
View Code

 

 usetc 1307 http://acm.uestc.edu.cn/problem.php?pid=1307

题意:相邻两个数之差大于2;

dfs(pos,pre,limit,fg); fg表示前面是否全为0

 1 #include<cstdio>

 2 #include<iostream>

 3 #include<algorithm>

 4 #include<cmath>

 5 #include<cstdlib>

 6 #include<cstring>

 7 #include<iostream>

 8 using namespace std;

 9 typedef long long LL;

10 const int N = 20;

11 LL dp[N][10][2];

12 LL a,b;

13 int dig[N];

14 LL dfs(int pos,int pre,int limit,int fg) {

15     if (pos < 0) return 1;

16     if (!limit && dp[pos][pre][fg] != -1)

17         return dp[pos][pre][fg];

18     int last = limit ? dig[pos] : 9;

19     LL ret = 0;

20     for (int i = 0; i <= last; i++) {

21         if (fg == 0 || abs(i - pre) >= 2)

22         ret += dfs(pos-1,i,limit && (i == last),fg || i);

23     }

24     if (!limit) {

25         dp[pos][pre][fg] = ret;

26     }

27     return ret;

28 }

29 LL solve(LL n) {

30     int len = 0;

31     if (n < 0) return 0;

32     while (n) {

33         dig[len++] = n % 10;

34         n /= 10;

35     }

36     return dfs(len-1,0,1,0);

37 }

38 int main(){

39     memset(dp,-1,sizeof(dp));

40   //  cout<<solve(15)<<endl;

41     while (cin>>a>>b) {

42         cout<<solve(b)-solve(a-1)<<endl;

43     }

44     return 0;

45 }
View Code

 

hdu4352 http://acm.hdu.edu.cn/showproblem.php?pid=4352

题意:求[L,R]内最长递增子序列是k的个数;

分析:知道题意后,马上map<vector<>,LL> mp[][]搞了,然后华丽丽的T掉了,

vector<int> g,  g[i]表示最长递增子序列为长度为i结尾的值的最小值;

这是我对vector<>里面的值的性质没有思考,显然vector<>最多包含10个数,并且是严格递增的,

这样我们就可以直接状压存了,(1<<10)一个数值跟vector<>是一一对应的;

 1 #include<cstdio>

 2 #include<cstring>

 3 #include<iostream>

 4 #include<cmath>

 5 #include<algorithm>

 6 #include<queue>

 7 #include<vector>

 8 using namespace std;

 9 typedef long long LL;

10 LL dp[22][1<<10][12];

11 LL a,b;

12 int k;

13 int dig[22];

14 int cge(int sta,int k) {

15     if (sta & (1<<k)) return sta;

16     if ((1<<k) > sta) {

17         return sta | (1<<k);

18     }

19     sta |= 1<<k;

20     for (int i = k+1; i < 10; i++) {

21         if (sta & (1<<i)) {

22             return sta ^ (1<<i);

23         }

24     }

25 }

26 int get(int k) {

27     int ret = 0;

28     for (int i = 0; i < 10; i++) if (k & (1<<i)) ret++;

29     return ret;

30 }

31 LL dfs(int pos,int sta,int limit) {

32     if (pos < 0) return get(sta) == k;

33     if (!limit && dp[pos][sta][k] != -1)

34         return dp[pos][sta][k];

35     int last = limit ? dig[pos] : 9;

36     LL ret = 0;

37     for (int i = 0; i <= last; i++) {

38         ret += dfs(pos-1,sta || i ? cge(sta,i) : 0,limit && (i == last));

39     }

40     if (!limit) {

41         dp[pos][sta][k] = ret;

42     }

43     return ret;

44 }

45 LL solve(LL n) {

46     int len = 0;

47     while (n) {

48         dig[len++] = n % 10;

49         n /= 10;

50     }

51     return dfs(len-1,0,1);

52 }

53 int main(){

54     memset(dp,-1,sizeof(dp));

55     int T,cas = 0; scanf("%d",&T);

56     while (T--) {

57         scanf("%I64d%I64d%d",&a,&b,&k);

58         printf("Case #%d: ",++cas);

59         printf("%I64d\n",solve(b) - solve(a-1));

60     }

61     return 0;

62 }
View Code

 

hdu3886 http://acm.hdu.edu.cn/showproblem.php?pid=3886

题意:求[l,r]内满足题意条件的数的个数,给你一个字符串这里且称为标准串,要数值满足这个标准串(条件比较难以表述,看题);

分析:dfs(pos,pre,loc,cc,limit,fg);

分别表示当前位置,前一位的数值,当前匹配到标准串中位置,跟标准串中匹配了几个数值,是否有限制,标记有无前导零;

有个trick,找了好久,如果直接来的话,比如// 1234 1234会输出2,因为12 34 @ 123 4 被当成不同的数了,我们可以规定如果标准串中有连续相同的字符,并且满足转移到下个字符了的话一定先转移;

 1 #include<cstdio>

 2 #include<cstring>

 3 #include<iostream>

 4 #include<algorithm>

 5 #include<cstdlib>

 6 #include<cmath>

 7 using namespace std;

 8 typedef long long LL;

 9 const int N = 100+10;

10 const LL Mod = 100000000;

11 char a[N],b[N];

12 char stand[N];

13 int flag;

14 void init(){

15     int lena = strlen(a), lenb = strlen(b);

16     reverse(a,a+lena);

17     reverse(b,b+lenb);

18     flag = 1;

19     for (int i = 0; i < lena; i++) if (a[i] != '0') flag = 0;

20     if (lena == 1 && a[0] == '0') flag = 1;

21     int mark = 1;

22     if (!flag)

23     for (int i = 0; i < lena; i++) {

24         if (mark) {

25             if (a[i] > '0') {

26                 a[i] = a[i] - 1;

27                 mark = 0;

28                 break;

29             }

30             else {

31                 a[i] = '9';

32                 mark = 1;

33             }

34         }

35     }

36 }

37 int dig[N];

38 int end;

39 int dp[N][10][N][2][2];

40 

41 int check(int pre,int nw,int id) {

42     if (id >= end) return 0;

43     if (pre < nw && stand[id] == '/') return 1;

44     if (pre == nw && stand[id] == '-') return 1;

45     if (pre > nw && stand[id] == '\\') return 1;

46     return 0;

47 

48 }

49 int dfs(int pos,int pre,int loc,int cc,int limit,int fg) {

50     if (pos < 0) return loc == end - 1 && cc >= 2;

51     int t = 0;

52     if (cc >= 2) t = 1;

53     if (!limit && dp[pos][pre][loc][t][fg] != -1)

54         return dp[pos][pre][loc][t][fg];

55     int last = limit ? dig[pos] : 9;

56     int ret = 0;

57     for (int i = 0; i <= last; i++) {

58         if (!fg) {

59             ret = (ret + dfs(pos-1,i,loc,i != 0,limit && (i==last),fg || i)) % Mod;

60             continue;

61         }

62         if (cc >= 2 &&  check(pre,i,loc+1) && stand[loc] == stand[loc+1]){

63             ret = (ret + dfs(pos-1,i,loc+1,2,limit && (i==last),fg || i)) % Mod;

64             continue;

65         }

66 

67         if (cc >= 2 &&  check(pre,i,loc+1) && stand[loc] != stand[loc+1]){

68             ret = (ret + dfs(pos-1,i,loc+1,2,limit && (i==last),fg || i)) % Mod;

69         }

70         if (check(pre,i,loc))

71             ret = (ret + dfs(pos-1,i,loc,cc+1,limit && (i==last),fg || i) ) % Mod;

72     }

73     if (!limit ) {

74         dp[pos][pre][loc][t][fg] = ret;

75     }

76     return ret;

77 }

78 int solve(char *s) {

79     int len = strlen(s);

80     for (int i = 0; i < len; i++) {

81         dig[i] = s[i] - '0';

82     }

83     while (dig[len-1] == 0) len--;

84    // dig[len++] = 0;

85     return dfs(len-1,0,0,0,1,0);

86 }

87 int main(){

88     while (~scanf("%s%s%s",stand,a,b)) {

89         init();

90         memset(dp,-1,sizeof(dp));

91         end = strlen(stand);

92       //  cout<<a<<endl;

93       //  cout<<b<<endl;

94         if (flag) printf("%08d\n",solve(b));

95         else printf("%08d\n",(solve(b) - solve(a) + Mod) % Mod);

96     }

97     return 0;

98 }
View Code

 

cf 55D http://codeforces.com/problemset/problem/55/D

题意:求[L,R]之间能整除自己每一位的数的个数;

分析:1~9的最小公倍数为2520,同时记录下那些数出现过因为0,1不许要1<<8,cf内存大

dfs(pos,sta,mod,limit);

 1 #include<cstdio>

 2 #include<cstring>

 3 #include<iostream>

 4 #include<algorithm>

 5 #include<cmath>

 6 using namespace std;

 7 typedef long long LL;

 8 const int Mod = 2520;

 9 LL dp[20][1<<8][2520];

10 

11 LL a,b;

12 int dig[20];

13 int check(int sta,int mod) {

14     for (int i = 2; i < 10; i++) {

15         if (sta & (1<<(i-2))) {

16             if (mod % i) return 0;

17         }

18     }

19     return 1;

20 }

21 LL dfs(int pos,int sta,int mod,int limit) {

22     if (pos < 0) return check(sta,mod);

23     if (!limit && dp[pos][sta][mod] != -1) return dp[pos][sta][mod];

24     int last = limit ? dig[pos] : 9;

25     LL ret = 0;

26     for (int i = 0; i <= last; i++) {

27         int t = sta;

28         if (i >= 2) t |= 1<<(i-2);

29         ret += dfs(pos-1,t,(mod * 10 + i) % Mod,limit && (i == last));

30     }

31     if (!limit) {

32         dp[pos][sta][mod] = ret;

33     }

34     return ret;

35 }

36 LL solve(LL n) {

37     int len = 0;

38     while (n) {

39         dig[len++] = n % 10;

40         n /= 10;

41     }

42     return dfs(len-1,0,0,1);

43 }

44 int main(){

45     memset(dp,-1,sizeof(dp));

46     int T; scanf("%d",&T);

47     while (T--) {

48         cin>>a>>b;

49         cout<<solve(b) - solve(a-1)<<endl;

50     }

51     return 0;

52 }
View Code

 

 Foj 2042 http://acm.fzu.edu.cn/problem.php?pid=2042

题意:求[a,b]与[c,d]之间 xor值大于e的 sum += i ^ j;

分析:

dfs(pos,limita,limitb,limitc);

如果只是记录dp[pos]的话会TLE,所以记录dp[pos][limita][limitb][limitc];

这样的话每次solve()都要初始化;

递归版本的数位Dp记录真的很灵活,全看题目;

 1 #include<cstdio>

 2 #include<cstring>

 3 #include<algorithm>

 4 #include<iostream>

 5 #define MP make_pair

 6 using namespace std;

 7 const int N = 66;

 8 typedef long long LL;

 9 const LL Mod = 1000000007;

10 typedef pair<LL,LL> pLL;

11 pLL dp[N][2][2][2];

12 LL a,b,c,d,e;

13 int diga[N],digb[N],digc[N];

14 void change(int dig[],LL n,int &len) {

15     len = 0;

16     while (n) {

17         dig[len++] = n % 2;

18         n /= 2;

19     }

20 }

21 pLL dfs(int pos,int limita,int limitb,int limitc) {

22     if (pos < 0) {

23           return !limitc ? MP(1,0) : MP(0,0);

24     }

25     if (dp[pos][limita][limitb][limitc].first != -1 && dp[pos][limita][limitb][limitc].second != -1) return dp[pos][limita][limitb][limitc];

26     int lasta = limita ? diga[pos] : 1;

27     int lastb = limitb ? digb[pos] : 1;

28     int lastc = limitc ? digc[pos] : -1;

29     pLL ret = MP(0,0);

30     for (int i = 0; i <= lasta; i++) {

31         for (int j = 0; j <= lastb; j++) {

32             int t = (i ^ j);

33             if (t >= lastc) {

34                pLL cnt = dfs(pos-1,limita && (i == lasta),limitb && (j == lastb),limitc && (t == lastc));

35                ret.first = (ret.first + cnt.first) % Mod;

36                ret.second = ((ret.second + (1ll<<pos) * (i^j)  % Mod * cnt.first % Mod) % Mod + cnt.second) % Mod;

37             }

38         }

39     }

40     dp[pos][limita][limitb][limitc] = ret;

41 

42     return ret;

43 }

44 

45 LL solve(LL a,LL b,LL c) {

46         for (int i = 0; i < N; i++)

47               for (int j = 0; j < 2; j++)

48                   for (int k = 0; k < 2; k++)

49                   for (int x = 0; x < 2; x++) dp[i][j][k][x] = MP(-1,-1);

50 

51     int lena,lenb,lenc;

52     change(diga,a,lena);

53     change(digb,b,lenb);

54     change(digc,c,lenc);

55     int len = max(lena,max(lenb,lenc));

56     while (lena < len) diga[lena++] = 0;

57     while (lenb < len) digb[lenb++] = 0;

58     while (lenc < len) digc[lenc++] = 0;

59  

60     return (dfs(len-1,1,1,1).second + Mod) % Mod;

61 }

62 int main(){

63     int T,cas = 0; scanf("%d",&T);

64     while (T--) {

65 

66         printf("Case %d: ",++cas);

67         cin>>a>>b>>c>>d>>e;

68 

69         cout<<((solve(b,d,e) - solve(b,c-1,e) + Mod) % Mod  - solve(a-1,d,e) + solve(a-1,c-1,e) + Mod) % Mod<<endl;

70     }

71     return 0;

72 }
View Code

 

 

 

你可能感兴趣的:(dp)