【BZOJ2738】矩阵乘法【整体二分】

然而和矩阵乘法并没有什么关系。 将矩阵里的数当做添加操作,二分答案。 对于添加操作,遇到小于等于mid的数,在二维树状数组里的相应坐标加上1,这样可以查询一个矩阵里面有多少的数小于等于mid。 对于查询操作,直接在二维树状数组里查询小于等于mid的数的个数。如果个数大于k,说明第k小比mid小,反则同理。 对于查询操作还需要记录一个cur值,表示当前有多少数小于k。 
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <iostream>
#include <algorithm>
using namespace std;

const int maxn = 505, maxm = 60005 + maxn * maxn;

int n, m, ans[maxm], tmp[maxm], tr[maxn][maxn];

struct _opt {
	int tp, id, x1, y1, x2, y2, k, cur;
} q[maxm], q1[maxm], q2[maxm];

inline int iread() {
	int f = 1, x = 0; char ch = getchar();
	for(; ch < '0' || ch > '9'; ch = getchar()) f = ch == '-' ? -1 : 1;
	for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	return f * x;
}

inline void add(int x, int y, int c) {
	for(int i = x; i <= n; i += i & -i) for(int j = y; j <= n; j += j & -j) tr[i][j] += c;
}

inline int sum(int x, int y) {
	int ans = 0;
	for(int i = x; i; i -= i & -i) for(int j = y; j; j -= j & -j) ans += tr[i][j];
	return ans;
}

void solve(int h, int t, int l, int r) {
	if(h > t) return;
	if(l == r) {
		for(int i = h; i <= t; i++) if(q[i].tp == 2) ans[q[i].id] = l;
		return;
	}

	int mid = l + r >> 1;
	for(int i = h; i <= t; i++)
		if(q[i].tp == 1 && q[i].k <= mid) add(q[i].x1, q[i].y1, 1);
		else if(q[i].tp == 2) tmp[i] = sum(q[i].x2, q[i].y2) - sum(q[i].x1 - 1, q[i].y2) - sum(q[i].x2, q[i].y1 - 1) + sum(q[i].x1 - 1, q[i].y1 - 1);
	
	for(int i = h; i <= t; i++) if(q[i].tp == 1 && q[i].k <= mid) add(q[i].x1, q[i].y1, -1);

	int h1 = 0, h2 = 0;
	for(int i = h; i <= t; i++)
		if(q[i].tp == 2) {
			if(q[i].cur + tmp[i] >= q[i].k) q1[h1++] = q[i];
			else q[i].cur += tmp[i], q2[h2++] = q[i];
		} else 
			q[i].k <= mid ? q1[h1++] = q[i] : q2[h2++] = q[i];

	for(int i = 0; i < h1; i++) q[h + i] = q1[i];
	for(int i = 0; i < h2; i++) q[h + h1 + i] = q2[i];
	solve(h, h + h1 - 1, l, mid); solve(h + h1, t, mid + 1, r);
}

int main() {
	n = iread(); m = iread();

	int __max = -1, cnt = 0;
	for(int i = 1; i <= n; i++) for(int j = 1; j <= n; j++) {
		int k = iread();
		__max = max(__max, k);
		q[++cnt] = (_opt) {1, 0, i, j, 0, 0, k, 0};
	}
	for(int i = 1; i <= m; i++) {
		int x1 = iread(), y1 = iread(), x2 = iread(), y2 = iread(), k = iread();
		q[++cnt] = (_opt) {2, i, x1, y1, x2, y2, k, 0};
	}

	solve(1, cnt, 0, __max);

	for(int i = 1; i <= m; i++) printf("%d\n", ans[i]);

	return 0;
}


你可能感兴趣的:(cdq分治,整体二分)