先贴一个prefix Sum的做法
时间复杂度 O(N) update, O(N) query
class NumMatrix {
int[][] matrix;
int[][] prefixSum;
int N, M;
public NumMatrix(int[][] matrix) {
N = matrix.length;
if (N == 0) return;
M = matrix[0].length;
this.matrix = new int[N][M];
for (int i = 0; i < N; i++) {
for (int c = 0; c < M; c++) {
this.matrix[i][c] = matrix[i][c];
}
}
this.prefixSum = new int[N][M];
for (int r = 0; r < N; r++) {
int sum = 0;
for (int c = 0; c < M; c++) {
sum += matrix[r][c];
prefixSum[r][c] = sum;
}
}
}
public void update(int row, int col, int val) {
int diff = val - matrix[row][col];
for (int c = col; c < M; c++) {
prefixSum[row][c] += diff;
}
matrix[row][col] = val;
}
public int sumRegion(int row1, int col1, int row2, int col2) {
int sum = 0;
for (int r = row1; r <= row2; r++) {
sum += prefixSum[r][col2] - prefixSum[r][col1] + matrix[r][col1];
}
return sum;
}
}
下面是Binary index Tree的解法。
时间复杂度是logN * logM
比segment tree要差一些,不过短呀!!
这里很容易出的bug是在sweep的时候,每次内循环里的col 要重要开始。
参考文献
https://leetcode.com/problems/range-sum-query-2d-mutable/discuss/75870/Java-2D-Binary-Indexed-Tree-Solution-clean-and-short-17ms
class NumMatrix {
int N, M;
int[][] bitTree;
int[][] matrix;
public NumMatrix(int[][] matrix) {
if (matrix == null || matrix.length == 0 || matrix[0].length == 0) return;
N = matrix.length;
M = matrix[0].length;
this.matrix = new int[N][M];
this.bitTree = new int[N + 1][M + 1];
for (int r = 0; r < N; r++) {
for (int c = 0; c < M; c++) {
this.matrix[r][c] = matrix[r][c];
updateTree(this.matrix[r][c], r + 1, c + 1);
}
}
}
public void update(int row, int col, int val) {
updateTree(val - matrix[row][col], row + 1, col + 1);
matrix[row][col] = val;
}
public int sumRegion(int row1, int col1, int row2, int col2) {
int sum4 = getSum(row2 + 1, col2 + 1);
int sum1 = getSum(row1, col1);
int sum2 = getSum(row1, col2 + 1);
int sum3 = getSum(row2 + 1, col1);
return sum4 - sum2 - sum3 + sum1;
}
private void updateTree(int diff, int i, int j) {
//int r = i, c = j;
for (int r = i; r <= N; r += (r & (-r))) {
for (int c = j; c <= M; c += (c & (-c))) {
//System.out.println(r + " " + c + " " + diff);
bitTree[r][c] += diff;
}
}
}
public int getSum(int i, int j) {
int ans = 0;
for (int r = i ; r > 0; r -= (r & (-r))) {
for (int c = j ; c > 0; c -= (c & (-c))) {
ans += bitTree[r][c];
}
}
return ans;
}
}
用segment tree的做法
Segment tree其实不难,都有套路,建议掌握
时间复杂度 O(log4(N^2)) 还是O(logN)
class NumMatrix {
SegmentTreeNode root;
public NumMatrix(int[][] matrix) {
if (matrix == null || matrix.length == 0 || matrix[0].length == 0) return;
int N = matrix.length, M = matrix[0].length;
root = buildTree(matrix, 0, 0, N - 1, M - 1);
}
private SegmentTreeNode buildTree(
int[][] matrix, int row1, int col1, int row2, int col2) {
if (row1 > row2 || col1 > col2) return null;
if (row1 == row2 && col1 == col2) {
return new SegmentTreeNode(row1, col1, row2, col2, matrix[row1][col1]);
}
int midRow = row1 + (row2 - row1) / 2;
int midCol = col1 + (col2 - col1) /2;
SegmentTreeNode node = new SegmentTreeNode(row1, col1, row2, col2, 0);
node.c1 = buildTree(matrix, row1, col1, midRow, midCol);
node.c2 = buildTree(matrix, row1, midCol + 1, midRow, col2);
node.c3 = buildTree(matrix, midRow + 1, col1, row2, midCol);
node.c4 = buildTree(matrix, midRow + 1, midCol + 1, row2, col2);
node.sum = (node.c1 == null ? 0 : node.c1.sum)
+ (node.c2 == null ? 0 : node.c2.sum)
+ (node.c3 == null ? 0 : node.c3.sum)
+ (node.c4 == null ? 0 : node.c4.sum);
return node;
}
private void updateHelper(SegmentTreeNode node, int row, int col, int val) {
if (node == null || row < node.row1
|| row > node.row2 || col < node.col1
|| col > node.col2) return;
if (row == node.row1 && row == node.row2
&& col == node.col1 && col == node.col2) {
node.sum = val;
return;
}
updateHelper(node.c1, row, col, val);
updateHelper(node.c2, row, col, val);
updateHelper(node.c3, row, col, val);
updateHelper(node.c4, row, col, val);
node.sum = (node.c1 == null ? 0 : node.c1.sum)
+(node.c2 == null ? 0 : node.c2.sum)
+ (node.c3 == null ? 0 : node.c3.sum)
+ (node.c4 == null ? 0 : node.c4.sum);
}
private int sumRegionHelper(SegmentTreeNode node,
int row1, int col1, int row2, int col2) {
if (node == null || row1 > node.row2 || row2 < node.row1
|| col1 > node.col2 || col2 < node.col1) return 0;
if (row1 <= node.row1 && row2 >= node.row2
&& col1 <= node.col1 && col2 >= node.col2) {
return node.sum;
}
return sumRegionHelper(node.c1, row1, col1, row2, col2)
+ sumRegionHelper(node.c2, row1, col1, row2, col2)
+ sumRegionHelper(node.c3, row1, col1, row2, col2)
+ sumRegionHelper(node.c4, row1, col1, row2, col2);
}
public void update(int row, int col, int val) {
updateHelper(root, row, col, val);
}
public int sumRegion(int row1, int col1, int row2, int col2) {
return sumRegionHelper(root, row1, col1, row2, col2);
}
}
class SegmentTreeNode{
int row1, row2, col1, col2, sum;
SegmentTreeNode c1, c2, c3, c4;
public SegmentTreeNode(int row1, int col1, int row2, int col2, int sum) {
this.row1 = row1;
this.row2 = row2;
this.col1 = col1;
this.col2 = col2;
this.sum = sum;
}
}