ZJUT 地下迷宫 (高斯求期望)

http://cpp.zjut.edu.cn/ShowProblem.aspx?ShowID=1423


设dp[i]表示在i点时到达终点要走的期望步数,那么dp[i] = ∑1/m*dp[j] + 1,j是与i相连的点,m是与i相邻的点数,建立方程组求解。重要的一点是先判断DK到达不了的点,需要bfs预处理一下进行离散化,再建立方程组。


#include <stdio.h>
#include <iostream>
#include <map>
#include <set>
#include <list>
#include <stack>
#include <vector>
#include <math.h>
#include <string.h>
#include <queue>
#include <string>
#include <stdlib.h>
#include <algorithm>
#define LL __int64
//#define LL long long
#define eps 1e-9
#define PI acos(-1.0)
using namespace std;
const int INF = 0x3f3f3f3f;
const int mod = 10000007;

int dir[4][2] = {{-1,0},{1,0},{0,-1},{0,1}};
int n,m;
int cnt;
char g[15][15];
int equ,var;
double a[110][110];
double x[110];
int num[15][15];
int sx,sy,ex,ey;

struct node
{
    int x,y;
};

bool Gauss()
{
    int row,col,max_r;
    int i,j;
    row = col = 0;
    while(row < equ && col < var)
    {
        max_r = row;
        for(i = row+1; i < equ; i++)
            if(fabs(a[i][col]) > fabs(a[max_r][col]))
                max_r = i;
        if(max_r != row)
        {
            for(j = col; j <= var; j++)
                swap(a[row][j],a[max_r][j]);
        }
        if(fabs(a[row][col]) < eps)
        {
            col++;
            continue;
        }
        for(i = row+1; i < equ; i++)
        {
            if(fabs(a[i][col]) < eps) continue;
            double t = a[i][col] / a[row][col];
            a[i][col] = 0;
            for(j = col+1; j <= var; j++)
                a[i][j] -= a[row][j]*t;
        }
        row++;
        col++;
    }
    for(i = row; i < equ; i++)
    {
        if(fabs(a[i][var]) > eps)
            return false;
    }
    for(i = var-1; i >= 0; i--)
    {
        if(fabs(a[i][i]) < eps) continue;
        double t = a[i][var];
        for(j = i+1; j < var; j++)
            t -= a[i][j]*x[j];
        x[i] = t/a[i][i];
    }
    return true;
}

void bfs()
{
    cnt = 0;
    memset(num,-1,sizeof(num));
    queue <struct node> que;
    que.push((struct node){sx,sy});
    num[sx][sy] = cnt++;
    while(!que.empty())
    {
        struct node u = que.front();
        que.pop();
        for(int d = 0; d < 4; d++)
        {
            int x = u.x + dir[d][0];
            int y = u.y + dir[d][1];
            if(x >= 1 && x <= n && y >= 1 && y <= m && g[x][y] != 'X' && num[x][y] == -1)
            {
                que.push( (struct node){x,y} );
                num[x][y] = cnt++;
            }
        }
    }
}

int main()
{
	while(~scanf("%d %d",&n,&m))
	{
		for(int i = 1; i <= n; i++)
        {
            scanf("%s",g[i]+1);
            for(int j = 1; j <= m; j++)
            {
                if(g[i][j] == 'D')
                {
                    sx = i;
                    sy = j;
                }
                if(g[i][j] == 'E')
                {
                    ex = i;
                    ey = j;
                }
            }
        }
        bfs();
        equ = var = cnt;
		memset(a,0,sizeof(a));
		memset(x,0,sizeof(x));

		for(int i = 1; i <= n; i++)
		{
			for(int j = 1; j <= m; j++)
			{
				if(g[i][j] == 'X') continue;
				//printf("%d %d %d\n",i,j,M[make_pair(i,j)]);
				int t = num[i][j];
				if(t == -1) continue;
				if(g[i][j] == 'E')
				{
					a[t][t] = 1;
					a[t][cnt] = 0;
				}
				else
				{
					a[t][t] = 1;
					a[t][cnt] = 1;
					int c = 0;
					for(int d = 0; d < 4; d++)
					{
					    int ii = i + dir[d][0];
					    int jj = j + dir[d][1];
					    if(ii >= 1 && ii <= n && jj >= 1 && jj <= m && g[ii][jj] != 'X' && num[ii][jj] != -1)
                            c++;
					}
					for(int d = 0; d < 4; d++)
					{
                        int ii = i + dir[d][0];
					    int jj = j + dir[d][1];
					    if(ii >= 1 && ii <= n && jj >= 1 && jj <= m && g[ii][jj] != 'X' && num[ii][jj] != -1)
                        {
                            int tt = num[ii][jj];
                            a[t][tt] = -1.0/c;
                        }
					}
				}
			}
		}
		if(!Gauss())
            printf("tragedy!\n");
        else if(fabs(x[num[sx][sy]]-1000000)<eps)
            printf("tragedy!\n");
        else printf("%.2lf\n",x[num[sx][sy]]);
	}
	return 0;
}


你可能感兴趣的:(概率DP)