Portal
这题咕了好久,因为前面一直认为线段树合并是非常玄学的,就一直没有打(虽然现在也认为它很玄学
线段树合并只对一些相似度很低的线段树进行合并,例如大范围的权值线段树就可以,具体操作是这样的:
int merge(int x,int y,int l,int r){
if(!x || !y) return x+y;
int t=++num;
if(l==r){balabala;return ;}
int mid=(l+r)/2;
ls[t]=merge(ls[x],ls[y],l,mid);
rs[t]=merge(rs[x],rs[y],mid+1,r);
update(t);
}
当然您也可以选择不新开节点,这样空间复杂度就一定有保证了。
否则的话空间复杂度和时间复杂度是相关的,这就变得很玄学。
我们来讨论一下这个东西,若有两棵线段树在一个点重复的话,那么就会耗费log n的时间来开出一条新链,否则不耗费时间。
所以说相似度要很低才行,否则时间空间都爆炸。
这道题就很好做了。
首先树上差分,然后在x,y打上一个加标记,在lca(x,y)的地方打上一个删除标记。
那么我们需要支持的事情就是单点修改,访问连续1区间的贡献之和。
这个东西维护一个前缀连续1,后缀连续1,贡献和 就可以了。
#include
#include
#include
#include
#include
using namespace std;
const int N=500010,M=9000010;
int n,m,num;
struct edge{
int y,next,c;
}s[N<<1];
int first[N],len;
int f[N][20],dep[N];
vector A[N],D[N];
int ls[M],rs[M],lmax[M],rmax[M],root[N],tot[M];
long long ans=0;
const long long mod=1e9+7;
void read(int&x){
char ch=getchar();x=0;
while(ch<'0' || ch>'9') ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
}
void ins(int x,int y,int c){
s[++len]=(edge){y,first[x],c};first[x]=len;
s[++len]=(edge){x,first[y],c};first[y]=len;
}
void dfs(int x,int fa){
for(int i=first[x];i!=0;i=s[i].next)if(s[i].y!=fa){
int y=s[i].y;
f[y][0]=x;for(int k=1;k<=19;k++) f[y][k]=f[f[y][k-1]][k-1];
dep[y]=dep[x]+1;
dfs(y,x);
}
}
int lca(int x,int y){
if(dep[y]=0;k--) if(dep[f[y][k]]>=dep[x]) y=f[y][k];
if(x==y) return x;
for(int k=19;k>=0;k--) if(f[x][k]!=f[y][k]) x=f[x][k],y=f[y][k];
return f[x][0];
}
long long gv(int x){
return 1ll*(1+x)*x/2;
}
void update(int x,int l,int r){
int mid=(l+r)/2;lmax[x]=lmax[ls[x]];rmax[x]=rmax[rs[x]];
if(lmax[ls[x]]==mid-l+1) lmax[x]+=lmax[rs[x]];
if(rmax[rs[x]]==r-mid) rmax[x]+=rmax[ls[x]];
tot[x]=(tot[ls[x]]+tot[rs[x]])%mod+(gv(rmax[ls[x]]+lmax[rs[x]])-gv(rmax[ls[x]])-gv(lmax[rs[x]]))%mod;
tot[x]%=mod;
}
int merge(int x,int y,int l,int r){
if(!x || !y) return x+y;
if(l==r){
if(lmax[x]) return x;
else return y;
}
int mid=(l+r)/2;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
update(x,l,r);
return x;
}
void insert(int&now,int x,int l,int r){
if(now==0) now=++num;
if(l==r){
lmax[now]=rmax[now]=tot[now]=1;
return ;
}
int mid=(l+r)/2;
if(x<=mid) insert(ls[now],x,l,mid);
else insert(rs[now],x,mid+1,r);
update(now,l,r);
}
void del(int now,int x,int l,int r){
if(now==0) return ;
if(l==r){
lmax[now]=rmax[now]=tot[now]=0;
return ;
}
int mid=(l+r)/2;
if(x<=mid) del(ls[now],x,l,mid);
else del(rs[now],x,mid+1,r);
update(now,l,r);
}
void get_ans(int x,int fa,int c){
for(int i=first[x];i!=0;i=s[i].next) if(s[i].y!=fa){
int y=s[i].y;
get_ans(y,x,s[i].c);
root[x]=merge(root[x],root[y],1,m);
}
for(int i=0;i