一开始有n个非负整数hi,接下来会进行m次操作,第i次操作给出一个数c[i],要求你选出c[i]个大于零的数并将它们减去1。
问最多可以进行多少轮操作后无法操作(即没有c[i]个大于零的数)
1<=n,m<=1000000
对于一个数据结构学傻的人来说,一上来就会上什么平衡树之类的东西。
每次去c个,显然取当前最大的c个会优。
但是你发现如果对最大的c个数-1,也许相对顺序会变,那怎么做呢?
假设现在是[2,2,3,3,4,4],c=3,对前三大的减1,会变成:
[2,2,3,2,4,4]顺序就变了,然而你发现会改变的一定是和第c大的值一样的,对右边减1,可以转成对左边减1,至于确定区间可以用平衡树。
当然10^6平衡树怎么说也很难过,可以差分后用单点修改线段树维护,实际上数据水爆了。
Code:
#include
#include
#include
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define min(a, b) ((a) < (b) ? (a) : (b))
#define max(a, b) ((a) > (b) ? (a) : (b))
using namespace std;
const int N = 1e6 + 50;
int n, m, c[N], h;
int t[N * 8], st[N * 8], en[N * 8];
void Build(int i, int x, int y) {
if(x == y) {
t[i] = c[x] - c[x - 1];
st[i] = en[i] = (t[i] == 0);
return;
}
int m = x + y >> 1;
Build(i + i, x, m); Build(i + i + 1, m + 1, y);
t[i] = t[i + i] + t[i + i + 1];
st[i] = st[i + i] == m - x + 1 ? st[i + i] + st[i + i + 1] : st[i] = st[i + i];
en[i] = en[i + i + 1] == y - m ? en[i + i + 1] + en[i + i] : en[i] = en[i + i + 1];
}
int find(int i, int x, int y, int r) {
if(y == r) return t[i];
int m = x + y >> 1;
if(r <= m) return find(i + i, x, m, r);
return t[i + i] + find(i + i + 1, m + 1, y, r);
}
int findst(int i, int x, int y, int l) {
if(x == l) return st[i];
int m = x + y >> 1;
if(l > m) return findst(i + i + 1, m + 1, y, l);
int v = findst(i + i, x, m, l);
return v == m - l + 1 ? st[i + i + 1] + v : v;
}
int finden(int i, int x, int y, int r) {
if(y == r) return en[i];
int m = x + y >> 1;
if(r <= m) return finden(i + i, x, m, r);
int v = finden(i + i + 1, m + 1, y, r);
return v == r - m ? en[i + i] + v : v;
}
void change(int i, int x, int y, int l, int c) {
if(x == y) {
t[i] += c;
st[i] = en[i] = (t[i] == 0);
return;
}
int m = x + y >> 1;
if(l <= m) change(i + i, x, m, l, c); else change(i + i + 1, m + 1, y, l, c);
t[i] = t[i + i] + t[i + i + 1];
st[i] = st[i + i] == m - x + 1 ? st[i + i] + st[i + i + 1] : st[i] = st[i + i];
en[i] = en[i + i + 1] == y - m ? en[i + i + 1] + en[i + i] : en[i] = en[i + i + 1];
}
void read(int &x) {
char c = ' '; for(; c < '0' || c > '9'; c= getchar());
x = 0; for(; c >= '0' && c <= '9'; c = getchar()) x = x * 10 + c - 48;
}
int main() {
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
scanf("%d %d", &n, &m);
fo(i, 1, n) read(c[i]);
sort(c + 1, c + n + 1);
Build(1, 1, n);
fo(i, 1, m) {
read(h);
int sum = find(1, 1, n, n - h + 1);
if(sum <= 0) {
printf("%d\n", i - 1);
return 0;
}
int en0 = h == 1 ? 0 : findst(1, 1, n, n - h + 2);
int st0 = h == n ? 0 : finden(1, 1, n, n - h + 1);
change(1, 1, n, n - h + 1 - st0, -1);
if(n - h + 1 - st0 + en0 + 1 <= n) change(1, 1, n, n - h + 1 - st0 + en0 + 1, 1);
if(n - h + en0 + 2 <= n) change(1, 1, n, n - h + en0 + 2, -1);
}
printf("%d\n", m);
}