【树状数组+二分,贪心】CF1227-D2. Optimal Subsequences (Hard Version)

点击跳转到题目

目录

  • 题目描述
  • 思路
    • 每次记录最优点的代码
    • 记录询问状态的代码
    • 查询pos的代码
  • 代码

题目描述

给定一个长度为n的子序列,有m次询问,每次询问给定一个 k 和 pos,表示求长度为 k 的最大子序列之和的第 pos 位是什么数字。同时每个最大子序列之和要求满足字典序最小

思路

题目给定范围 1 ≤ n , m ≤ 2 e 5 1 \leq n,m \leq 2e5 1n,m2e5,很容易想到对于询问长度为 k 的最大求和子序列时,只要贪心的把每个最大的数且位置更靠前的数取出来就行,然后每次去查询第 pos 个数字是什么再输出。因为这个插入状态是动态的,所以在查询过程中也是动态的,先要将所有查询结果存到一个优先队列里去
这个优先队列以长度 k 为优先级,同时记录查询的位置 pos 和第几个答案 id

每次记录最优点的代码

struct node {
	int num, pos;
	friend bool operator < (node a, node b) {
		if(a.num == b.num) return a.pos < b.pos;
		return a.num > b.num;
	}
}a[N];

记录询问状态的代码

struct query {
	int k, pos, id;
	friend bool operator < (query a, query b) {
		return a.k > b.k;
	}
};
priority_queue<query> q;

每次先查询长度短的,同时插入这个点的位置。这一部分用树状数组维护,这个点也是我比赛过程中一直没有想到的点,导致卡了好久。用树状数组维护可以保证每次可以有序的插入这个点的复杂度在logn,每次查询通过二分去查询所要求的点,可以将整体时间复杂度降到mlogn的复杂度。

查询pos的代码

int query(int x) {
	int res = 0;
	for(int i = x; i; i -= lowbit(i)) res += tr[i];
	return res;
}

int find(int pos) {
	int l = 1, r = n;
	int res = 0;
	while(l <= r) {
		int mid = l + r >> 1;
		if(query(mid) >= pos) {
			res = mid;
			r = mid - 1;
		} else l = mid + 1;
	}
	return res;
}

代码

struct node {
	int num, pos;
	friend bool operator < (node a, node b) {
		if(a.num == b.num) return a.pos < b.pos;
		return a.num > b.num;
	}
}a[N];
struct query {
	int k, pos, id;
	friend bool operator < (query a, query b) {
		return a.k > b.k;
	}
};
priority_queue<query> q;
int tr[N], num[N], ans[N];
int x, n;

int lowbit(int x) {
	return x & -x;
}

void update(int x, int c) {
	for(int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}

int query(int x) {
	int res = 0;
	for(int i = x; i; i -= lowbit(i)) res += tr[i];
	return res;
}

int find(int pos) {
	int l = 1, r = n;
	int res = 0;
	while(l <= r) {
		int mid = l + r >> 1;
		if(query(mid) >= pos) {
			res = mid;
			r = mid - 1;
		} else l = mid + 1;
	}
	return res;
}

void solve() {
	scanf("%d", &n);
	for(int i = 1; i <= n; i++) {
		scanf("%d", &num[i]);
		a[i] = {num[i], i};
	}
	sort(a + 1, a + 1 + n);
	int m;
	scanf("%d", &m);
	for(int i = 1; i <= m; i++) {
		int k, pos;
		scanf("%d%d", &k, &pos);
		q.push({k, pos, i});
	}
	x = 0;
	while(!q.empty()) {
		int cnt = q.top().k, pos = q.top().pos, id = q.top().id;
		q.pop();
		while(x < cnt) {
			++x;
			update(a[x].pos, 1);
		}
		int p = find(pos);
		ans[id] = num[p];
	}
	for(int i = 1; i <= m; i++) printf("%d\n", ans[i]);
}

你可能感兴趣的:(CodeForces)