数位dp:统计1到n中有多少数包含2018的子串(2018可以不连续)

数据范围:10^10

  • 分析:dp[pos][pre] = Node { ll x, ll y, ll z,ll k, ll none}
    • x, y, z, k, none分别代表包含8, 18, 018, 2018子串的数量, 除前边的剩下的数量.
    • pos代表第几位,pre代表pos+1的数
    • 注意:当一个数被加了之后,后边就不能加了(即数只能被加一次)
#include
using namespace std;
typedef long long ll;
const int maxn = 100 + 10;

int bits[maxn];
struct Node {
    ll x;
    ll y;
    ll z;
    ll k;
    ll none;
};
Node dp[maxn][maxn];

Node dfs(int pos, int pre, bool flag) {
    if(pos < 0) {
        return (Node){pre == 8 ? 1 : 0, 0, 0, 0, pre == 8? 0 : 1};
    }
    if(pre != -1 && !flag && dp[pos][pre].x != -1) {
        return dp[pos][pre];
    }
    int dep = flag ? bits[pos] : 9;
    Node ret = (Node){0, 0, 0, 0};
    for(int i = 0; i <= dep; i ++) {
        Node t = dfs(pos - 1, i, flag && i == dep);
        if(i == 8) {
            ret.x += t.none;
        } else if(i == 1) {
            ret.y += t.x;
        } else if(i == 0) {
            ret.z += t.y;
        } else if(i == 2) {
            ret.k += t.z;
        }
        if(i != 1) {
            ret.x += t.x;
        }
        if(i !=  0) {
            ret.y += t.y;
        }
        if(i != 2) {
            ret.z += t.z;
        }
        ret.k += t.k;
        if(i != 8) {
            ret.none += t.none;
        }
        
    }
    // cout << pos << " " << pre << " " << flag  << ": " << ret.x  << " " << ret.y << " " << ret.z << " " << ret.k  << " " << ret.none << endl;
    if(!flag && pre != -1) {
        dp[pos][pre] = ret;
    }
    return ret;
}
ll solve(ll n) {
    int len = 0;
    while(n) {
        bits[len ++] = n % 10;
        n /= 10;
    }
    return dfs(len - 1, -1, true).k;
}

int main()  {
    for(int i = 0;i < maxn; i ++) {
        for(int j = 0; j < 10;j ++) {
            dp[i][j].x = -1;
        }
    }
    ll n;
    while(cin >> n) {
        cout << solve(n) << endl;
    }
    return 0;
}

你可能感兴趣的:(数位dp:统计1到n中有多少数包含2018的子串(2018可以不连续))