题解 | Maze-2019牛客暑期多校训练营第二场E题

Given a maze with N rows and M columns, where bij represents the cell on the i-row, j-th column. If bij=“1” , it’s a wall and can’t not be passed. If you are on the cell bi,j, ou can go to bi,(j+1) as long as it’s not a wall.

Sometime, a cell may be changed into wall, or vise versa. You need to find out the number of way to pass through the maze starting at some given cell and finishing at some given cell.

If the starting cell or finishing cell is a wall, there’s clearly no way to pass through the maze.

Note that you can’t go back to the cell you just from.

The first line of input contains three space-separated integers N, M, Q.
Following N lines each contains M characters bij representing the maze.
Following Q lines each contains three space-separated integers

If qi=1, the state of cell bai,bi is changed.
If qi=2, you need to find out the number of way to start at cell b1,ai and finish at cell bN,bi.

bi,j∈ “01”
If qi=1, 1≤ai≤N and 1≤bi≤M.
If qi=2, 1≤ai,bi≤M.

For each qi=2, Output one line containing an integer representing the answer module 109+7(1000000007).

2 2 3
2 1 2
1 1 2
2 1 2


• Data Structure(Segment tree), Dynamic Programming
We can construct a DP as:
dp[i][j] =sum(dp[i-1][k] for (k < j and b_ik=b_i{k+1}=…=b_ij=0))+ sum(dp[i-1][k] for (k > j and b_ik=b_i{k-1}=…=b_ij=0))
It can be represented as matrix multiplication from dp[i][ * ] to dp[i+1][ * ].

Construct a segment tree, each node consists of the weight of each state.
The final answer will be the product of them.
Then, modify time will be O(m^3 \lg N), query time will be O(1).


// Author: Yen-Jen Wang
#pragma GCC optimize("O3")

using namespace std;

typedef long long ll;

const int MAX_N = 50000 + 7;
const int MAX_M = 10 + 1;

const ll MOD = 1000000000 + 7;

struct Node {
    Node *lc, *rc;
    int l, r;
    ll dp[MAX_M][MAX_M];

    Node(int _l = 0, int _r = 0) : lc(0), rc(0), l(_l), r(_r) {
        memset(dp, 0, sizeof(dp));

int N, M;
char mp[MAX_N][MAX_M];

void pull(Node *o) {
    //static ll dp1[MAX_M];
    //static ll dp2[MAX_M];

    Node *a = o->lc;
    Node *b = o->rc;

    for (int i = 1; i <= M; ++i) {
      for (int j = 1; j <= M; ++j) {
        o->dp[i][j] = 0;

    for (int i = 1; i <= M; ++i) {
      for (int j = 1; j <= M; ++j) {
        if (a->dp[i][j] == 0) {
        for (int k = 1; k <= M; ++k) {
          ll way = (a->dp[i][j] * b->dp[j][k]) % MOD;
          o->dp[i][k] = (o->dp[i][k] + way) % MOD;

    /*for (int i = 1; i <= M; i++) {
        for (int j = 1; j <= M; j++) {
            if (!a->dp[i][j])
            if (mp[b->l][j]) 
                dp1[j] = a->dp[i][j];
                dp1[j] = 0;
        for (int j = 1; j <= M; j++) 
            dp2[j] = 0;
        ll s = 0;
        for (int j = 1; j <= M; j++) {
            if (mp[b->l][j] == 0)
                s = 0;
            else {
                s += dp1[j];
                if (s >= MOD)
                    s -= MOD;
            dp2[j] += s;
            if (dp2[j] >= MOD)
                dp2[j] -= MOD;

        s = 0;
        for (int j = M; j >= 0; j--) {
            if (mp[b->l][j] == 0)
                s = 0;
            dp2[j] += s;
            if (dp2[j] >= MOD)
                dp2[j] -= MOD;
            if (mp[b->l][j] != 0) {
                s += dp1[j];
                if (s >= MOD)
                    s -= MOD;

        for (int j = 1; j <= M; j++) {
            o->dp[i][j] = 0;
            for (int k = 1; k <= M; k++) {
                o->dp[i][j] += dp2[k] * b->dp[k][j] % MOD;
                if (o->dp[i][j] >= MOD)
                    o->dp[i][j] -= MOD;

Node* build(int l, int r) {
    Node *o = new Node(l, r);
    if (l == r) {
        for (int i = 1; i <= M; i++) {
          for (int j = 1; j <= M; j++) {
            o->dp[i][j] = 0;
          for (int j = i; j <= M; ++j) {
            if (mp[l][j] == 0) {
            o->dp[i][j] = 1;
          for (int j = i; j >= 1; --j) {
            if (mp[l][j] == 0) {
            o->dp[i][j] = 1;
    else {
        int m = (l + r) >> 1;
        o->lc = build(l, m);
        o->rc = build(m + 1, r);
    return o;

void maintain(Node *o, int p) {
    int l = o->l, r = o->r;
    if (l == r) {
        for (int i = 1; i <= M; i++) {
          for (int j = 1; j <= M; j++) {
            o->dp[i][j] = 0;
          for (int j = i; j <= M; ++j) {
            if (mp[l][j] == 0) {
            o->dp[i][j] = 1;
          for (int j = i; j >= 1; --j) {
            if (mp[l][j] == 0) {
            o->dp[i][j] = 1;
    else {
        int m = (l + r) >> 1;
        if (p <= m)
            maintain(o->lc, p);
            maintain(o->rc, p);

int main() {
    int Q;
    scanf("%d%d%d", &N, &M, &Q); 
    for (int i = 1; i <= N; i++) {
        scanf("%s", mp[i] + 1);
        for (int j = 1; j <= M; j++){
            mp[i][j] -= '0';
            mp[i][j] = 1 - mp[i][j];
    Node *tr = build(1, N);
    while (Q--) {
        int q, a, b;
        scanf("%d%d%d", &q, &a, &b);
        if (q == 1) {
            mp[a][b] ^= 1;
            maintain(tr, a);
            printf("%lld\n", tr->dp[a][b]);
    return 0;

