【gdsoi2018 day3】基地

题目大意:

一看就会但是不敢打的题…

题解:

这题让人非常难受,因为很难找到一种简便的打法,我的做法用了2800bytes,你要知道这是一道数学题。

我的做法有两个核心的函数:

1.求一个点的子树大小。
2.求一个点的子树内的点到它的代价。

注意这两个函数不用考虑删除,就是原来的树,删除的影响可以直接用hash表搞掉。

首先我们设n到1的路径上的点为特殊点,特殊点的子树大小需要特殊算,这个自下而上走一遍就可以算出。

对于一个点,要算它的子树大小,先求出它的层数,找到它这一层的特殊点。

它在特殊点的左边或右边,它的子树是一棵满二叉树,深度取决于在左边还是右边。
它是特殊点,之前已经算过,直接返回即可。

求层数用cmath库里的log2函数,这个东西是O(1)的。

求代价的话,可以分层讨论每一层点的贡献:
想象一个递归dg(x,y),设求z的子树大小,一开始是dg(z*2,z*2+1),dg(x*2,y*2+1)

x,y表示的是当前这一层的边界的两个点,每一层的点的序号一定是一个等差数列。

设c为这一层的特殊点,

于是分类讨论:
1.x<=c<=y
2.c < < x
3.c>y

第二、三种情况每个点的子树大小是一样的,于是套个自然数幂和公式就可以O(1)算出。

同理,第一种情况分三段算就可以了。

至此我们搞定了两个关键函数,接下来,考虑删除。

一个删除,其实会对它到根节点的路径上的点产生影响,由于个数少,用个hash表存一下就好了。

存的是当前它这个子树少了多少个点,到它这里少了多少代价。

最后一步找答案也很简单:
一开始在根节点,可以直接算出设在根节点的代价,
由于序号要小,优先考虑左节点。

要往一个子节点走的条件是:
子节点的子树个数大于总节点个数/2

一直走,直到不优了,走过一条边要减去多余的代价。

注意mo运算,InFleaKing就是被这个坑成了50分,不然这个比较水的题当时也不会只有一个人AC了。

#include
#include
#include
#define ll long long
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
using namespace std;

const ll mo = 998244353;

const ll n2 = 499122177, n6 = 166374059;

const int M = 11021051;

ll n, m, a, b, c, x, w, a2[65], ss[65];

struct HA {
    ll x, a, b, de;
} h[M], d[65];

int fid(ll n) {return (int) log2(n) + 1;}

void Bl() {
    w = fid(n); d[w].x = n; d[w].a = 1;
    fd(i, w, 2) {
        ll x = n / a2[w - i], y = x / 2;
        d[i - 1].x = y, d[i - 1].a = d[i].a + 1 + ss[w - i + (x & 1)];
    }
}
ll siz(ll x) {
    if(x > n) return 0;
    ll p = fid(x);
    if(x == d[p].x) return d[p].a;
    return ss[w - p + (x < d[p].x)];
}
ll calc(ll x) {x %= mo; return (x * (x + 1) % mo * (2 * x + 1) % mo * n6 % mo * a % mo + x * (x + 1) % mo * n2 % mo * b % mo + x * c % mo) % mo;}
ll ca(ll x) {x %= mo; return (x * x % mo * a % mo + x * b % mo + c) % mo;}
ll sv(ll x)  {
    ll ans = 0;
    x <<= 1; ll y = x + 1, p = fid(x); 
    while(x <= n) {
        y = y < n ? y : n;
        if(d[p].x >= x && d[p].x <= y)
            ans += (calc(d[p].x - 1) - calc(x - 1)) * (siz(x) % mo) % mo + (calc(y) - calc(d[p].x)) * (siz(y) % mo) % mo + ca(d[p].x) * (d[p].a % mo)% mo;
            else ans += (calc(y) - calc(x - 1)) * (siz(x) % mo) % mo;
        x <<= 1; y = y * 2 + 1; p ++;
    }
    return (ans % mo + mo) % mo;
}


ll hash(ll n) {
    ll y = n % M;
    while(h[y].x != 0 && h[y].x != n) y = (y + 1) % M;
    return y;
}

void det(ll n) {
    ll x = n;
    for(; x; x /= 2) if(h[hash(x)].de) return;
    ll y = hash(n);
    h[y].x = n; h[y].de = 1;
    ll s1 = siz(n) - h[y].a, ji = (sv(n) - h[y].b) % mo;
    h[y].a += s1, h[y].b = (h[y].b + ji) % mo;
    for(x = n; x > 1; x /= 2) {
        ll fa = x / 2, yy = hash(fa);
        h[yy].x = fa; ji = (ji + ca(x) * (s1 % mo)) % mo;
        h[yy].a += s1; h[yy].b = (h[yy].b + ji) % mo;
    }
}

void find() {
    x = 1;
    ll y = hash(1), ans = sv(1) - h[y].b, sum = n - h[y].a;
    while(1) {
        ll u = x * 2, v = u + 1, t1 = hash(u), t2 = hash(v);
        if((u > n || h[t1].de) && (v > n || h[t2].de)) break;
        ll s1 = siz(u) - h[t1].a, s2 = siz(v) - h[t2].a;
        if((v > n || h[t2].de || s1 >= s2) && s1 > sum - s1) {
            ans -= (s1 * 2 - sum) % mo * ca(u) % mo;
            x = u; continue;
        }
        if((u > n || h[t1].de || s1 < s2) && s2 > sum - s2) {
            ans -= (s2 * 2 - sum) % mo * ca(v) % mo;
            x = v; continue;
        }
        break;
    }
    printf("%lld %lld\n", x, (ans % mo + mo) % mo);
}

int main() {
    freopen("base.in", "r", stdin);
    freopen("base.out", "w", stdout);
    a2[0] = 1; fo(i, 1, 62) a2[i] = a2[i - 1] * 2, ss[i] = a2[i] - 1;
    scanf("%lld %lld %lld %lld %lld", &n, &m, &a, &b, &c);
    Bl();
    find();
    fo(ii, 1, m) {
        scanf("%lld", &x);
        det(x);
        find();
    }
}

你可能感兴趣的:(杂题)