牛客ACM多校赛第5场 subsequence 1

链接:https://ac.nowcoder.com/acm/contest/885/G
来源:牛客网

You are given two strings s and t composed by digits (characters '0' ∼\sim∼ '9'). The length of s is n and the length of t is m. The first character of both s and t aren't '0'.
 

Please calculate the number of valid subsequences of s that are larger than t if viewed as positive integers. A subsequence is valid if and only if its first character is not '0'.

Two subsequences are different if they are composed of different locations in the original string. For example, string "1223" has 2 different subsequences "23".


Because the answer may be huge, please output the answer modulo 998244353.

题意:给你两个由数字组成字符串s和t,以及它们的长度n和m,问s中有多少大于t的子序列

输入描述

The first line contains one integer T, indicating that there are T tests.

Each test consists of 3 lines.

The first line of each test contains two integers n and m, denoting the length of strings s and t.

The second line of each test contains the string s.

The third line of each test contains the string t.

* 1≤m≤n≤30001 \le m \le n \le 30001≤m≤n≤3000.
 

* sum of n in all tests ≤3000\le 3000≤3000.

 

* the first character of both s and t aren't '0'.
 

输出描述:

For each test, output one integer in a line representing the answer modulo 998244353.

示例1

输入

复制

3
4 2
1234
13
4 2
1034
13
4 1
1111
2

输出

复制

9
6
11

说明

For the last test, there are 6 subsequences "11", 4 subsequcnes "111" and 1 subsequence "1111" that are valid, so the answer is 11.

思路:

dp+组合数C数组预处理

 

1.一开始我是dp直接莽,把s>t的所有长度大于等于m(t的长度)的结果直接求出来,然后TLE了

2.后来我只用dp找出s>t的长度等于m的数量,其余数量用组合数求出来,但是没有预处理C,只是去别人那找了个快速求C的函数,结果TLE

#include
using namespace std;
typedef long long ll;
const int MAX = 1e5+50;
const ll mod = 998244353;
ll dp[3005][3005][3],tot=0;//[3],分为小于,等于,大于三种情况
char s[3005],t[3005];
bool vis[MAX];
int p[MAX];
void make_p(){//挖素数
    vis[0]=vis[1]=1;
    for(int i=2;i<=MAX;i++)
    if(!vis[i]){
        p[++tot]=i;
        for(int j=i*2;j<=MAX;j+=i) vis[j]=1;
    }
}
ll qsm(ll a,ll b){//快速幂
    ll ans=1,w=a;
    for(;b;b>>=1,w=(w*w)%mod) if(b&1) ans=(ans*w)%mod;
    return ans;
}
ll Get(ll x,ll y){
    ll sum=0;
    for(;x;x/=y) sum+=x/y;
    return sum;
}
ll C(int x,int y){
    ll ans=1;
    for(int i=1;i<=tot&&p[i]<=x;i++){
        int T=Get(x,p[i])-Get(x-y,p[i])-Get(y,p[i]);
        ans=(ans*qsm(p[i],T))%mod;
    }
    return ans;
}
int main(){
    memset(vis,0,sizeof(vis));
    make_p();
    int T;
    scanf("%d",&T);
    while(T--){
        int n,m;
        scanf("%d%d",&n,&m);
        scanf("%s",s+1);
        scanf("%s",t+1);
        memset(dp,0,sizeof(dp));
        dp[0][0][1]=1;
        ll ans=0;
        for(int i=1;i<=n;i++){//目前在第i位数
            for(int j=0;jm){
                    dp[i][j+1][2]=(dp[i][j+1][2]+dp[i-1][j][1])%mod;
                    dp[i][j+1][2]=(dp[i][j+1][2]+dp[i-1][j][0])%mod;
                }
                else{
                    dp[i][j+1][0]=(dp[i][j+1][0]+dp[i-1][j][0])%mod;
                    if(s[i]>t[j+1]){
                        dp[i][j+1][2]=(dp[i][j+1][2]+dp[i-1][j][1])%mod;
                    }
                    else if(s[i]==t[j+1]){
                        dp[i][j+1][1]=(dp[i][j+1][1]+dp[i-1][j][1])%mod;
                    }
                    else{
                        dp[i][j+1][0]=(dp[i][j+1][0]+dp[i-1][j][1])%mod;
                    }
                }
                if(j+1>=m) ans=(ans+dp[i][j+1][2])%mod;
            }
        }
        for(int i=1;i<=n;i++){
            if(s[i]=='0') continue;
            int num=n-i+1;
            if(num

3.赛后看了别人的代码(别人的dp只处理了到(i,j)为止s等于t的情况,发现的确是这样的,因为我们这里只找s>t长度为m的情况,小于的情况只会在长度大于m的时候转变为s>t,这里在外面直接可以求出,所以不用处理),发现别人都预处理了C,我也先处理一下,但是依然超时,心态炸裂,仔细对比试验,居然是我的memset(dp,0,sizeof(dp))疯狂T,虽然是1e9的内存处理,但是以后都没碰到过这种情况,唉

#include
using namespace std;
typedef long long ll;
const int MAX = 1e5+50;
const ll mod = 998244353;
ll dp[3005][3005][2];//[2],分为等于,大于两种情况
char s[3005],t[3005];
ll C[3005][3005];
void init(){//先把3000以内的组合数先搞出来
    for(int i=0;i<=3000;i++){
        C[i][0]=1;
        for(int j=1;j<=i;j++){
            C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
        }
    }
}
int main(){
    init();
    int T;
    scanf("%d",&T);
    while(T--){
        int n,m;
        scanf("%d%d",&n,&m);
        scanf("%s",s+1);
        scanf("%s",t+1);
        memset(dp,0,sizeof(dp));
        dp[0][0][0]=1;
        ll ans=0;
        //先算长度为m的大于t的子序列
        for(int i=1;i<=n;i++){//目前在第i位数
            for(int j=0;jt[j+1]){//现在大于
                    dp[i][j+1][1]=(dp[i][j+1][1]+dp[i-1][j][0])%mod;
                }
                else if(s[i]==t[j+1]){//现在依旧等于
                    dp[i][j+1][0]=(dp[i][j+1][0]+dp[i-1][j][0])%mod;
                }
                if(j+1==m) ans=(ans+dp[i][j+1][1])%mod;
            }
        }
        for(int i=m+1;i<=n;i++){//再算长度大于m的子序列的数量
            ans=(ans+C[n][i])%mod;
        }
        for(int i=1;i<=n;i++){
            if(s[i]=='0'){
                for(int j=m;j<=n-i;j++){
                    ans=(ans-C[n-i][j]+mod)%mod;
                }
            }
        }
//      for(int i=1;i<=n;i++){
//          for(int j=0;j<=i;j++){
//              printf("i=%d,j=%d,(%d,%d)  ",i,j,dp[i][j][0],dp[i][j][1]);
//          }
//          printf("\n");
//      }
        printf("%lld\n",ans);
    }
    return 0;
}

4.后期手动处理dp,才没有超时,而且看了别人的dp,发现会更简单,二位数组dp[i][j],表示到第i个数字,已经取了j个数字,s==t的数量,如果取的s[i]大于t[j],说明s已经大于t,直接进行处理就行了,的确是快了很多,比我的三维快了20%左右的样例。

#include
using namespace std;
typedef long long ll;
const int MAX = 1e5+50;
const ll mod = 998244353;
ll dp[3005][3005];//[2],分为等于,大于两种情况
char s[3005],t[3005];
ll C[3005][3005];
void init(){//先把3000以内的组合数先搞出来 C[n][m]
    for(int i=0;i<=3000;i++){
        C[i][0]=1;
        for(int j=1;j<=i;j++){
            C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
        }
    }
}
int main(){
    init();
    int T;
    scanf("%d",&T);
    while(T--){
        int n,m;
        scanf("%d%d",&n,&m);
        scanf("%s",s+1);
        scanf("%s",t+1);
        ll ans=0;
        for(int i=0;i<=n;++i) for(int j=0;j<=m;++j) dp[i][j]=0;
        dp[0][0]=1;//到(i,j)为止,s==j的情况有几种
        //因为处理的是s==t的情况,开头是0的情况不会被算进去
        for(int i=1;i<=n;i++){//是否取第i个数字
            dp[i][0]=1;//不取的情况s==t
            for(int j=1;j<=m&&j<=i;j++){//已经取了j个数字
                dp[i][j]=dp[i-1][j];//假如i不取
                //假如i取
                if(s[i]==t[j]){
                    dp[i][j]=(dp[i][j]+dp[i-1][j-1])%mod;
                }
                else if(s[i]>t[j]){
                    //如果在i处s大于t,第j个字母固定,则前面dp[i-1][j-1]种相等
                    //的情况*在剩下n-i个字母中取m-j个的组合数,就是答案之一
                    ans=(ans+(dp[i-1][j-1]*C[n-i][m-j]))%mod;
                }
            }
        }
        for(int i=1;i<=n;i++){//加上长度大于m的情况
            if(s[i]!='0'){
                for(int j=m;j<=n-i;j++){
                    ans=(ans+C[n-i][j])%mod;
                }
            }
        }
//      for(int i=1;i<=n;i++){
//          for(int j=0;j<=i;j++){
//              printf("i=%d,j=%d,(%d,%d)  ",i,j,dp[i][j][0],dp[i][j][1]);
//          }
//          printf("\n");
//      }
        printf("%lld\n",ans);
    }
    return 0;
}

 

你可能感兴趣的:(dp)