典型的树链剖分题,只要找到所有重链的权值然后贪心找前m个的和就行了。
具体解法就是先dfs找到所有叶子节点从根走下来得到的总权值,排序之后将有重复路径的节点权值减去,这就是一个找重链的过程,对于每个节点,他的所有儿子中能够找到最大权值链的那一条就是和这个节点在同一条重链上的,其他儿子节点作为其他重链的新起点,最后结构造出了一个包含重链的树。每个节点都一定且仅在一条重链中,并且每条重链都包含一个叶子节点,所以只要找到权值最大的m条重链就是最大总权值了。
一开始并没有想到写树链剖分,用线段树写的。
前面的处理一样,找到每个叶子节点的总权值,然后以叶子节点建树。在dfs找叶子节点总权值时回溯处理出来每个节点在这个线段树中包含的区间。然后线段树求m次最大值就行了。每次query之后从该叶子节点向上搜索找到他的所有父节点对这些父节点在线段树中包含的叶子节点都进行区间更新,就是都减去这个节点的权值,保证已经更新过的节点就不更新了,所以整个更新的复杂度为O(n*logn)。找最大的m个节点的复杂度是O(m*logn)。
所以树链剖分的复杂度是O(2*n),线段树叶子节点是O(n*logn+m*logn),实际表现差不多。
树链剖分:
#include
#define INF 0x3f3f3f3f
#define MOD 1000000007
#define EPS 1e-6
#define N 112345
using namespace std;
struct node
{
long long val,id;
friend bool operator < (node a, node b)
{
return a.val > b.val;
}
}p[N];
long long n,m,res,flag,tot;
vectorzi[N];
long long fa[N],val[N],ans[N];
bool vis[N];
void init()
{
for(int i=0;i<=n;i++)zi[i].clear();
memset(vis,0,sizeof(vis));
res=0;tot=0;
}
void dfs(long long now, long long vall)
{
int num=zi[now].size();
if(num==0)
p[tot].id=now, p[tot++].val=val[now]+vall;
else
for(int i=0;i=0&&j
#include
#define INF 0x3f3f3f3f
#define MOD 1000000007
#define EPS 1e-6
#define N 112345
using namespace std;
struct node
{
long long sum,side;
}sum[N<<2],ttt;
long long n,res,flag,tot;
long long a[N],b[N],hehe[N],xixi[N];
#define root 1 , tot , 1
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
vectorzi[N];
long long fa[N];
bool vis[N];
long long val[N];
long long add[N<<2];
void pushUp(long long rt)
{
if(sum[rt<<1].sum>=sum[rt<<1|1].sum)
{
sum[rt]=sum[rt<<1];
}
else
{
sum[rt]=sum[rt<<1|1];
}
}
void pushDown(long long l,long long r,long long rt)
{
if(add[rt])
{
long long m = (l+r)>>1;
add[rt<<1] += add[rt];
add[rt<<1|1] += add[rt];
sum[rt<<1].sum += add[rt];
sum[rt<<1|1].sum += add[rt];
add[rt] = 0;
}
}
void update(long long l,long long r,long long rt,long long ql,long long qr,long long val)
{
if(l>qr||ql>r)return;
if(l>=ql&&r<=qr)
{
sum[rt].sum += val;
add[rt] += val;
return;
}
pushDown(l,r,rt);
long long m = (l+r)>>1;
if(ql<=m)update(lson,ql,qr,val);
if(qr>m)update(rson,ql,qr,val);
pushUp(rt);
}
void build(long long l,long long r,long long rt)
{
add[rt]=0;
if(l == r)
{
sum[rt].sum=hehe[res++];
sum[rt].side=xixi[res-1];
return;
}
long long m = (l+r)>>1;
build(lson);
build(rson);
pushUp(rt);
}
node query(long long l,long long r,long long rt,long long ql,long long qr)
{
if(l>qr||ql>r)
return ttt;
if(l>=ql&&r<=qr)
return sum[rt];
pushDown(l,r,rt);
long long m = l+r>>1;
node x=query(l,m,rt<<1,ql,qr);
node y=query(m+1,r,rt<<1|1,ql,qr);
return x.sum>=y.sum?x:y;
}
void init()
{
for(int i=0;i<=n;i++)zi[i].clear();
memset(vis,0,sizeof(vis));
ttt.sum=-1;
}
void dfs(int now,long long vall)
{
long long len=zi[now].size();
if(len==0)
{
hehe[tot++]=val[now]+vall;
xixi[tot-1]=now;
a[now]=b[now]=tot-1;
return ;
}
for(int i=0;i