好虚....
这题的做法显然是LCT,然后就是裸的操作,对于这个乘法和加法,我们怎么做标记呢?
就是将乘法和加法合并起来算做一个标记, mul 数组和 plus 数组两个数组算做一个标记,设权值为 x 那么标记的意思就是把所有子节点的的 x 都变为:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<queue>
#include<set>
#include<map>
#include<vector>
#include<algorithm>
#include<iostream>
#define N 303333
#define R 51061
#define plus plu
using namespace std;
int sc()
{
int i=0,f=1; char c=getchar();
while(c>'9'||c<'0'){if(c=='-')f=-1; c=getchar();}
while(c>='0'&&c<='9') i=i*10+c-'0',c=getchar();
return i*f;
}
bool rev[N];
int fa[N],ch[N][2],st[N];
int n,m;
unsigned int v[N],val[N],plus[N],mul[N],size[N];
bool Root(int x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;}
void push_up(int x)
{
val[x]=(val[ch[x][0]]+val[ch[x][1]]+v[x])%R;
size[x]=(size[ch[x][0]]+size[ch[x][1]]+1)%R;
}
void cal(int x,int Mul,int Plus)
{
if(!x)return;
v[x]=(v[x]*Mul+Plus)%R;
val[x]=(val[x]*Mul+Plus*size[x])%R;
plus[x]=(plus[x]*Mul+Plus)%R;
mul[x]=mul[x]*Mul%R;
}
void push_down(int x)
{
if(rev[x])
{
rev[ch[x][0]]^=1,rev[ch[x][1]]^=1;
swap(ch[x][0],ch[x][1]);rev[x]=0;
}
int l=ch[x][0],r=ch[x][1];
if(mul[x]!=1||plus[x]!=0)
cal(l,mul[x],plus[x]),cal(r,mul[x],plus[x]);
mul[x]=1;plus[x]=0;
}
void rotate(int x)
{
int y=fa[x],z=fa[y],l,r;
if(ch[y][0]==x)l=0;else l=1;r=l^1;
if(!Root(y))
if(ch[z][0]==y)ch[z][0]=x;else ch[z][1]=x;
fa[x]=z;fa[y]=x;fa[ch[x][r]]=y;
ch[y][l]=ch[x][r];ch[x][r]=y;
push_up(y),push_up(x);
}
void splay(int x)
{
int top=0; st[++top]=x;
for(int i=x;!Root(i);i=fa[i])st[++top]=fa[i];
while(top)push_down(st[top--]);
while(!Root(x))
{
int y=fa[x],z=fa[y];
if(!Root(y))
{
if(ch[z][0]==y^ch[y][0]==x)rotate(x);
else rotate(y);
}
rotate(x);
}
}
void access(int x)
{
for(int t=0;x;t=x,x=fa[x])
splay(x),ch[x][1]=t,push_up(x);
}
void make_root(int x)
{
access(x),splay(x),rev[x]^=1;
}
void link(int x,int y)
{
make_root(x),fa[x]=y;
}
void cut(int x,int y)
{
make_root(x),access(y),splay(y);
ch[y][0]=fa[x]=0; push_up(y);
}
int main()
{
n=sc(),m=sc();
for(int i=1;i<=n;i++)v[i]=val[i]=mul[i]=size[i]=1;
for(int i=1;i<n;i++)
{
int x=sc(),y=sc();
link(x,y);
}
while(m--)
{
char s[5];scanf("%s",s);
if(s[0]=='+')
{
int x=sc(),y=sc(),z=sc();
make_root(x),access(y),splay(y);
cal(y,1,z);
}
else if(s[0]=='-')
{
int x=sc(),y=sc(),a=sc(),b=sc();
cut(x,y);link(a,b);
}
else if(s[0]=='*')
{
int x=sc(),y=sc(),z=sc();
make_root(x),access(y),splay(y);
cal(y,z,0);
}
else if(s[0]=='/')
{
int x=sc(),y=sc();
make_root(x),access(y),splay(y);
printf("%d\n",val[y]);
}
}
return 0;
}