gdfzoj #1440 Max(状压DP)

原题链接
gdfzoj #1440 Max(状压DP)_第1张图片
gdfzoj #1440 Max(状压DP)_第2张图片


注意到这题求的的是最大值的期望,考虑使用状压dp维护最大值求解
发现对于每个n行c+1列的矩阵,发生 Aj=Aj+k 的事件概率和为1,即它们是互斥的,考虑将是否发生 Aj=Aj+k 的事件压入状态
我们可以记 f[i][mask][k] 表示第i个数经过mask的操作(mask为状态压缩值),值为k的概率
则易得转移方程 f[i][mask][k]=maskf[i][mask][kj]Pj[0,c]
得到f后,我们再设g[i][mask][k]表示前i个数使用过mask的操作,最大值为k的概率
有转移方程 g[i][mask][k]=smaskmax(p,q)=kg[i1][s][p]f[i][masks][q]
其中 smask 用枚举子集的方法处理,而 max(p,q)=k 则维护f,g数组的前缀和做到在O(1)时间计算完

#include
#include
#include
#define mod 1000000007
using namespace std;
typedef long long ll;
ll f[42][1050][32],g[42][1050][32],sf[42][1050][32],sg[42][1050][32];
ll p[12][42][5],n,m,c,ans=0,tmp,C; 
int main()
{
    scanf("%lld%lld%lld",&n,&m,&c),C=m*c;
    for (int _=1;_<=m;_++) for (int i=1;i<=n;i++) for (int j=0;j<=c;j++) scanf("%lld",&p[_][i][j]);
    for (int i=1;i<=n;i++) //被动转移计算f 
    {
        f[i][0][0]=1;
        for (int mask=1;mask<(1<<m);mask++)
            for (int j=1;j<=m;j++) if (mask & (1 << (j-1))) 
            {
                for (int k=0;k<=C;k++) for (int _=0;_<=c&&_<=k;_++) 
                    (f[i][mask][k]+=f[i][mask - (1<<(j-1))][k-_]*p[j][i][_])%=mod;
                break;
            }
    }
    for (int i=1;i<=n;i++) for (int mask=0;mask<(1<<m);mask++)
    {
        sf[i][mask][0]=f[i][mask][0];
        for (int _=1;_<=C;_++) sf[i][mask][_]=(f[i][mask][_]+sf[i][mask][_-1])%mod;
    }
    g[0][0][0]=1;
    for (int i=1;i<=n;i++) 
    {
        for (int mask=0;mask<(1<<m);mask++)
        {
            sg[i-1][mask][0]=g[i-1][mask][0];
            for (int _=1;_<=C;_++) sg[i-1][mask][_]=(sg[i-1][mask][_-1]+g[i-1][mask][_])%mod;
        }
        for (int mask=0;mask<(1<<m);mask++)
            for (int j=mask;;j=(j-1)&mask) //枚举子集
            {
                tmp=mask-j;
                for (int _=0;_<=C;_++) 
                {
                    (g[i][mask][_]+=((sg[i-1][tmp][_]*f[i][j][_] + sf[i][j][_]*g[i-1][tmp][_] - g[i-1][j][_]*f[i][tmp][_])%mod+mod)%mod)%=mod;
                }
                if (!j) break;
            } 
    }
    for (int _=1;_<=C;_++) ans=(ans+_*g[n][(1<<m)-1][_])%mod;
    printf("%lld\n",ans);
} 

你可能感兴趣的:(dp)