poj 1322
取出 n 种巧克力,巧克力有 c 种颜色,当取出的巧克力有偶数个的时候他会把它吃掉,问剩余 m 种巧克力的时候的概率。
首先我们用概率dp来解决这个问题,
dp[i][j]=dp[i−1][j−1]∗(c−j+1)/c+dp[i−1][j+1]∗(j+1)/c
递推式很简单可是得注意复杂度太高了 O(108)
因为要保留3位小数,所以可以发现当 n≥103 可以发现就不会影响了。
这样可以卡过去
#include
#include
#include
#include
#include
#include
#include
#define Debug(x) cout<<(x)<
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int,int > PII;
typedef map<int,int >::iterator MIT;
const int maxn = 1e6+10;
double dp[2][maxn];
int main()
{
//freopen("H:\\c++\\file\\stdin.txt","r",stdin);
int c,n,m;
while(scanf("%d%d%d",&c,&n,&m)&& c)
{
memset(dp,0,sizeof(dp));
dp[0][0] = 1.0;
if(m>c|| m>n || (n+m) & 1)
{
printf("0.000\n");continue;
}
if(n>1000)n = 1000+n%2;
for(int i =1 ; i<=n ; ++i)
for(int j = 0 ; j<=i&& j<=c ; ++j)
{
dp[i&1][j] = 0;
if((i+j)&1)continue;
if(j>0)dp[i&1][j] +=dp[1-(i&1)][j-1]*(c-j+1)/c;
if(j+1<=i-1)dp[i&1][j]+=dp[1-(i&1)][j+1]*(j+1)/c;
}
printf("%.3f\n",dp[n%2][m]);
}
return 0;
}
这个的效率会显著加快,可是编码有点难度。。
#include
#include
#include
#include
#include
#include
#include
#define Debug(x) cout<<(x)<
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int,int > PII;
typedef map<int,int >::iterator MIT;
double po[111], ne[111], pp[111], nn[111];
double powmod(double x, int n) {
double ret = 1;
while(n) {
if(n&1) ret *= x;
x *= x;
n /= 2;
}
return ret;
}
int c, n, m;
// 由于C(n, k)可能会很大,不能直接预处理出组合数
double cal(double ret, int n, int k) {
if(n-k < k) k = n-k;
for(int i = n;i > n-k; i--)
ret *= i;
for(int i = 1;i <= k; i++)
ret /= i;
return ret;
}
void solve() {
int i, j;
for(i = 0;i <= c; i++) {
po[i] = ne[i] = pp[i] = nn[i] = 0;
}
double chu = powmod(1.0/2, m);
for(i = 0;i <= m; i++) {
int now = i-m+i;
int flag = 1;
if((m-i)&1) flag = -1;
if(now >= 0) po[now] += cal(chu*flag, m, i); // 保存e^(kx)的系数
else ne[-now] += cal(chu*flag, m, i); // 保存e^(-kx)的系数
}
chu = powmod(1.0/2, c-m);
for(i = 0;i <= c-m; i++) {
double cur = cal(chu, c-m, i);
for(j = 0;j <= m; j++) {
int now = j+i-(c-m)+i;
if(now >= 0) pp[now] += po[j]*cur; // 直接合并系数
else nn[-now] += po[j]*cur;
}
for(j = 0;j <= m; j++) {
int now = -j + i-(c-m)+i;
if(now >= 0) pp[now] += ne[j]*cur;
else nn[-now] += ne[j]*cur;
}
}
double ans = 0;
for(i = 1;i <= c; i++) {
ans += cal( pp[i]*powmod((double)i/c, n), c, m);
}
for(i = 1;i <= c; i++) {//e^-kx//(-k)^n
if(n&1) nn[i] = -nn[i];
ans += cal( nn[i]*powmod((double)i/c, n), c, m);
}
printf("%.3f\n", ans);
}
int main() {
while(scanf("%d", &c) != -1 && c) {
scanf("%d%d", &n, &m);
if(m > n || m > c || (n-m)%2==1) {
puts("0.000"); continue;
}
// 尤其要注意n等于0 && m等于0 要特判
if(n == 0 && m == 0) {
puts("1.000"); continue;
}
solve();
}
return 0;
}