poj 1322 Chocolate(生成函数 or 概率dp)

Problem Link

poj 1322
取出 n 种巧克力,巧克力有 c 种颜色,当取出的巧克力有偶数个的时候他会把它吃掉,问剩余 m 种巧克力的时候的概率。

Analysis

概率dp

首先我们用概率dp来解决这个问题,

dp[i][j]=dp[i1][j1](cj+1)/c+dp[i1][j+1](j+1)/c

递推式很简单可是得注意复杂度太高了 O(108)
因为要保留3位小数,所以可以发现当 n103 可以发现就不会影响了。
这样可以卡过去

#include 
#include
#include
#include
#include
#include
#include
#define Debug(x) cout<<(x)<
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int,int > PII;
typedef map<int,int >::iterator MIT;

const int maxn = 1e6+10;

double dp[2][maxn];

int main()
{
    //freopen("H:\\c++\\file\\stdin.txt","r",stdin);


    int c,n,m;
   while(scanf("%d%d%d",&c,&n,&m)&& c)
   {
       memset(dp,0,sizeof(dp));
       dp[0][0] = 1.0;
       if(m>c|| m>n || (n+m) & 1)
       {
           printf("0.000\n");continue;
       }
       if(n>1000)n = 1000+n%2;
       for(int i =1 ; i<=n ; ++i)
        for(int j = 0 ; j<=i&& j<=c ; ++j)
       {
           dp[i&1][j] = 0;
           if((i+j)&1)continue;
           if(j>0)dp[i&1][j] +=dp[1-(i&1)][j-1]*(c-j+1)/c;
           if(j+1<=i-1)dp[i&1][j]+=dp[1-(i&1)][j+1]*(j+1)/c;
       }

       printf("%.3f\n",dp[n%2][m]);

   }


    return 0;
}

生成函数

这里写图片描述
这里写图片描述
这个的效率会显著加快,可是编码有点难度。。

#include 
#include
#include
#include
#include
#include
#include
#define Debug(x) cout<<(x)<
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int,int > PII;
typedef map<int,int >::iterator MIT;


double po[111], ne[111], pp[111], nn[111];

double powmod(double x, int n) {
    double ret = 1;
    while(n) {
        if(n&1) ret *= x;
        x *= x;
        n /= 2;
    }
    return ret;
}

int c, n, m;
// 由于C(n, k)可能会很大,不能直接预处理出组合数
double cal(double ret, int n, int k) {
    if(n-k < k) k = n-k;
    for(int i = n;i > n-k; i--)
        ret *= i;
    for(int i = 1;i <= k; i++)
        ret /= i;
    return ret;
}

void solve() {
    int i, j;
    for(i = 0;i <= c; i++)  {
        po[i] = ne[i] = pp[i] = nn[i] = 0;
    }
    double chu = powmod(1.0/2, m);
    for(i = 0;i <= m; i++) {
        int now = i-m+i;
        int flag = 1;
        if((m-i)&1) flag = -1;
        if(now >= 0) po[now] += cal(chu*flag, m, i); // 保存e^(kx)的系数
        else    ne[-now] += cal(chu*flag, m, i);  // 保存e^(-kx)的系数
    }
    chu = powmod(1.0/2, c-m);
    for(i = 0;i <= c-m; i++) {
        double cur = cal(chu, c-m, i);
        for(j = 0;j <= m; j++) {
            int now = j+i-(c-m)+i;
            if(now >= 0)    pp[now] += po[j]*cur;  // 直接合并系数
            else    nn[-now] += po[j]*cur;
        }
        for(j = 0;j <= m; j++) {
            int now = -j + i-(c-m)+i;
            if(now >= 0)    pp[now] += ne[j]*cur;
            else    nn[-now] += ne[j]*cur;
        }
    }
    double ans = 0;
    for(i = 1;i <= c; i++) {
        ans += cal( pp[i]*powmod((double)i/c, n),  c, m);
    }
    for(i = 1;i <= c; i++) {//e^-kx//(-k)^n
        if(n&1) nn[i] = -nn[i];
        ans += cal( nn[i]*powmod((double)i/c, n),  c, m);
    }
    printf("%.3f\n", ans);
}

int main() {
    while(scanf("%d", &c) != -1 && c) {
        scanf("%d%d", &n, &m);
        if(m > n || m > c || (n-m)%2==1) {
            puts("0.000");  continue;
        }
        // 尤其要注意n等于0 && m等于0 要特判
        if(n == 0 && m == 0) {
            puts("1.000");  continue;
        }
        solve();
    }
    return 0;
}

你可能感兴趣的:(算法&数据结构)