其实就是之前用线段树维护的东西改成用splay维护
比较麻烦的就是翻转操作,需要把翻转的位置提出来插入的另一个splay中,翻转后再插回去
具体的细节还是见代码吧。
#include
#include
#include
#include
#include
#define N 100003
#define inf 1000000000
#define LL long long
using namespace std;
int tot,cnt,sz,n,m,k,point[N],v[N],nxt[N],belong[N],deep[N],son[N],f[N];
int ch[N][3],root,rt,a[N],pos[N],fa[N],rev[N];
LL mx[N],mn[N],sum[N],val[N],tag[N],size[N];
struct data{
int x,l,r;
}c[100];
char s[20];
void add(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int f1)
{
deep[x]=deep[f1]+1; size[x]=1;
for (int i=point[x];i;i=nxt[i]){
if (v[i]==f1) continue;
f[v[i]]=x;
dfs(v[i],x);
size[x]+=size[v[i]];
if (size[v[i]]>size[son[x]]) son[x]=v[i];
}
}
void dfs1(int x,int chain)
{
pos[x]=++sz; a[sz]=x; belong[x]=chain;
if (!son[x]) return;
dfs1(son[x],chain);
for (int i=point[x];i;i=nxt[i])
if (v[i]!=f[x]&&v[i]!=son[x]) dfs1(v[i],v[i]);
}
void change(int now,LL delta)
{
val[now]+=delta; mn[now]+=delta; mx[now]+=delta;
sum[now]+=size[now]*delta; tag[now]+=delta;
}
void pushdown(int now)
{
if (tag[now]) {
change(ch[now][0],tag[now]);
change(ch[now][1],tag[now]);
tag[now]=0;
}
if (rev[now]) {
swap(ch[now][1],ch[now][0]);
rev[ch[now][0]]^=1; rev[ch[now][1]]^=1;
rev[now]=0;
}
}
void update(int now)
{
int l=ch[now][0]; int r=ch[now][1]; tag[now]=0;
sum[now]=mx[now]=mn[now]=val[now]; size[now]=1;
if (l) sum[now]+=sum[l],size[now]+=size[l],
mx[now]=max(mx[now],mx[l]),mn[now]=min(mn[now],mn[l]);
if (r) sum[now]+=sum[r],size[now]+=size[r],
mx[now]=max(mx[now],mx[r]),mn[now]=min(mn[now],mn[r]);
}
int get(int x)
{
return ch[fa[x]][1]==x;
}
void rotate(int x)
{
int y=fa[x]; int z=fa[y];
pushdown(y); pushdown(x); int which=get(x);
ch[y][which]=ch[x][which^1]; fa[ch[x][which^1]]=y;
if (z) ch[z][ch[z][1]==y]=x;
fa[x]=z; fa[y]=x; ch[x][which^1]=y;
update(y); update(x);
}
void splay(int x,int tar,int &RT)
{
for (int f1;(f1=fa[x])!=tar;rotate(x))
if (fa[f1]!=tar) rotate(get(x)==get(f1)?f1:x);
if (!tar) RT=x;
}
int find(int now,int x)
{
while (true) {
pushdown(now);
if (size[ch[now][0]]>=x) now=ch[now][0];
else {
x-=size[ch[now][0]];
if (x==1) return now;
x--;
now=ch[now][1];
}
}
}
int build(int l,int r)
{
if (l>r) return 0;
if (l==r) {
++cnt; size[cnt]=1;
return cnt;
}
int mid=(l+r)/2; int now=++cnt;
ch[now][0]=build(l,mid-1); fa[ch[now][0]]=now;
ch[now][1]=build(mid+1,r); fa[ch[now][1]]=now;
update(now);
return now;
}
void init()
{
rt=++cnt; int now=++cnt;
ch[rt][1]=now; fa[now]=rt; size[now]=1;
size[rt]=2;
}
int work(int x,int y)
{
int aa=find(root,x-1); int bb=find(root,y+1);
splay(aa,0,root); splay(bb,aa,root);
return ch[ch[root][1]][0];
}
void solve(int x,int y,int z)
{
while (belong[x]!=belong[y]){
if (deep[belong[x]]int t=work(pos[belong[x]],pos[x]);
change(t,z);
x=f[belong[x]];
}
if (deep[x]>deep[y]) swap(x,y);
int t=work(pos[x],pos[y]);
change(t,z);
}
LL calc(int x,int y,int opt)
{
LL ans=(opt==3?inf:0);
while (belong[x]!=belong[y]) {
if (deep[belong[x]]int t=work(pos[belong[x]],pos[x]);
if (opt==1) ans+=sum[t];
if (opt==2) ans=max(ans,mx[t]);
if (opt==3) ans=min(ans,mn[t]);
x=f[belong[x]];
}
if (deep[x]>deep[y]) swap(x,y);
int t=work(pos[x],pos[y]);
if (opt==1) ans+=sum[t];
if (opt==2) ans=max(ans,mx[t]);
if (opt==3) ans=min(ans,mn[t]);
return ans;
}
void invert(int x,int y)
{
int top=1; tot=0; pushdown(rt);
while (belong[x]!=belong[y]) {
if (deep[belong[x]]int t=work(pos[belong[x]],pos[x]); ch[ch[root][1]][0]=0;
fa[t]=0; update(ch[root][1]); update(root);
++tot; c[tot].x=pos[belong[x]]-1;
int len=pos[x]-pos[belong[x]]+1;
c[tot].l=top+1; c[tot].r=top+len;
int aa=find(rt,top); int bb=find(rt,top+1);
splay(aa,0,rt); splay(bb,aa,rt);
ch[ch[rt][1]][0]=t; fa[t]=ch[rt][1]; rev[t]^=1;
update(ch[rt][1]); update(rt);
top=c[tot].r;
x=f[belong[x]];
}
if (deep[x]>deep[y]) swap(x,y);
int t=work(pos[x],pos[y]);
ch[ch[root][1]][0]=0;
fa[t]=0; update(ch[root][1]); update(root);
++tot; c[tot].x=pos[x]-1;
int len=pos[y]-pos[x]+1;
c[tot].l=top+1; c[tot].r=top+len;
int aa=find(rt,top); int bb=find(rt,top+1);
splay(aa,0,rt); splay(bb,aa,rt);
ch[ch[rt][1]][0]=t; fa[t]=ch[rt][1]; rev[t]^=1;
update(ch[rt][1]); update(rt);
top=c[tot].r;
rev[rt]^=1;
for (int i=tot;i>=1;i--) {
int aa=find(rt,c[i].l-1); int bb=find(rt,c[i].r+1);
splay(aa,0,rt); splay(bb,aa,rt);
int t=ch[ch[rt][1]][0];
ch[ch[rt][1]][0]=0; fa[t]=0; update(ch[rt][1]); update(rt);
aa=find(root,c[i].x); bb=find(root,c[i].x+1);
splay(aa,0,root); splay(bb,aa,root);
ch[ch[root][1]][0]=t; fa[t]=ch[root][1]; rev[t]^=1;
update(ch[root][1]); update(root);
}
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d%d",&n,&m,&k);
for (int i=1;iint x,y; scanf("%d%d",&x,&y);
add(x,y);
}
dfs(k,0);
memset(size,0,sizeof(size));
sz=1; dfs1(k,k);
root=build(1,n+2); init();
for (int i=1;i<=m;i++) {
int x,y,z; scanf("%s%d%d",s+1,&x,&y);
if (s[1]=='I'&&s[3]=='c') scanf("%d",&z),solve(x,y,z);
if (s[1]=='S') printf("%lld\n",calc(x,y,1));
if (s[1]=='M'&&s[2]=='a') printf("%lld\n",calc(x,y,2));
if (s[1]=='M'&&s[2]=='i') printf("%lld\n",calc(x,y,3));
if (s[1]=='I'&&s[3]=='v') invert(x,y);
}
}