原题请看这里
国家可以表示为 n n n 个节点 n − 1 n-1 n−1 条边的图。 F ( x ) F(x) F(x) 表示节点 x x x 的疫情严重性。有以下三种修改/查询:
有多个测试用例。 输入的第一行包含一个整数 T ( 1 ≤ T ≤ 5 ) T(1 \leq T \leq 5) T(1≤T≤5),表示测试用例的数量。
对于每个测试用例,第一行包含两个整数 n , m ( 1 ≤ n , m ≤ 5 × 1 0 4 ) n,m(1 \leq n,m \leq 5 \times 10 ^ 4) n,m(1≤n,m≤5×104),代表城市的数量以及事件和查询的数量。 以下 n − 1 n-1 n−1行描述了该国家/地区的所有路径,每条路径均包含两个整数 x , y ( 1 ≤ x , y ≤ n ) x,y(1 \leq x,y \leq n) x,y(1≤x,y≤n),代表城市 x x x和 y y y之间的道路。 以下 m m m行描述了所有事件,每个事件均以整数 o p t ( 1 ≤ o p t ≤ 3 ) \mathit {opt}(1 \leq \mathit {opt} \leq 3) opt(1≤opt≤3)开始,并且如果 o p t \mathit{opt} opt为
每个查询输出一个整数。
1
5 6
1 2
1 3
2 4
2 5
1 1 5
3 4
2 1
1 2 7
3 3
3 1
3
9
6
首先,我们对每一个操作进行分析:
o p t = 1 : opt=1: opt=1:
在树上求距离可以用 l c a lca lca,于是我们把 d i s t dist dist这个函数化开,设 x , y , l c a x,y,lca x,y,lca的深度分别为 d e p [ x ] , d e p [ y ] , d e p [ l c a ] : dep[x],dep[y],dep[lca]: dep[x],dep[y],dep[lca]:
w − d i s t ( x , y ) w-dist(x,y) w−dist(x,y)
= w − ( d e p [ x ] + d e p [ y ] − 2 d e p [ l c a ] ) =w-(dep[x]+dep[y]-2dep[lca]) =w−(dep[x]+dep[y]−2dep[lca])
= w − d e p [ x ] − d e p [ y ] + 2 d e p [ l c a ] =w-dep[x]-dep[y]+2dep[lca] =w−dep[x]−dep[y]+2dep[lca]
由此我们发现:当 o p t = = 1 opt==1 opt==1时 w − d e p [ x ] w-dep[x] w−dep[x]是固定的,而对于每一个节点, d e p [ y ] dep[y] dep[y]也是固定的,所以我们设 A + = w − d e p [ x ] , B + + A+=w-dep[x],B++ A+=w−dep[x],B++
A A A表示所有关于 x x x的结果, B B B表示 d e p [ y ] dep[y] dep[y]的次数
所以对于每次操作,我们都可以查询一下之前所有的结果,即:
f ( y ) = A − B ∗ d e p [ y ] + 2 ∑ ( d e p [ l c a ] ) f(y)=A-B*dep[y]+2\sum(dep[lca]) f(y)=A−B∗dep[y]+2∑(dep[lca])
o p t = 2 : opt=2: opt=2:
对于这个操作,我们需要考虑正负,所以我们需要储存一下 m i n ( 0 , f ( y ) ) min(0,f(y)) min(0,f(y)).
所以我们开一个数组存一下每次y的消除结果:
f f y + = m i n ( 0 , f ( y ) ) − f ( y ) ff_y+=min(0,f(y))-f(y) ffy+=min(0,f(y))−f(y)
之后我们算答案只要加上 f f y ff_y ffy就行了:
f ( y ) = A − B ∗ d e p [ y ] + 2 ∑ ( d e p [ l c a ] ) + f f y f(y)=A-B*dep[y]+2\sum(dep[lca])+ff_y f(y)=A−B∗dep[y]+2∑(dep[lca])+ffy
o p t = = 3 opt==3 opt==3
输出结果即可
代码是我队友写的,好丑好丑呜呜呜呜
#include
#include
#include
#define ll long long
#define I1 i<<1
#define I2 i<<1|1
using namespace std;
const int MAXN=1e5+5;
int v[MAXN],h[MAXN],fa[MAXN],id[MAXN],ld[MAXN],rd[MAXN],siz[MAXN],dep[MAXN],son[MAXN],top[MAXN],vis[MAXN],ff[MAXN],n,m,t,x,y,op;
ll cnt,tot,A,B;
struct node{int to,next;}a[MAXN<<1];
struct Tree{int l,r,lz;ll sum;}sgt[MAXN<<2],e[MAXN<<2];
void add(int x,int y){
a[++cnt].to=y;
a[cnt].next=h[x];
h[x]=cnt;
}
void p(int x){
siz[x]=1;
son[x]=0;
for(int i=h[x];~i;i=a[i].next){
int nxt=a[i].to;
if(nxt==fa[x])continue;
fa[nxt]=x;
dep[nxt]=dep[x]+1;
p(nxt);
siz[x]+=siz[nxt];
if(siz[nxt]>siz[son[x]])
son[x]=nxt;
}
}
void dfs(int x,int y){
top[x]=y;
ld[x]=++tot;
id[tot]=x;
if(son[x]) dfs(son[x],y);
for(int i=h[x],nxt;~i;i=a[i].next){
nxt=a[i].to;
if(nxt^son[x]&&nxt^fa[x])
dfs(nxt,nxt);
}
}
void build(int i,int l,int r){
sgt[i].l=l;
sgt[i].r=r;
sgt[i].lz=0;
if(l==r){
sgt[i].sum=v[id[l]];
return;
}
int mid=(l+r)>>1;
build(I1,l,mid);
build(I2,mid+1,r);
sgt[i].sum=sgt[I1].sum+sgt[I2].sum;
}
void down(int x){
if(sgt[x].lz){
sgt[x].sum+=sgt[x].lz*(sgt[x].r-sgt[x].l+1);
sgt[x<<1].lz+=sgt[x].lz;
sgt[x<<1|1].lz+=sgt[x].lz;
sgt[x].lz=0;
}
}
ll q(int x,int l,int r){
if(sgt[x].l==l&&sgt[x].r==r) return sgt[x].sum+sgt[x].lz*(r-l+1);
down(x);
int m=sgt[x].l+sgt[x].r>>1;
if(r<=m) return q(x<<1,l,r);
else if(l>m) return q(x<<1|1,l,r);
else return q(x<<1,l,m)+q(x<<1|1,m+1,r);
}
void ud(int i,int l,int r,ll v){
if(sgt[i].l==l&&sgt[i].r==r){
sgt[i].sum+=(r-l+1)*v;
sgt[i].lz+=v;
return;
}
if(sgt[i].lz) down(i);
int mid=sgt[i].l+sgt[i].r>>1;
if(r<=mid)ud(I1,l,r,v);
else if(l>mid)ud(I2,l,r,v);
else ud(I1,l,mid,v),ud(I2,mid+1,r,v);
sgt[i].sum=sgt[I1].sum+sgt[I2].sum;
}
void my(int x,int l,int r,ll val){
if(sgt[x].l==l&&sgt[x].r==r){
sgt[x].lz+=val;
return ;
}
sgt[x].sum+=(r-l+1)*val;
int m=sgt[x].l+sgt[x].r>>1;
if(r<=m) my(x<<1,l,r,val);
else if(l>m) my(x<<1|1,l,r,val);
else my(x<<1,l,m,val),my(x<<1|1,m+1,r,val);
}
void cy(int x,int y,ll val){
while(top[x]^top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
my(1,ld[top[y]],ld[y],val);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
my(1,ld[x],ld[y],val);
}
ll q1(int x,int y){
ll ret=A-B*dep[x]+ff[x];
while(top[x]^top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
ret+=q(1,ld[top[y]],ld[y]);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
return ret+q(1,ld[x],ld[y]);
}
int main(){
for(scanf("%d",&t);t--;cnt=tot=A=B=0){
scanf("%d%d",&n,&m);
memset(h,-1,sizeof(h));
memset(vis,0,sizeof(vis));
memset(dep,0,sizeof(dep));
memset(siz,0,sizeof(siz));
memset(top,0,sizeof(top));
memset(rd,0,sizeof(rd));
memset(ld,0,sizeof(ld));
memset(ff,0,sizeof(ff));
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dep[1]=1;
p(1);
dfs(1,1);
build(1,1,tot);
while(m--){
scanf("%d",&op);
if(op==1){
scanf("%d%d",&x,&y);
cy(1,x,2ll),
A+=y-dep[x];B++;
}
else{
scanf("%d",&x);
ll val=q1(x,1);
if(op==2)ff[x]+=min(0ll,val)-val;
if(op==3)printf("%lld\n",val);
}
}
}
}