CF 1073E. Segment Sum(digit DP)

题目链接

E. Segment Sum

分析

可以说很经典了,数位 dp,关于数位dp我也才学

数位dp

这个题目与 仅统计个数有点不同的地方在于,它要求值的和,而对于整数来说每个位是可以独立相加的,可是如果仅仅用一个状态 dp[st][pos][0/1] 表示吃状态下的最终结果的话,可能不行,因为比如

1232**
1233** 这两种情况其实是在一个 dp 状态下的 st = 111,pos = 3
所以会遗漏一些位的计算,所以我们采用 计算每个位贡献和(单独的位可相加)的方式将后面的值加起来

所以计算两个量一个是个数,一个是当前值

code

#include
using namespace std;

#define MAX_VAL         110
#define MAX_ARRAY_SIZE  20
#define ms(x,v)         memset((x),(v),sizeof(x))
#define pb              push_back
#define fi              first
#define se              second
#define mp              make_pair
#define INF             0x3f3f3f3f
#define MOD             998244353
#define CTZ(x)          __builtin_ctz(x)        //the number of zeros at the end of number
#define CLZ(x)          __builtin_clz(x)        //the number of zeros at the begin of number
#define POPCNT(x)       __builtin_popcount(x)   //the number of ones in the number
#define PARITY(x)       __builtin_parity(x)     //the parity(odd or even) of the number

typedef long long LL;
typedef pair<LL,LL> Pair;
int k;
LL a,b;
int num[MAX_ARRAY_SIZE],cnt;
LL dp[1<<10][MAX_ARRAY_SIZE][2];
LL ans[1<<10][MAX_ARRAY_SIZE][2];
LL pw[MAX_ARRAY_SIZE];

pair<LL,LL> calc(const int st,const int pos,const int f){
    if(POPCNT(st)>k)return mp(0,0);
    if(pos == cnt){
        return mp(0,1);
    }
    pair<LL,LL> ret = mp(ans[st][pos][f],dp[st][pos][f]);
    if(ret.se !=-1)return ret;
    ret = mp(0,0);
    int lmt = f? 9: num[pos];
    for(int i=0 ; i<=lmt ; ++i){
        Pair tmp=mp(0,0);
        if(st ==0){
            tmp = calc(i?1<<i : 0,pos+1,f||i<lmt);
        }
        else{
             tmp = calc(st|1<<i,pos+1,f||i<lmt);
        }
        ret.fi += i*pw[cnt-pos-1]%MOD * tmp.se % MOD+ tmp.fi;
        ret.se += tmp.se;
    }
    return mp(ans[st][pos][f] = ret.fi % MOD, dp[st][pos][f] = ret.se %MOD);
}

LL solve(const LL n){
    cnt =0;
    ms(dp,-1);
    ms(ans,-1);
     LL debug_n = n;
    while (debug_n) {
        num[cnt++] = debug_n % 10;
        debug_n/=10;
    }
    reverse(num,num+cnt);
    Pair ret = calc(0,0,0);
    return ret.fi;
}


int main (int argc, char const *argv[]) {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    pw[0] =1;
    for(int i=1 ; i<MAX_ARRAY_SIZE ; ++i)pw[i] = pw[i-1]*10 %MOD;

    cin >> a >> b >> k;
    std::cout << (solve(b) - solve(a -1) + MOD) % MOD << '\n';
    return 0;
}

你可能感兴趣的:(算法&数据结构)