题目来源于牛客竞赛:https://ac.nowcoder.com/acm/contest/discuss
题目描述:
Given a maze with N rows and M columns, where bij represents the cell on the i-row, j-th column. If bij=“1” , it’s a wall and can’t not be passed. If you are on the cell bi,j, ou can go to bi,(j+1) as long as it’s not a wall.
Sometime, a cell may be changed into wall, or vise versa. You need to find out the number of way to pass through the maze starting at some given cell and finishing at some given cell.
If the starting cell or finishing cell is a wall, there’s clearly no way to pass through the maze.
Note that you can’t go back to the cell you just from.
输入描述:
The first line of input contains three space-separated integers N, M, Q.
Following N lines each contains M characters bij representing the maze.
Following Q lines each contains three space-separated integers
If qi=1, the state of cell bai,bi is changed.
If qi=2, you need to find out the number of way to start at cell b1,ai and finish at cell bN,bi.
1≤N,Q≤50000
1≤M≤10
bi,j∈ “01”
1≤qi≤2
If qi=1, 1≤ai≤N and 1≤bi≤M.
If qi=2, 1≤ai,bi≤M.
输出描述:
For each qi=2, Output one line containing an integer representing the answer module 109+7(1000000007).
示例1:
输入
2 2 3
00
00
2 1 2
1 1 2
2 1 2
输出
2
1
题解:
• Data Structure(Segment tree), Dynamic Programming
We can construct a DP as:
dp[i][j] =sum(dp[i-1][k] for (k < j and b_ik=b_i{k+1}=…=b_ij=0))+ sum(dp[i-1][k] for (k > j and b_ik=b_i{k-1}=…=b_ij=0))
It can be represented as matrix multiplication from dp[i][ * ] to dp[i+1][ * ].
• Data Structure(Segment tree), Dynamic Programming
dp[i][j] = sum(dp[i-1][k] for (k < j and b_ik=b_i{k+1}=…=b_ij=0))+ sum(dp[i-1][k] for (k > j and b_ik=b_i{k-1}=…=b_ij=0))
Construct a segment tree, each node consists of the weight of each state.
The final answer will be the product of them.
Then, modify time will be O(m^3 \lg N), query time will be O(1).
代码:
// Author: Yen-Jen Wang
#pragma GCC optimize("O3")
#include
using namespace std;
typedef long long ll;
const int MAX_N = 50000 + 7;
const int MAX_M = 10 + 1;
const ll MOD = 1000000000 + 7;
struct Node {
Node *lc, *rc;
int l, r;
ll dp[MAX_M][MAX_M];
Node(int _l = 0, int _r = 0) : lc(0), rc(0), l(_l), r(_r) {
memset(dp, 0, sizeof(dp));
}
};
int N, M;
char mp[MAX_N][MAX_M];
void pull(Node *o) {
//static ll dp1[MAX_M];
//static ll dp2[MAX_M];
Node *a = o->lc;
Node *b = o->rc;
for (int i = 1; i <= M; ++i) {
for (int j = 1; j <= M; ++j) {
o->dp[i][j] = 0;
}
}
for (int i = 1; i <= M; ++i) {
for (int j = 1; j <= M; ++j) {
if (a->dp[i][j] == 0) {
continue;
}
for (int k = 1; k <= M; ++k) {
ll way = (a->dp[i][j] * b->dp[j][k]) % MOD;
o->dp[i][k] = (o->dp[i][k] + way) % MOD;
}
}
}
/*for (int i = 1; i <= M; i++) {
for (int j = 1; j <= M; j++) {
if (!a->dp[i][j])
continue;
if (mp[b->l][j])
dp1[j] = a->dp[i][j];
else
dp1[j] = 0;
}
for (int j = 1; j <= M; j++)
dp2[j] = 0;
ll s = 0;
for (int j = 1; j <= M; j++) {
if (mp[b->l][j] == 0)
s = 0;
else {
s += dp1[j];
if (s >= MOD)
s -= MOD;
}
dp2[j] += s;
if (dp2[j] >= MOD)
dp2[j] -= MOD;
}
s = 0;
for (int j = M; j >= 0; j--) {
if (mp[b->l][j] == 0)
s = 0;
dp2[j] += s;
if (dp2[j] >= MOD)
dp2[j] -= MOD;
if (mp[b->l][j] != 0) {
s += dp1[j];
if (s >= MOD)
s -= MOD;
}
}
for (int j = 1; j <= M; j++) {
o->dp[i][j] = 0;
for (int k = 1; k <= M; k++) {
o->dp[i][j] += dp2[k] * b->dp[k][j] % MOD;
if (o->dp[i][j] >= MOD)
o->dp[i][j] -= MOD;
}
}
}*/
}
Node* build(int l, int r) {
Node *o = new Node(l, r);
if (l == r) {
for (int i = 1; i <= M; i++) {
for (int j = 1; j <= M; j++) {
o->dp[i][j] = 0;
}
for (int j = i; j <= M; ++j) {
if (mp[l][j] == 0) {
break;
}
o->dp[i][j] = 1;
}
for (int j = i; j >= 1; --j) {
if (mp[l][j] == 0) {
break;
}
o->dp[i][j] = 1;
}
}
}
else {
int m = (l + r) >> 1;
o->lc = build(l, m);
o->rc = build(m + 1, r);
pull(o);
}
return o;
}
void maintain(Node *o, int p) {
int l = o->l, r = o->r;
if (l == r) {
for (int i = 1; i <= M; i++) {
for (int j = 1; j <= M; j++) {
o->dp[i][j] = 0;
}
for (int j = i; j <= M; ++j) {
if (mp[l][j] == 0) {
break;
}
o->dp[i][j] = 1;
}
for (int j = i; j >= 1; --j) {
if (mp[l][j] == 0) {
break;
}
o->dp[i][j] = 1;
}
}
}
else {
int m = (l + r) >> 1;
if (p <= m)
maintain(o->lc, p);
else
maintain(o->rc, p);
pull(o);
}
}
int main() {
int Q;
scanf("%d%d%d", &N, &M, &Q);
for (int i = 1; i <= N; i++) {
scanf("%s", mp[i] + 1);
for (int j = 1; j <= M; j++){
mp[i][j] -= '0';
mp[i][j] = 1 - mp[i][j];
}
}
Node *tr = build(1, N);
while (Q--) {
int q, a, b;
scanf("%d%d%d", &q, &a, &b);
if (q == 1) {
mp[a][b] ^= 1;
maintain(tr, a);
}
else
printf("%lld\n", tr->dp[a][b]);
}
return 0;
}
更多问题,更详细题解可关注牛客竞赛区,一个刷题、比赛、分享的社区。
传送门:https://ac.nowcoder.com/acm/contest/discuss