UOJ #211. 【UER #6】逃跑

链接:

link

题解:

E×all=(aiave)2×all=a2i×all(ai)2 E × a l l = ∑ ( a i − a v e ) 2 × a l l = ∑ a i 2 × a l l − ( ∑ a i ) 2

f(i,x,y) f ( i , x , y ) 表示走了 i i 步第一次走到 (x,y) ( x , y ) 的方案数, g(i,x,y) g ( i , x , y ) 表示走了 i i 步走到 (x,y) ( x , y ) 的方案数,容斥可得

f(i,x,y)=g(i,x,y)j<if(j,x,y)×g(ij,0,0) f ( i , x , y ) = g ( i , x , y ) − ∑ j < i f ( j , x , y ) × g ( i − j , 0 , 0 )

sumi=f(i,x,y) s u m i = ∑ f ( i , x , y ) ,那么 ai=sumi(w1+w2+w3+w4)ni ∑ a i = ∑ s u m i ( w 1 + w 2 + w 3 + w 4 ) n − i

考虑计算 a2i ∑ a i 2 ,类似管道取珠的方法,这里我们计算 p(i,x,y) p ( i , x , y ) 表示第一次某个位置之后又走了 (x,y) ( x , y ) i i 这个时刻第一次到达另一个位置的方案数, q(i,x,y) q ( i , x , y ) 表示走 i i 步到达 (0,0) ( 0 , 0 ) 中途到过 (x,y) ( x , y ) 的方案数,那么

q(i,x,y)=j<ig(j,x,y)×f(ij,x,y) q ( i , x , y ) = ∑ j < i g ( j , x , y ) × f ( i − j , − x , − y )

p(i,x,y)=j<isumj×g(ij,x,y)j<isumj×q(ij,x,y) p ( i , x , y ) = ∑ j < i s u m j × g ( i − j , x , y ) − ∑ j < i s u m j × q ( i − j , − x , − y ) −
j<ip(j,x,y)×(g(ij,0,0)q(ij,x,y)) ∑ j < i p ( j , x , y ) × ( g ( i − j , 0 , 0 ) − q ( i − j , − x , − y ) )

这里减掉的第一个东西是走了 (x,y) ( x , y ) 之后到达的位置不是第一次到达的方案数(即以前到达过),比如 (0,0)(0,1)(0,0) ( 0 , 0 ) → ( 0 , 1 ) → ( 0 , 0 ) ,那么可以在 (0,0) ( 0 , 0 ) 这个位置计算有多少种方法回来。减掉的第二个东西是走了 (x,y) ( x , y ) 后不是第一次到达的方案数(即以前已经算过答案)。

代码:

#include 

using namespace std;

#define X first
#define Y second
#define mp make_pair
#define pb push_back
#define Debug(...) fprintf(stderr, __VA_ARGS__)

typedef long long LL;
typedef long double LD;
typedef unsigned int uint;
typedef pair <int, int> pii;
typedef unsigned long long uLL;

template <typename T> inline void Read(T &x) {
  char c = getchar();
  bool f = false;
  for (x = 0; !isdigit(c); c = getchar()) {
    if (c == '-') {
      f = true;
    }
  }
  for (; isdigit(c); c = getchar()) {
    x = x * 10 + c - '0';
  }
  if (f) {
    x = -x;
  }
}

template <typename T> inline bool CheckMax(T &a, const T &b) {
  return a < b ? a = b, true : false;
}

template <typename T> inline bool CheckMin(T &a, const T &b) {
  return a > b ? a = b, true : false;
}

const int N = 205;
const int mod = 998244353;
const int dx[4] = {-1, 1, 0, 0};
const int dy[4] = {0, 0, -1, 1};

int n, m, w[4], pwd[N], sum[N], f[N][N][N], g[N][N][N], p[N][N][N], q[N][N][N];

#define f(i, x, y) f[i][x + n][y + n]
#define g(i, x, y) g[i][x + n][y + n]
#define p(i, x, y) p[i][x + n][y + n]
#define q(i, x, y) q[i][x + n][y + n]

int main() {
#ifdef wxh010910
  freopen("d.in", "r", stdin);
#endif
  Read(n);
  for (int i = 0; i < 4; ++i) {
    Read(w[i]), m += w[i];
  }
  pwd[0] = 1;
  for (int i = 1; i <= n; ++i) {
    pwd[i] = 1LL * pwd[i - 1] * m % mod;
  }
  g(0, 0, 0) = 1;
  for (int i = 0; i < n; ++i) {
    for (int x = -i; x <= i; ++x) {
      for (int y = abs(x) - i; y <= i - abs(x); ++y) {
        for (int j = 0; j < 4; ++j) {
          g(i + 1, x + dx[j], y + dy[j]) = (1LL * g(i, x, y) * w[j] + g(i + 1, x + dx[j], y + dy[j])) % mod;
        }
      }
    }
  }
  for (int i = 0; i <= n; ++i) {
    for (int x = -i; x <= i; ++x) {
      for (int y = abs(x) - i; y <= i - abs(x); ++y) {
        f(i, x, y) = g(i, x, y);
        for (int j = 0; j < i; ++j) {
          f(i, x, y) = (f(i, x, y) - 1LL * f(j, x, y) * g(i - j, 0, 0) % mod + mod) % mod;
        }
        sum[i] = (sum[i] + f(i, x, y)) % mod;
      }
    }
  }
  for (int i = 0; i <= n; ++i) {
    for (int x = -i; x <= i; ++x) {
      for (int y = abs(x) - i; y <= i - abs(x); ++y) {
        for (int j = 0; j < i; ++j) {
          q(i, x, y) = (1LL * g(j, x, y) * f(i - j, -x, -y) + q(i, x, y)) % mod;
        }
      }
    }
  }
  for (int i = 0; i <= n; ++i) {
    for (int x = -i; x <= i; ++x) {
      for (int y = abs(x) - i; y <= i - abs(x); ++y) {
        for (int j = 0; j < i; ++j) {
          p(i, x, y) = (1LL * sum[j] * g(i - j, x, y) + p(i, x, y)) % mod;
          p(i, x, y) = (p(i, x, y) - 1LL * sum[j] * q(i - j, -x, -y) % mod + mod) % mod;
          p(i, x, y) = (p(i, x, y) - 1LL * p(j, x, y) * (g(i - j, 0, 0) - q(i - j, -x, -y) + mod) % mod + mod) % mod;
        }
      }
    }
  }
  int u = 0, v = 0;
  for (int i = 0; i <= n; ++i) {
    u = (1LL * sum[i] * pwd[n - i] + u) % mod;
    for (int x = -i; x <= i; ++x) {
      for (int y = abs(x) - i; y <= i - abs(x); ++y) {
        if (x || y) {
          v = (1LL * p(i, x, y) * pwd[n - i] + v) % mod;
        }
      }
    }
  }
  v = (2LL * v + u) % mod, u = 1LL * u * u % mod;
  printf("%d\n", (1LL * v * pwd[n] - u + mod) % mod);
#ifdef wxh010910
  Debug("My Time: %.3lfms\n", (double)clock() / CLOCKS_PER_SEC);
#endif
  return 0;
}

你可能感兴趣的:(UOJ #211. 【UER #6】逃跑)