ICPC 2023 网络赛 j (线性dp

#include
using namespace std;
using VI = vector;
using ll = long long;
const int mod = 998244353;

//? 63
//@ 64
//a 97
//z 122
//Z 90
//A 65
int n;
string s;
//da  xiao shu
ll dp[2][2][2][2][100];
ll sum[2][2][2];
int change(char x){
    if(x == '?') return 1;
    else if(x >= 'A' && x <= 'Z'){
        return (int)x - 'A' + 2;
    }else if(x >= 'a' && x <= 'z'){
        return (int)(x - 'a') + 28;
    }else{
        return (int) x - '0' + 54;
    }
}
// ? 1
// A  - Z   2  27
// a - z    28  53
// 0 - 9   54  63


int main(){
    cin>>n;
    cin>>s;
    s = " " + s;
    //1dp[0][0][0][0][1] = 1;

    int p = change(s[1]);
    if(p == 1){
        for(int k = 2 ; k <= 63 ; k++){
            if( k >= 2 && k <= 27){//daxie
                dp[1][1][0][0][k] = 1;

            }else if( k >= 28 && k <= 53){
               dp[1][0][1][0][k] = 1;

            }else if( k >= 54 && k <=63){
                dp[1][0][0][1][k] = 1;
            }
        }

    }else if(p >= 28 && p <= 53){
        dp[1][0][1][0][p] = 1;
        dp[1][1][0][0][p - 26] = 1;

    }else if(p >= 54 && p <= 63){
        dp[1][0][0][1][p] = 1;

    }else{
        dp[1][1][0][0][p] = 1;
    }

    for(int i = 2 ; i <= n; i++){
        int cur = i % 2;
        int pre = (i+1) % 2;
        int t = change(s[i]);
        memset(sum , 0 , sizeof sum);
        for(int j = 0 ; j <= 63 ; j++){
            for(int st1 = 0 ;st1 <= 1 ;st1++){
                for(int st2 = 0 ; st2 <= 1 ; st2++){
                    for(int st3 = 0 ;st3 <= 1 ; st3++){
                        dp[cur][st1][st2][st3][j] = 0 ;
                        sum[st1][st2][st3] += dp[pre][st1][st2][st3][j];
                        sum[st1][st2][st3] %= mod;
                    }
                }
            }
        }

        if(t == 1){

            for(int k = 2 ; k <= 63 ;k++){

                if( k >= 2 && k <= 27){//daxie
                    for(int st1 = 0 ;st1 <= 1; st1++){
                        for(int st2 = 0;st2 <= 1; st2++){
                            dp[cur][1][st1][st2][k] = dp[cur][1][st1][st2][k]
                                                      + ((sum[0][st1][st2] - dp[pre][0][st1][st2][k] + mod) %  mod + (sum[1][st1][st2] - dp[pre][1][st1][st2][k] + mod) % mod) % mod;
                            dp[cur][1][st1][st2][k] %= mod;

                        }
                    }

                }else if( k >= 28 && k <= 53){
                    for(int st1 = 0 ;st1 <= 1; st1++){
                        for(int st2 = 0;st2 <= 1; st2++){
                            dp[cur][st1][1][st2][k] = dp[cur][st1][1][st2][k]
                                                      + ((sum[st1][0][st2] - dp[pre][st1][0][st2][k]) %  mod + (sum[st1][1][st2] - dp[pre][st1][1][st2][k] ) % mod) % mod;
                            dp[cur][st1][1][st2][k] %= mod;

                        }
                    }

                }else if( k >= 54 && k <=63){
                    for(int st1 = 0 ;st1 <= 1; st1++){
                        for(int st2 = 0;st2<=1;st2++){
                            dp[cur][st1][st2][1][k] = dp[cur][st1][st2][1][k]
                                                      + ((sum[st1][st2][0] - dp[pre][st1][st2][0][k]) %  mod
                                                         + (sum[st1][st2][1] - dp[pre][st1][st2][1][k]) % mod) % mod;
                            dp[cur][st1][st2][1][k] %= mod;

                        }
                    }
                }
            }

        }else if( (t >= 2 && t<=27)){//大写字母的情况
            for(int st1 = 0 ; st1 <= 1; st1++){
                for(int st2 = 0; st2 <= 1; st2++){
                    dp[cur][1][st1][st2][t] = dp[cur][1][st1][st2][t]
                                              + ((sum[0][st1][st2] - dp[pre][0][st1][st2][t]) %  mod
                                                 + (sum[1][st1][st2] - dp[pre][1][st1][st2][t]) % mod) % mod;
                    dp[cur][1][st1][st2][t] %= mod;

                }
            }

        }else if(t >= 28 && t<= 53){//小写字母的情况
            for(int st1 = 0 ;st1 <= 1; st1++){
                for(int st2 = 0;st2 <= 1; st2++){
                    dp[cur][st1][1][st2][t] = dp[cur][st1][1][st2][t]
                                              + ((sum[st1][0][st2] - dp[pre][st1][0][st2][t]) %  mod
                                                 + (sum[st1][1][st2] - dp[pre][st1][1][st2][t]) % mod) % mod;
                    dp[cur][st1][1][st2][t] %= mod;

                }
            }
            t -= 26;
            for(int st1 = 0 ;st1 <= 1; st1++){
                for(int st2 = 0;st2 <= 1; st2++){
                    dp[cur][1][st1][st2][t] = dp[cur][1][st1][st2][t]
                                              + ((sum[0][st1][st2] - dp[pre][0][st1][st2][t]) %  mod
                                                 + (sum[1][st1][st2] - dp[pre][1][st1][st2][t]) % mod) % mod;
                    dp[cur][1][st1][st2][t] %= mod;

                }
            }

        }else if(t >= 54 && t<=63){//数字的情况
            for(int st1 = 0 ;st1 <= 1; st1++){
                for(int st2 = 0;st2<=1;st2++){
                    dp[cur][st1][st2][1][t] = dp[cur][st1][st2][1][t]
                                              + ((sum[st1][st2][0] - dp[pre][st1][st2][0][t]) %  mod
                                                 + (sum[st1][st2][1] - dp[pre][st1][st2][1][t]) % mod) % mod;
                    dp[cur][st1][st2][1][t] %= mod;

                }
            }
        }



    }

    ll res = 0 ;
    for(int i = 2 ; i <= 63 ; i++){
        res = (res + dp[n%2][1][1][1][i])%mod;
    }
    cout<< (res + mod) % mod;


}

因为输出的时候没有先加mod , 而wa,破大防

因为转移的时候只关心上一个和  当前字母不同的  ,

可以把上一个状态的所有和 求出来,然后把这个字母的状态去掉。

你可能感兴趣的:(dp,网络,算法,图论)