POJ 2104题目大意如下:
给定N(N <= 100000) 个整数,进行M(M <= 5000)次询问,每次询问的方式以一个三元组的形式呈现,比如(i,j,k)表示在从第i个数和第j个数之间,升序状态下的第k个数。
如果单纯的使用排序,对于5000次询问,这样花费的时间比较多。这里采用平方分割的方法,那么可以在O(n^0.5)的时间内完成。
大方向是使用另外一个N容量的数组存储对原来数组的升序情况,然后采用二分法,每次选取一个数,根据这个数在(i,j)区间的位置情况,进而选出真正的第k个数。
1.假设X是第k个数,那么一定有:
1)在询问区间内,不超过X的个数一定不少于k个
2)在询问区间内,小于X的个数一定少于k个
那么以上就是二分搜索的一个判定条件。接下来是:如果不对之前的数组进行预处理,那么久只能通过遍历一边的方式来查询所有元素,但是我们可以采用对区间进行有序化处理,这样就可以通过二分搜索高效地求出不超过X的数的个数了。
但是有没有必要对每个查询都进行一次排序,这样的话对于降低复杂度没有半点功用。所以可以考虑采用平方分割的方法进行求解。
主要步骤分为:
1)对于完全包含在查询区间内的桶,可以用二分搜索计算小于X的数的数量
2)对于在不被查询区间包含,而只是部分重合,可以使用原来的数组直接将其找出,逐个检查
代码如下:(第一次AC 1932K/12000MS,就觉得这时间有点不对,所以后来又交一遍结果TE了。。。所以还是建议用线段树最保险)
#include <iostream> #include <algorithm> #include <cstdio> #include <vector> using namespace std; const int maxn = 100002; const int maxn2 = 5002; const int B = 1000; int A[maxn], I[maxn2], J[maxn2], K[maxn2], infer[maxn]; vector<int> bucket[maxn / B]; int N, M; void solve() { for (int i = 0; i < N; i++) { bucket[i / B].push_back(A[i]); infer[i] = A[i]; } sort(infer, infer + N); for (int i = 0; i < N / B; i++) sort(bucket[i].begin(), bucket[i].end()); for (int i = 0; i < M; i++) { int left = -1, right = N - 1; int k = K[i]; //二分筛选 while (right - left > 1) { int mid = (right + left) / 2; int tl = I[i] - 1, tr = J[i], v = infer[mid], am = 0; while (tl < tr && tl % B) if (A[tl++] <= v) am++; while (tr > tl && tr % B) if (A[--tr] <= v) am++; while (tl < tr) { int j = tl / B; am += upper_bound(bucket[j].begin(), bucket[j].end(), v) - bucket[j].begin(); tl += B; } //am >= k说明当前的数比实际数要大,所以往下搜索,否则往上搜索 if (am >= k) right = mid; else left = mid; } printf("%d\n", infer[right]); } } int main() { scanf("%d %d", &N, &M); for (int i = 0; i < N; i++) scanf("%d", &A[i]); for (int i = 0; i < M; i++) scanf("%d %d %d", &I[i], &J[i], &K[i]); solve(); return 0; }
接下来使用线段树的方法,每个节点管理的是一个升序数组,该升序数组是由子数组组成的,其实感觉用线段树实现比平方分割好多了,写起来思路更流畅(虽然参看了书),事实证明也是如此,跑的结果是这样的:Accept 16400K / 6250MS 节省了一半时间:
#include <iostream> #include <algorithm> #include <cstdio> #include <cstring> #include <vector> using namespace std; const int ST_SIZE = (1 << 18) - 1; const int maxn1 = 100001; const int maxn2 = 5005; int N, M; int A[maxn1]; int I[maxn2], J[maxn2], K[maxn2]; vector<int> dat[ST_SIZE]; void init(int k, int l, int r) { if (r - 1 == l) dat[k].push_back(A[l]); else { int mid = (l + r) / 2; init(k*2 + 1, l, mid); init(k*2 + 2, mid, r); dat[k].resize(r - l); merge(dat[k*2 + 1].begin(), dat[k*2 + 1].end(), dat[k*2 + 2].begin(), dat[k*2 + 2].end(), dat[k].begin()); } } int query(int x, int k, int cl, int cr, int ql, int qr) { if (cl >= qr || cr <= ql) return 0; if (qr >= cr && cl >= ql) return upper_bound(dat[k].begin(), dat[k].end(), x) - dat[k].begin(); int mid = (cl + cr) / 2; return query(x, k*2 + 1, cl, mid, ql, qr) + query(x, k*2 + 2, mid, cr, ql, qr); } void solve() { init(0, 0, N); sort(A, A + N); for (int i = 0; i < M; i++) { int l = I[i] - 1, r = J[i], k = K[i]; int lb = -1, ub = N - 1; while (ub - lb > 1) { int mid = (ub + lb) / 2; int c = query(A[mid], 0, 0, N, l, r); if (c >= k) ub = mid; else lb = mid; } printf("%d\n", A[ub]); } } int main() { scanf("%d %d", &N, &M); for (int i = 0; i < N; i++) scanf("%d", &A[i]); for (int i = 0; i < M; i++) scanf("%d %d %d", &I[i], &J[i], &K[i]); solve(); }