作为提高组 d 2 t 1 d2t1 d2t1,比去年难
所以这道题我打的特别的差
这道题我们很显然可以看到可以打一个暴力
复杂度 o ( n ∗ n ! ) o(n*n!) o(n∗n!)
我考场上就达到了这里——我太菜了
void dfs(int u,ll plus){
if(u==n+1){
if(!tot)return;
Rep(i,1,m)if(cnt[i]>tot/2)return;
ans=(ans+plus)%mod;
return;
}
dfs(u+1,plus);
Rep(i,1,m)
if(a[u][i]){
cnt[i]++;
tot++;
dfs(u+1,plus*a[u][i]%mod);
tot--;
cnt[i]--;
}
}
int main()
{
read(n),read(m);
Rep(i,1,n)Rep(j,1,m)read(a[i][j]);
dfs(1,1);
printf("%lld\n",ans);
return 0;
}
因为每一行只能选一个,所以我们可以每行爆搜找选哪个(或者不选),随后判断行不行
这道题一拿到题就觉得是 d p dp dp,但是没想出来转移方程,想要拿到64分,我们需要把 m m m=2,3的情况拆成两种来讨论,当 m = 2 m=2 m=2时,用 f [ i ] [ j ] [ k ] f[i][j][k] f[i][j][k]表示第 i i i行,第一种食材用了 j j j次,第二种用了 k k k次,可以显然得到一个转移方程 f [ i ] [ j ] [ k ] = s u m { f [ i − 1 ] [ j ] [ k ] , f [ i − 1 ] [ j − 1 ] [ k ] ∗ a [ i ] [ 1 ] , f [ i − 1 ] [ j ] [ k − 1 ] ∗ a [ i ] [ 2 ] } f[i][j][k]=sum\{f[i-1][j][k],f[i-1][j-1][k]*a[i][1],f[i-1][j][k-1]*a[i][2]\} f[i][j][k]=sum{f[i−1][j][k],f[i−1][j−1][k]∗a[i][1],f[i−1][j][k−1]∗a[i][2]}
当 m = 3 m=3 m=3时就是四位 d p dp dp
if(m==2){
f[0][0][0]=1;
Rep(i,1,n)
Rep(j,0,i)
Rep(k,0,i){
if(j+k>i)break;
if(j>=1)f[i][j][k]+=f[i-1][j-1][k]*a[i][1],f[i][j][k]%=mod;
if(k>=1)f[i][j][k]+=f[i-1][j][k-1]*a[i][2],f[i][j][k]%=mod;
f[i][j][k]+=f[i-1][j][k],f[i][j][k]%=mod;
}
Rep(i,1,n)
ans=(ans+f[n][i][i])%mod;
printf("%lld\n",ans);
}
if(m==3){
g[0][0][0][0]=1;
Rep(i,1,n)
Rep(j,0,n)
Rep(k,0,n)
Rep(l,0,n){
if(j+k+l>i)break;
if(j>=1)g[i][j][k][l]+=g[i-1][j-1][k][l]*a[i][1],g[i][j][k][l]%=mod;
if(k>=1)g[i][j][k][l]+=g[i-1][j][k-1][l]*a[i][2],g[i][j][k][l]%=mod;
if(l>=1)g[i][j][k][l]+=g[i-1][j][k][l-1]*a[i][3],g[i][j][k][l]%=mod;
g[i][j][k][l]+=g[i-1][j][k][l],g[i][j][k][l]%=mod;
}
Rep(i,0,n)
Rep(j,0,n)
Rep(k,0,n){
int tot=i+j+k;
if(i>tot/2||j>tot/2|k>tot/2||tot==0)continue;
ans+=g[n][i][j][k];
ans%=mod;
}
printf("%lld\n",ans);
}
我们换一种想法,我们把所有可能的情况先算出来,然后再把不符合条件的都减去,怎么算减去的呢,我们可以想到超过一半的只可能有一种菜,我们先循环那种菜再第几列(col),其他的看成一种菜,那么我们设计状态 f [ i ] [ j ] [ k ] f[i][j][k] f[i][j][k]表示前 i i i道菜,超过一半的菜有 j j j道,剩下的选了 k k k道,那么转移方程就是 f [ i ] [ j ] [ k ] = s u m { f [ i − 1 ] [ j ] [ k ] , f [ i − 1 ] [ j − 1 ] [ k ] ∗ a [ i ] [ c o l ] , f [ i − 1 ] [ j ] [ k − 1 ] ∗ ( s [ i ] − a [ i ] [ c o l ] ) } f[i][j][k]=sum\{f[i-1][j][k],f[i-1][j-1][k]*a[i][col],f[i-1][j][k-1]*(s[i]-a[i][col])\} f[i][j][k]=sum{f[i−1][j][k],f[i−1][j−1][k]∗a[i][col],f[i−1][j][k−1]∗(s[i]−a[i][col])}
ans=1;
Rep(i,1,n){
Rep(j,1,m)s[i]=(s[i]+a[i][j])%mod;
ans=ans*(s[i]+1)%mod;
}
Rep(col,1,m){
memset(h,0,sizeof(h));
h[0][0][0]=1;
Rep(i,1,n)
Rep(j,0,n)
Rep(k,0,n){
if(j+k>i)break;
if(j>=1)h[i][j][k]+=h[i-1][j-1][k]*a[i][col],h[i][j][k]%=mod;
if(k>=1)h[i][j][k]+=h[i-1][j][k-1]*(s[i]-a[i][col]),h[i][j][k]%=mod;
h[i][j][k]+=h[i-1][j][k],h[i][j][k]%=mod;
}
Rep(i,0,n)
Rep(j,0,n){
int tot=i+j;
if(tot>n)continue;
if(i<=tot/2)continue;
ans=(ans-h[n][i][j]+mod)%mod;
}
}
printf("%lld\n",(ans-1+mod)%mod);
我们发现,我们最后统计的时候是和 j j j, k k k没有关系的,我们只要知道他们的差就可以了,所以我们可以优化掉一维,复杂度变成了 O ( n 2 m ) O(n^2m) O(n2m),可以过啦!
转移就变成了 f [ i ] [ j ] = s u m { f [ i − 1 ] [ j ] , f [ i − 1 ] [ j − 1 ] ∗ a [ i ] [ c o l ] , f [ i − 1 ] [ j + 1 ] ∗ ( s [ i ] − s [ i ] [ c o l ] ) } f[i][j]=sum\{f[i-1][j],f[i-1][j-1]*a[i][col],f[i-1][j+1]*(s[i]-s[i][col])\} f[i][j]=sum{f[i−1][j],f[i−1][j−1]∗a[i][col],f[i−1][j+1]∗(s[i]−s[i][col])}但是第二维有可能是负的,所以要整体平移
# include
# include
# include
# include
# include
# include
# include
# include
# include
# include
# include
# include
# include
# define Rep(i,a,b) for(int i=a;i<=b;i++)
# define _Rep(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int N=5e3+5;
const int ZERO=110;
const int mod=998244353;
typedef long long ll;
template <typename T>void read(T &x){
x=0;int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+c-'0';
x*=f;
}
int n,m;
int a[N][N];
ll f[45][45][45],g[45][45][45][45],s[N],h[45][45][45],dp[ZERO][2*ZERO+10];
int cnt[N],tot;
ll ans;
void dfs(int u,ll plus){
if(u==n+1){
if(!tot)return;
Rep(i,1,m)if(cnt[i]>tot/2)return;
ans=(ans+plus)%mod;
return;
}
dfs(u+1,plus);
Rep(i,1,m)
if(a[u][i]){
cnt[i]++;
tot++;
dfs(u+1,plus*a[u][i]%mod);
tot--;
cnt[i]--;
}
}
int main()
{
read(n),read(m);
Rep(i,1,n)Rep(j,1,m)read(a[i][j]),a[i][j]%=mod;
ans=1;
Rep(i,1,n){
Rep(j,1,m)s[i]=(s[i]+a[i][j])%mod;
ans=ans*(s[i]+1)%mod;
}
Rep(col,1,m){
memset(dp,0,sizeof(dp));
dp[0][ZERO]=1;
Rep(i,1,n)
for(int j=-1*i;j<=i;j++){
if(j-1+ZERO>=0)dp[i][j+ZERO]+=dp[i-1][j-1+ZERO]*a[i][col],dp[i][j+ZERO]%=mod;
dp[i][j+ZERO]+=dp[i-1][j+1+ZERO]*((s[i]-a[i][col]+mod)%mod),dp[i][j+ZERO]%=mod;
dp[i][j+ZERO]+=dp[i-1][j+ZERO],dp[i][j+ZERO]%=mod;
}
Rep(i,1,n)ans=(ans-dp[n][i+ZERO]+mod)%mod;
}
printf("%lld\n",(ans-1+mod)%mod);
return 0;
}
这道题其实64分不难,但是我特别的菜只打了暴力,所以只有32。。。
所以要好好练一练 d p dp dp