题目大意:一棵n个节点的树,每个节点有一个权值val
操作1:修改点x的权值
操作2:查询与x的距离小于等于d的节点的权值和。
如果修改的话应该有很多种做法的。
首先建立重心树,对于每个点维护两棵权值线段树,一棵表示u(作为重心)的子树中到u距离为x的点的权值和,一棵表示到u的父重心距离为x的点的权值和。
那么每次查询的时候就是u的子树中距离为[0,d]的权值和+u在重心树中所有祖先子树中的答案。
第一部分直接从线段树中查,第二部分设u到祖先的距离为D,先得到祖先重心子树中所有[0,d-D]的权值和,再减去u所属的祖先的子重心子树中的答案,就是我们维护的第二棵线段树中的信息。
对于修改,每次只会影响 logn 个点的信息。
所以时间复杂度是 O(nlog2n)
#include
#include
#include
#include
#include
#define N 200003
#define inf 1000000000
using namespace std;
int n,m,tot,nxt[N],point[N],v[N],deep[N],fa[N][20],mi[20],belong[N];
int rt[N],rtc[N],root,size[N],f[N],vis[N],sum,sz,ans,val[N];
struct data{
int ls,rs,sum;
}tr[N*60],c[N*60];
void add(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int father)
{
deep[x]=deep[father]+1;
for (int i=1;i<=19;i++) {
if (deep[x]-mi[i]<0) break;
fa[x][i]=fa[fa[x][i-1]][i-1];
}
for (int i=point[x];i;i=nxt[i]){
if (v[i]==father) continue;
fa[v[i]][0]=x;
dfs(v[i],x);
}
}
int lca(int x,int y)
{
if (deep[x]int k=deep[x]-deep[y];
for (int i=0;i<=19;i++)
if ((k>>i)&1) x=fa[x][i];
if (x==y) return x;
for (int i=19;i>=0;i--)
if (fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void getroot(int x,int father)
{
f[x]=0; size[x]=1;
for (int i=point[x];i;i=nxt[i]){
if (v[i]==father||vis[v[i]]) continue;
getroot(v[i],x);
size[x]+=size[v[i]];
f[x]=max(f[x],size[v[i]]);
}
f[x]=max(f[x],sum-size[x]);
if (f[x]int dis(int x,int y)
{
return deep[x]+deep[y]-2*deep[lca(x,y)];
}
void divi(int x,int father)
{
belong[x]=father; vis[x]=1;
for (int i=point[x];i;i=nxt[i]){
if (vis[v[i]]) continue;
root=0; sum=size[v[i]];
getroot(v[i],x);
divi(root,x);
}
}
void update(int now)
{
int l=tr[now].ls; int r=tr[now].rs;
tr[now].sum=tr[l].sum+tr[r].sum;
}
void insert(int &i,int l,int r,int x,int val)
{
if (!i) i=++sz,tr[i].ls=tr[i].rs=tr[i].sum=0;
if (l==r) {
tr[i].sum+=val;
return;
}
int mid=(l+r)/2;
if (x<=mid) insert(tr[i].ls,l,mid,x,val);
else insert(tr[i].rs,mid+1,r,x,val);
update(i);
}
int qjsum(int i,int l,int r,int ll,int rr)
{
if (ll>rr) return 0;
if (ll<=l&&r<=rr) return tr[i].sum;
int mid=(l+r)/2; int ans=0;
if (ll<=mid) ans+=qjsum(tr[i].ls,l,mid,ll,rr);
if (rr>mid) ans+=qjsum(tr[i].rs,mid+1,r,ll,rr);
return ans;
}
void change(int u,int v,int val)
{
int D=dis(u,v);
insert(rt[u],0,n,D,val);
if (!belong[u]) return;
int f=belong[u]; D=dis(f,v);
insert(rtc[u],0,n,D,val);
change(f,v,val);
}
void calc(int u,int son,int v,int d)
{
if (!u) return;
if (u==son) ans+=qjsum(rt[u],0,n,0,d);
else {
int D=dis(u,v);
ans+=qjsum(rt[u],0,n,0,d-D);
ans-=qjsum(rtc[son],0,n,0,d-D);
}
calc(belong[u],u,v,d);
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
mi[0]=1;
for (int i=1;i<=19;i++) mi[i]=mi[i-1]*2;
while (scanf("%d%d",&n,&m)!=EOF) {
tot=0;
memset(point,0,sizeof(point)); sz=0;
memset(vis,0,sizeof(vis));
memset(fa,0,sizeof(fa));
memset(deep,0,sizeof(deep));
memset(rt,0,sizeof(rt));
memset(rtc,0,sizeof(rtc));
for (int i=1;i<=n;i++) scanf("%d",&val[i]);
for (int i=1;iint x,y; scanf("%d%d",&x,&y);
add(x,y);
}
dfs(1,0);
sum=n; f[0]=inf; root=0;
getroot(1,0); divi(root,0);
for (int i=1;i<=n;i++) change(i,i,val[i]);
for (int i=1;i<=m;i++) {
char s[10]; int x,v1; ans=0;
scanf("%s%d%d",s+1,&x,&v1);
if (s[1]=='!') {
change(x,x,v1-val[x]);
val[x]=v1;
}
if (s[1]=='?') calc(x,x,x,v1),printf("%d\n",ans);
}
}
}