bzoj3572_虚树的构建+lca

世界树......冰封王座肯定不是临时议事处。

网上神犇的题解都说要用虚树, 我试了各种办法都没搜到跟虚树有关系的东西QAQ。 终于, 在贴吧大神和DG的帮助下搞懂了这道题。 虽然神犇们不说什么是虚树, 我在这里介绍这一题中的应用。(当然, 有些大神其实根本不知道什么是虚树也乱发题解了)

对于每次询问的m个点, 在原树上将它们连接起来形成一个子图, 同时把这m个点的lca加入子图, 对非询问点而度数为2的点只要缩掉, 就得到了与当前询问对应的一棵虚树。(个人感觉不用在意这个名字, 只是重新构造了一个图而已)

这时我们会发现对于虚树中的每一条边, 其两端的节点都有自己的归属(这是句废话), 那么讨论这条链上的点的归属情况, 最后统计即可得到答案。

建虚树的过程请看代码......反正我当时也是看代码总结的, 总之就是用栈维护一条最右链, 每次用还未加入的元素以及该元素和栈顶的lca更新就行了。

ps: 见到奇奇怪怪的变量名称不要被吓到, 其实是有意义的。

#include 
#include 
#include 
#include 
#define N 300000 + 10
#define H 20
#define INF 1000000000

using namespace std;

struct node
{
    int dist, bel;
    node() { }
    node(int x, int y)
    {
        dist = x;
        bel = y;
    }
}g[N];
struct edge
{
    int to, next;
}e[2*N];
int n, m, q, num, top, ind, cont, p[N], flag[N];
int fa[N][H], fa_v[N], r[N], d[N], h[N], s[N], st[N];
int sum[N], dfn[N], ans[N], val[N], delta[N];
inline bool cmp(int x, int y)
{ return dfn[x] < dfn[y]; }
node min(node a, node b)
{
    if (a.dist == b.dist) return a.bel < b.bel ? a : b;
    return a.dist < b.dist ? a : b;
}
void read(int &x)
{
    x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9')
    {
        x = 10*x + c - '0';
        c = getchar();
    }
}
void add(int x, int y)
{
    e[++num].to = y;
    e[num].next = p[x];
    p[x] = num;
}
void init()
{
    int x, y;
    read(n);
    for (int i = 1; i < n; ++i)
    {
        read(x), read(y);
        add(x, y);
        add(y, x);
    }
}
void bfs()
{
    queueq;
    q.push(1);
    fa[1][0] = 1;
    flag[1] = 1;
    while(!q.empty())
    {
        int x = q.front();
        q.pop();
        for (int i = 1; i < H; ++i)
        fa[x][i] = fa[fa[x][i-1]][i-1];
        for (int i = p[x]; i; i = e[i].next)
        {
            int k = e[i].to;
            if (!flag[k])
            {
                d[k] = d[x] + 1;
                fa[k][0] = x;
                flag[k] = 1;
                q.push(k);
            }
        }
    }
}
void dfs(int x)
{
    sum[x] = 1;
    dfn[x] = ++ind;
    for (int i = p[x]; i; i = e[i].next)
    {
        int k = e[i].to;
        if (k != fa[x][0])
        {
            dfs(k);
            sum[x] += sum[k];
        }
    }
}
int lca(int x, int y)
{
    if (d[x] > d[y]) swap(x, y);
    int l = x, r = y;
    for (int mid = d[r] - d[l], i = 0; mid; ++i, mid >>= 1)
    if (mid & 1) r = fa[r][i];
    if (l == r) return r;
    for (int i = H - 1; i >= 0; i--)
    {
        if (fa[l][i] == fa[r][i]) continue;
        l = fa[l][i], r = fa[r][i];
    }
    return fa[r][0];
}
int find(int x, int h)
{
    for (int i = H - 1; i >= 0; i--)
    if (d[fa[x][i]] >= h) x = fa[x][i];
    return x;
}
void solve()
{
    top = cont = 0;
    read(m);
    for (int i = 1; i <= m; ++i)
    {
        read(h[i]);
        s[++cont] = r[i] = h[i];
        g[h[i]] = node(0, h[i]);
        ans[h[i]] = 0;
    }
    sort(h+1, h+m+1, cmp);
    for (int i = 1; i <= m; ++i)
    {
        int x = h[i];
        if (!top)
        {
            st[++top] = x;
            fa_v[x] = 0;
        }
        else
        {
            int anc = lca(x, st[top]);
            while(d[st[top]] > d[anc])
            {
                if (d[st[top-1]] <= d[anc])
                fa_v[st[top]] = anc;
                top--;
            }
            if (st[top] != anc)
            {
                s[++cont] = anc;
                g[anc] = node(INF, 0);
                fa_v[anc] = st[top];
                st[++top] = anc;
            }
            fa_v[x] = anc;
            st[++top] = x;
        }
    }
    sort(s+1, s+cont+1, cmp);
    for (int i = 1; i <= cont; ++i)
    {
        int x = s[i];
        val[x] = sum[x];
        if (i > 1) delta[x] = d[x] - d[fa_v[x]];
    }
    for (int i = cont; i > 1; i--)
    {
        int x = s[i];
        g[fa_v[x]] = min(g[fa_v[x]], node(g[x].dist+delta[x], g[x].bel));
    }
    for (int i = 2; i <= cont; ++i)
    {
        int x = s[i];
        g[x] = min(g[x], node(g[fa_v[x]].dist+delta[x], g[fa_v[x]].bel));
    }
    for (int i = 1; i <= cont; ++i)
    {
        int x = s[i], anc = fa_v[x];
        if (i == 1) ans[g[x].bel] += n - sum[x];
        else
        {
            int k = find(x, d[anc]+1);
            int del = sum[k] - sum[x];
            val[anc] -= sum[k];
            if (g[anc].bel == g[x].bel) ans[g[x].bel] += del;
            else
            {
                int mid = d[x] - ((g[x].dist+g[fa_v[x]].dist+delta[x])/2-g[x].dist);
                if ((g[x].dist+g[fa_v[x]].dist+delta[x]) % 2 == 0 && g[x].bel > g[fa_v[x]].bel) ++mid;
                int tmp = sum[find(x, mid)] - sum[x];
                ans[g[x].bel] += tmp;
                ans[g[fa_v[x]].bel] += del - tmp;
            }
        }
    }
    for (int i = 1; i <= cont; ++i)
    ans[g[s[i]].bel] += val[s[i]];
    for (int i = 1; i <= m; ++i)
    printf("%d ", ans[r[i]]);
    putchar('\n');
}
void deal()
{
    bfs();
    dfs(1);
    read(q);
    while(q--) solve();
}
int main()
{
    init();
    deal();
    return 0;
}


你可能感兴趣的:(bzoj,数据结构,图论)