link cut tree水题
将树链看成序列,答案即为sigma(a[i]*i*(n-i+1))/C(n+1,2),分母显然不用维护,分子拆开来搞一搞就好了
(维护信息的link cut tree的标记必须即时生效TAT否则会WA)
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<vector> #include<cmath> using namespace std; #define rep(i,l,r) for(int i=l;i<=r;i++) #define per(i,r,l) for(int i=r;i>=l;i--) #define mmt(a,v) memset(a,v,sizeof(a)) #define tra(i,u) for(int i=head[u];i;i=e[i].next) const int N=50000+5; typedef long long ll; ll gcd(ll a,ll b){return b?gcd(b,a%b):a;} int fa[N],ch[N][2]; ll sz[N],a1[N],a2[N],add[N],sum[N],w[N]; bool rev[N]; ll sqr(ll x){return x*x;} ll sum1(ll x){return x*(x+1)/2;} ll sum2(ll x){return x*(x+1)*(2*x+1)/6;} ll c2(ll x){return x*(x-1)/2;} ll calc(int x){ return a1[x]*(sz[x]+1)-a2[x]; } void pushup(int x){ int l=ch[x][0],r=ch[x][1]; sz[x]=sz[l]+sz[r]+1; sum[x]=sum[l]+sum[r]+w[x]; a1[x]=a1[l]+w[x]*(sz[l]+1)+a1[r]+sum[r]*(sz[l]+1); a2[x]=a2[l]+w[x]*sqr(sz[l]+1)+a2[r]+sqr(sz[l]+1)*sum[r]+2*(sz[l]+1)*(a1[r]); } void workadd(int x,ll d){ w[x]+=d; sum[x]+=d*sz[x]; a1[x]+=d*sum1(sz[x]); a2[x]+=d*sum2(sz[x]); } void workrev(int x){ swap(ch[x][0],ch[x][1]); a2[x]=a2[x]+sqr(sz[x]+1)*sum[x]-2*(sz[x]+1)*a1[x]; a1[x]=sum[x]*(sz[x]+1)-a1[x]; } void pushdown(int x){ int l=ch[x][0],r=ch[x][1]; if(rev[x]){ rev[l]^=1;rev[r]^=1; if(l)workrev(l); if(r)workrev(r); rev[x]^=1; } if(add[x]){ if(l)add[l]+=add[x],workadd(l,add[x]); if(r)add[r]+=add[x],workadd(r,add[x]); add[x]=0; } } void debug(int x){ if(!x)return; pushdown(x); debug(ch[x][0]); printf("%d %d %d\n",x,ch[x][0],ch[x][1]); debug(ch[x][1]); } bool isroot(int x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;} void rotate(int x){ int y=fa[x],z=fa[y],l=ch[y][1]==x,r=l^1; if(!isroot(y))ch[z][ch[z][1]==y]=x; fa[x]=z;fa[y]=x;fa[ch[x][r]]=y; ch[y][l]=ch[x][r];ch[x][r]=y; pushup(y);pushup(x); } int st[N],tp; void splay(int x){ st[tp=1]=x; for(int i=x;!isroot(i);i=fa[i])st[++tp]=fa[i]; while(tp)pushdown(st[tp--]); while(!isroot(x)){ int y=fa[x],z=fa[y]; if(!isroot(y)){ if(ch[y][0]==x^ch[z][0]==y)rotate(x); else rotate(y); } rotate(x); } } void access(int x){ for(int t=0;x;t=x,x=fa[x]) splay(x),ch[x][1]=t,pushup(x); } void makeroot(int x){ access(x);splay(x);rev[x]^=1;workrev(x); } int find(int x){ access(x);splay(x); while(ch[x][0])x=ch[x][0]; return x; } void link(int x,int y){ if(find(x)==find(y))return; makeroot(x);fa[x]=y; } void split(int x,int y){ makeroot(x);access(y);splay(y); } int pre(int x){ x=ch[x][0]; while(ch[x][1])x=ch[x][1]; return x; } void cut(int x,int y){ if(find(x)!=find(y))return; split(x,y); if(x!=pre(y))return; ch[y][0]=fa[x]=0;pushup(y); } void query(int x,int y){ if(find(x)!=find(y)){ puts("-1"); return; } split(x,y); ll a=calc(y),b=c2(sz[y]+1),c=gcd(a,b); printf("%lld/%lld\n",a/c,b/c); } void update(int x,int y,int d){ if(find(x)!=find(y))return; split(x,y); add[y]+=d;workadd(y,d); } int main(){ //freopen("a.in","r",stdin); //freopen("a.out","w",stdout); int n,m;scanf("%d%d",&n,&m); rep(i,1,n){ scanf("%lld",&w[i]); sum[i]=a1[i]=a2[i]=w[i]; sz[i]=1; } rep(i,2,n){ int u,v;scanf("%d%d",&u,&v); link(u,v); } while(m--){ int opt,u,v; scanf("%d%d%d",&opt,&u,&v); if(opt==1)cut(u,v); else if(opt==2)link(u,v); else if(opt==3){ int d;scanf("%d",&d); update(u,v,d); }else query(u,v); } return 0; }