BZOJ 1146 网络管理Network(树链剖分+BST)

题目链接:http://61.187.179.132/JudgeOnline/problem.php?id=1146

题意:给出一棵树,每个节点有一个权值。两种操作:(1) 修改某个节点的权值;(2)询问a到b路径上第K大的权值。

思路:首先,DFS两次得到树链(重新标号,一个树链上的节点的标号连续)。建立线段树,再为线段树每个节点建立一棵BST。(1)修改时,在线段树上依次向下修改,并且到达每个节点时要修改相应的BST,即删掉原来的权值,插入新的权值,复杂度(logn)^2;(2)查询:转化成区间第K小的。计算出(a,b)的最近公共祖先lca。二分答案x,在路径(a,lca)和(b,lca)上分别计算小于等于x的节点个数。比如计算(a,lca),就是顺着a所在树链依次向上,最多有logn个树链,然后在每个树链上查找,其实就是在线段树上查找,整个复杂度是(logn)^3。

 




vector<int> g[N];
int leaf[N],f[N][20],size[N],a[N],id[N],wh[N];
int visit[N],dep[N],head[N];
int T;


void dfs(int u)
{
    visit[u]=1; size[u]=1;
    int i,v;
    FOR0(i,SZ(g[u]))
    {
        v=g[u][i];
        if(visit[v]) continue;
        f[v][0]=u;
        dep[v]=dep[u]+1;
        dfs(v);
        size[u]+=size[v];
    }
    if(size[u]==1) leaf[u]=1;
}


void DFS(int u,int top)
{
    head[u]=top; id[++T]=u; wh[u]=T;
    if(leaf[u]) return;
    
    int maxSon,Max=-1;
    int i,v;
    FOR0(i,SZ(g[u]))
    {
        v=g[u][i];
        if(v==f[u][0]) continue;
        if(size[v]>Max) Max=size[v],maxSon=v;
    }
    DFS(maxSon,top);
    FOR0(i,SZ(g[u]))
    {
        v=g[u][i];
        if(v!=maxSon&&v!=f[u][0]) DFS(v,v);
    }
}


int n,m;


void init()
{
    RD(n,m);
    int i,x,y;
    FOR1(i,n) RD(a[i]);
    FOR1(i,n-1)
    {
        RD(x,y);
        g[x].pb(y);
        g[y].pb(x);
    }
    dfs(1); DFS(1,1);
    int j;
    for(i=1;(1<<i)<=n;i++)
    {
        for(j=1;j<=n;j++) f[j][i]=f[f[j][i-1]][i-1];
    }
}


int getLca(int x,int y)
{
    if(x==y) return x;
    if(dep[x]>dep[y]) swap(x,y);
    int t=dep[y]-dep[x],i;
    for(i=0;i<20;i++) if(t&(1<<i))
    {
        y=f[y][i];
    }
    if(x==y) return x;
    for(i=19;i>=0;i--)
    {
        if(f[x][i]!=f[y][i]&&f[x][i]&&f[y][i])
        {
            x=f[x][i];
            y=f[y][i];
        }
    }
    return f[x][0];
}


struct node
{
    int L,R,size,cnt,key;
};


struct Node
{
    int L,R,root;
};


node tree[N<<5];
Node segTree[N<<2];
int e;


int newNode(int key)
{
    e++;
    tree[e].key=key;
    tree[e].L=tree[e].R=0;
    tree[e].cnt=1;
    tree[e].size=1;
    return e;
}


void insert(int x,int key)
{
    int p=x;
    while(1)
    {
        tree[p].size++;
        if(tree[p].key==key)
        {
            tree[p].cnt++;
            return;
        }
        if(tree[p].key<key)
        {
            if(tree[p].R==0)
            {
                tree[p].R=newNode(key);
                return;
            }
            else p=tree[p].R;
        }
        else
        {
            if(tree[p].L==0)
            {
                tree[p].L=newNode(key);
                return;
            }
            else p=tree[p].L;
        }
    }
}


void build(int t,int L,int R)
{
    segTree[t].L=L;
    segTree[t].R=R;
    segTree[t].root=newNode(a[id[L]]);
    if(L==R) return;
    int mid=(L+R)>>1;
    build(t*2,L,mid);
    build(t*2+1,mid+1,R);
    int i;
    for(i=L+1;i<=R;i++) insert(segTree[t].root,a[id[i]]);
}




void del(int x,int key)
{
    int p=x;
    while(1)
    {
        tree[p].size--;
        if(tree[p].key==key)
        {
            tree[p].cnt--;
            return;
        }
        if(key<tree[p].key) p=tree[p].L;
        else p=tree[p].R;
    }
}




void modify(int t,int pos,int x,int y)
{
    del(segTree[t].root,x);
    insert(segTree[t].root,y);
    if(segTree[t].L==segTree[t].R) return;
    int mid=(segTree[t].L+segTree[t].R)>>1;
    if(pos<=mid) modify(t*2,pos,x,y);
    else modify(t*2+1,pos,x,y);
}


int cal(int x,int key)
{
    int ans=0,p=x;
    while(p)
    {
        if(key==tree[p].key)
        {
            ans+=tree[p].size-tree[tree[p].R].size;
            return ans;
        }
        if(key<tree[p].key) p=tree[p].L;
        else
        {
            ans+=tree[p].size-tree[tree[p].R].size;
            p=tree[p].R;
        }
    }
    return ans;
}


int query(int t,int L,int R,int key)
{
    if(segTree[t].L==L&&segTree[t].R==R)
    {
        return cal(segTree[t].root,key);
    }
    int mid=(segTree[t].L+segTree[t].R)>>1;
    if(R<=mid) return query(t*2,L,R,key);
    if(L>mid) return query(t*2+1,L,R,key);
    return query(t*2,L,mid,key)+query(t*2+1,mid+1,R,key);
}


int query(int L,int R,int key)
{
    int ans=0;
    while(head[L]!=head[R])
    {
        ans+=query(1,wh[head[L]],wh[L],key);
        L=f[head[L]][0];
    }
    ans+=query(1,wh[R],wh[L],key);
    return ans;
}


int calKth(int x,int y,int lca,int K)
{
    int low=0,high=INF,mid,ans,cnt;
    while(low<=high)
    {
        mid=(low+high)>>1;
        cnt=query(x,lca,mid)+query(y,lca,mid);
        if(a[lca]<=mid) cnt--;
        if(cnt>=K) ans=mid,high=mid-1;
        else low=mid+1;
    }
    return ans;
}


int main()
{
    init(); build(1,1,n);
    int op,x,y,lca,cnt;
    while(m--)
    {
        RD(op); RD(x,y);
        if(op==0) modify(1,wh[x],a[x],y),a[x]=y;
        else 
        {
            lca=getLca(x,y);
            cnt=dep[x]+dep[y]-dep[lca]*2+1;
            if(cnt<op) puts("invalid request!");
            else PR(calKth(x,y,lca,cnt-op+1));
        }
    }
}



你可能感兴趣的:(NetWork)