LightOJ1274---Beating the Dataset (概率dp)

You are in a contest, and unfortunately you don’t have much time. You have one problem in hand; you just glanced at the sample output and found that it just wants ‘YES’ or ‘NO’. So, you have made another plan instead of solving the problem as you know the system very well.

For this problem, every test case is stored in a separate file. When a submission is found, the system successively runs the solution on all tests of a problem, and for each test the checking process goes as follows. The input is copied to the file input.txt. Then the solution is launched. It reads the input from the file input.txt and writes the result to the file output.txt. When it finishes, the correct answer is copied to the file answer.txt. If the contents of the files answer.txt and output.txt match, the test is assumed to be passed; otherwise, the test is not passed.

So, you decided to write a program that would operate as follows. If the folder containing the program doesn’t contain the file answer.txt (i.e. the program is run on the first test), then the program outputs “YES”. Otherwise, the program outputs the contents of the file answer.txt. And before the contest, the sizes of the data files are given to you.

And it’s clear that the size of the file with the answer “YES” is 3 bytes, the size of the file with the answer “NO” is 2 bytes, and all the variants of the order of tests are equally probable. Now you want to calculate the average number of tests that your solution won’t pass.
Input

Input starts with an integer T (≤ 10), denoting the number of test cases.

Each case starts with a line containing two integers n (1 ≤ n ≤ 5000) and s (2n ≤ s ≤ 3n) where n denotes the number of data sets and s denotes the total size of the answer files.
Output

For each case, print the case number and the average number of tests your solution won’t pass. Error less than 10-6 will be ignored.
Sample Input

Output for Sample Input

4

3 7

1 2

1 3

4 10

Case 1: 2

Case 2: 1

Case 3: 0

Case 4: 2.5000000000
Note

For the first case, one of the three answers is “YES” and two answers are “NO”. If the order of tests is “YES-NO-NO”, then your solution won’t pass the second test only; if the order is “NO-YES-NO”, then it will pass none of the tests; if the order is “NO-NO-YES”, the solution won’t pass the first and the third tests.

一开始我设计的状态是 dp[a][b][flag] 表示已经有a个yes,b个no,上一次是yes/no时的期望错误次数,然后转移,但是爆内存了
所以稍微改了下状态,然后换成递推,用个滚动数组就行了
dp[i][j][flag] 到第i个位置,前面有j个yes,上一次是 yes/no 时期望错误次数

/************************************************************************* > File Name: I.cpp > Author: ALex > Mail: [email protected] > Created Time: 2015年05月18日 星期一 14时45分18秒 ************************************************************************/

#include <functional>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <queue>
#include <stack>
#include <map>
#include <bitset>
#include <set>
#include <vector>

using namespace std;

const double pi = acos(-1.0);
const int inf = 0x3f3f3f3f;
const double eps = 1e-15;
typedef long long LL;
typedef pair <int, int> PLL;

double dp[2][5010][2];
int m1, m2;

//double dfs(int a, int b, int flag) {
// double &res = dp[a][b][flag];
// if (res != -1.0) {
// return res;
// }
// if (a + b == m1 + m2) {
// return res = 0;
// }
// res = 0;
// double p1 = (m1 - a) * 1.0 / (m1 + m2 - a - b);
// double p2 = (m2 - b) * 1.0 / (m1 + m2 - a - b);
// if (!flag) {
// res = p1 * dfs(a + 1, b, 0) + p2 * (dfs(a, b + 1, 1) + 1);
// }
// else {
// res = p1 * (dfs(a + 1, b, 0) + 1) + p2 * dfs(a, b + 1, 1);
// }
// return res;
//}

int main() {
    int t, icase = 1;
    scanf("%d", &t);
    while (t--) {
        int n, s;
        scanf("%d%d", &n, &s);
        m1 = s - 2 * n;
        m2 = 3 * n - s;
        dp[n % 2][m1][0] = dp[n % 2][m1][1] = 0;
        for (int i = n - 1; i >= 0; --i) {
            for (int j = min(m1, i); j >= 0 && i - j <= m2; --j) {
                double p1 = (m1 - j) * 1.0 / (n - i);
                double p2 = (m2 - (i - j)) * 1.0 / (n - i);
                if (j + 1 <= m1) {
                    dp[i % 2][j][0] = dp[(i + 1) % 2][j + 1][0] * p1 + (dp[(i + 1) % 2][j][1] + 1) * p2;
                    dp[i % 2][j][1] = (dp[(i + 1) % 2][j + 1][0] + 1) * p1 + dp[(i + 1) % 2][j][1] * p2;
                }
                else {
                    dp[i % 2][j][0] = (dp[(i + 1) % 2][j][1] + 1) * p2;
                    dp[i % 2][j][1] = dp[(i + 1) % 2][j][1] * p2;
                }
            }
        }
        printf("Case %d: %.12f\n", icase++, dp[0][0][0]);
    }
    return 0;
}

你可能感兴趣的:(dp)