一看就会但是不敢打的题…
这题让人非常难受,因为很难找到一种简便的打法,我的做法用了2800bytes,你要知道这是一道数学题。
我的做法有两个核心的函数:
1.求一个点的子树大小。
2.求一个点的子树内的点到它的代价。
注意这两个函数不用考虑删除,就是原来的树,删除的影响可以直接用hash表搞掉。
首先我们设n到1的路径上的点为特殊点,特殊点的子树大小需要特殊算,这个自下而上走一遍就可以算出。
对于一个点,要算它的子树大小,先求出它的层数,找到它这一层的特殊点。
它在特殊点的左边或右边,它的子树是一棵满二叉树,深度取决于在左边还是右边。
它是特殊点,之前已经算过,直接返回即可。
求层数用cmath库里的log2函数,这个东西是O(1)的。
求代价的话,可以分层讨论每一层点的贡献:
想象一个递归dg(x,y),设求z的子树大小,一开始是dg(z*2,z*2+1),dg(x*2,y*2+1)
x,y表示的是当前这一层的边界的两个点,每一层的点的序号一定是一个等差数列。
设c为这一层的特殊点,
于是分类讨论:
1.x<=c<=y
2.c < < x
3.c>y
第二、三种情况每个点的子树大小是一样的,于是套个自然数幂和公式就可以O(1)算出。
同理,第一种情况分三段算就可以了。
至此我们搞定了两个关键函数,接下来,考虑删除。
一个删除,其实会对它到根节点的路径上的点产生影响,由于个数少,用个hash表存一下就好了。
存的是当前它这个子树少了多少个点,到它这里少了多少代价。
最后一步找答案也很简单:
一开始在根节点,可以直接算出设在根节点的代价,
由于序号要小,优先考虑左节点。
要往一个子节点走的条件是:
子节点的子树个数大于总节点个数/2
一直走,直到不优了,走过一条边要减去多余的代价。
注意mo运算,InFleaKing就是被这个坑成了50分,不然这个比较水的题当时也不会只有一个人AC了。
#include
#include
#include
#define ll long long
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
using namespace std;
const ll mo = 998244353;
const ll n2 = 499122177, n6 = 166374059;
const int M = 11021051;
ll n, m, a, b, c, x, w, a2[65], ss[65];
struct HA {
ll x, a, b, de;
} h[M], d[65];
int fid(ll n) {return (int) log2(n) + 1;}
void Bl() {
w = fid(n); d[w].x = n; d[w].a = 1;
fd(i, w, 2) {
ll x = n / a2[w - i], y = x / 2;
d[i - 1].x = y, d[i - 1].a = d[i].a + 1 + ss[w - i + (x & 1)];
}
}
ll siz(ll x) {
if(x > n) return 0;
ll p = fid(x);
if(x == d[p].x) return d[p].a;
return ss[w - p + (x < d[p].x)];
}
ll calc(ll x) {x %= mo; return (x * (x + 1) % mo * (2 * x + 1) % mo * n6 % mo * a % mo + x * (x + 1) % mo * n2 % mo * b % mo + x * c % mo) % mo;}
ll ca(ll x) {x %= mo; return (x * x % mo * a % mo + x * b % mo + c) % mo;}
ll sv(ll x) {
ll ans = 0;
x <<= 1; ll y = x + 1, p = fid(x);
while(x <= n) {
y = y < n ? y : n;
if(d[p].x >= x && d[p].x <= y)
ans += (calc(d[p].x - 1) - calc(x - 1)) * (siz(x) % mo) % mo + (calc(y) - calc(d[p].x)) * (siz(y) % mo) % mo + ca(d[p].x) * (d[p].a % mo)% mo;
else ans += (calc(y) - calc(x - 1)) * (siz(x) % mo) % mo;
x <<= 1; y = y * 2 + 1; p ++;
}
return (ans % mo + mo) % mo;
}
ll hash(ll n) {
ll y = n % M;
while(h[y].x != 0 && h[y].x != n) y = (y + 1) % M;
return y;
}
void det(ll n) {
ll x = n;
for(; x; x /= 2) if(h[hash(x)].de) return;
ll y = hash(n);
h[y].x = n; h[y].de = 1;
ll s1 = siz(n) - h[y].a, ji = (sv(n) - h[y].b) % mo;
h[y].a += s1, h[y].b = (h[y].b + ji) % mo;
for(x = n; x > 1; x /= 2) {
ll fa = x / 2, yy = hash(fa);
h[yy].x = fa; ji = (ji + ca(x) * (s1 % mo)) % mo;
h[yy].a += s1; h[yy].b = (h[yy].b + ji) % mo;
}
}
void find() {
x = 1;
ll y = hash(1), ans = sv(1) - h[y].b, sum = n - h[y].a;
while(1) {
ll u = x * 2, v = u + 1, t1 = hash(u), t2 = hash(v);
if((u > n || h[t1].de) && (v > n || h[t2].de)) break;
ll s1 = siz(u) - h[t1].a, s2 = siz(v) - h[t2].a;
if((v > n || h[t2].de || s1 >= s2) && s1 > sum - s1) {
ans -= (s1 * 2 - sum) % mo * ca(u) % mo;
x = u; continue;
}
if((u > n || h[t1].de || s1 < s2) && s2 > sum - s2) {
ans -= (s2 * 2 - sum) % mo * ca(v) % mo;
x = v; continue;
}
break;
}
printf("%lld %lld\n", x, (ans % mo + mo) % mo);
}
int main() {
freopen("base.in", "r", stdin);
freopen("base.out", "w", stdout);
a2[0] = 1; fo(i, 1, 62) a2[i] = a2[i - 1] * 2, ss[i] = a2[i] - 1;
scanf("%lld %lld %lld %lld %lld", &n, &m, &a, &b, &c);
Bl();
find();
fo(ii, 1, m) {
scanf("%lld", &x);
det(x);
find();
}
}