题意:
给三个骰子,骰子有k1,k2,k3个面。每次掷骰子,得到的点数和为前进的步数,如果掷骰子的结果为c1,c2,c3(有序)则回到原点。从0点出发,到达大于N的点,则游戏结束。求掷骰子次数的期望。
转移方程为:
dp[i] = dp[i+l1+l2+l3]*(1/k1/k2/k3)+dp[0]*(1/k1/k2/k3)+1
其中枚举l1,l2,l3为掷骰子得到情况。转移方程的i从N开始到0结束。表示从第i点到结束,掷骰子的期望。
有很多做法。
1.可以构造矩阵求解。
2. 展开多项式,每项都由dp[0]和1组成,最后求出dp[0]的表达式,最后算出dp[0].
3.没看懂~~~
4. 迭代法---扩展到二分-----(其实我数学不好,以下的看着感觉行就可以了)
迭代法的相关内容http://wenku.baidu.com/link?url=1hVfu2fzddmSlZkn0El-be6SNLDWK1mLhe9_RvZJ2eeoyHB35u6hwExfY7ixX74xTW4GsTtBcaeb3mXIXM3sCOgQX5LH-WCjINZ1Lbc_s67 。。
这个方法是根据数值分析里迭代法想到的。最后是用二分完成的。
根据转移方程,假设一个dp[0],根据转移方程,计算得到的每个dp值必然是比实际要小的。最后算 dp[0] = K + dp[0]/k1/k2/k3 + 1(K如转移方程计算得到的累加和)
计算dp[0] = (K+1)*k1*k2*k3/(k1*k2*k3-1)
可知: 每次代入dp[0]的值为res,由res计算得到的新的dp[0]是递增的,且收敛
理由--归纳法:1)res = 0代入计算后,dp[0] > res成立 (第一步成立)
2)每次迭代后,由于res < dp[0] 令res = dp[0],那么新一轮迭代后的每个dp值都比原来大(显然成立),那么K是增大的,那么dp[0]增大
因此迭代得到的dp[0]是递增的
3)由于res < dp[0] 那么,本次迭代得到的所有dp值都小于实际值,那么显然K比实际值也要小,那么dp[0]比实际值小成立。
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
但是真的代入dp[0] =0一直迭代到最后结果是会超时的。
能不能更快找到收敛值呢?
1. 显然代入dp[0] = res 得到的新的dp[0],dp[0] - res > 0 (以下都是a了题以后再总结的)
2 当dp[0] - res = 0的时候res就是解了
3那么当res > 实际的dp[0]时,迭代得到的dp[0]会大于res吗?
不会!!,得到的dp[0] < res,同样是收敛的,递减的。
4.因此,代入任意res到转移方程中,无数次迭代都会取到解。
5. 为什么?-----猜的,根据方法2,dp[0] = Adp[0]+B ============> x = ax+b ------------>a是小于1的常数(迭代收敛成立),所以用迭代法是正确的。(一句话就讲完了,上面的证明都是废话)。a为什么会小于1?转移方程可以看出来的 对于dp[i], 转移时有 令P = 1/k1/k2/k3,dp[i] = P*dp[0]+(1-P)*(a*dp[0]) a是展开方程得到的,用归纳法,dp[n]的时候a已经小于1了,以后得到的所有ai都会小于1.结束。。。贴代码----------------------
最后本方法效率不高,方法3的效率高,是n*(k1+k2+k3)的。本方法是n*(k1+k2+k3)*log(len)
#include
#include
#include
#include
#include
using namespace std;
#define ld double
ld p[20];
ld ans[600];
int a,b,c,k1,k2,k3,m;
ld del;
ld getans(ld res){
res = res*del+1;
for(int i = m; i >= 0; i--){
if(i != 0)
ans[i] = res;
else
ans[i] = 1;
for(int j = 3;j <= k1+k2+k3; j++)
ans[i] += ans[i+j]*p[j];
}
ans[0] = ans[0]/(1-del);
return ans[0];
}
int main(){
int n;
scanf("%d",&n);
ld low = 0, high = 100000000000.0;
while(n--){
scanf("%d%d%d%d%d%d%d",&m,&k1,&k2,&k3,&a,&b,&c);
del = 1.0/(k1*k2*k3);
for(int i = 0;i <= k1+k2+k3;i++)
p[i] = 0;
for(int l1 = 1; l1 <= k1; l1++){
for(int l2 = 1; l2 <= k2; l2++)
for(int l3 = 1; l3 <= k3; l3++){
if(l1 == a && l2 == b && l3 == c){
}
else p[l1+l2+l3] += 1;
}
}
memset(ans,0,sizeof(ans));
for(int i = 0;i <= k1+k2+k3; i++)
p[i] /= k1*k2*k3;
low = 0, high = 100000000.0;
int tt =70;
while(tt-- ){
ld mid1 = (high+low)/2 ;
ld res1 = getans(mid1) - mid1;
if(res1 < 0 ) high = mid1;
else low = mid1;
}
cout<< setprecision(10) << fixed << low <