题目大意:
给你一棵包含N个节点的树,设每条边一开始的边权为0,现在有两种操作:
1)给出参数U,V,C,表示把U与V之间的路径上的边权变成C(保证C≥0)
2)给出参数U,V,C,表示把U与V之间的路径上的边权加上max(C,路径上边权最小值的相反数)。
你需要统计出每次一操作过后树中边权为0的边有多少条。
树链剖分。边权下放为点权。
然后用线段树维护区间最小值,区间最小值的个数,以及区间中0的数量。
注意覆盖和增加标记的处理
#include
#include
#include
#include
#include
#define N 200003
#define inf 1000000000
using namespace std;
int tr[N*4],delta[N*4],cover[N*4],mn[N*4],ct[N*4];
int n,m,belong[N],son[N],v[N],nxt[N],point[N],size[N],tot,fa[N],deep[N],pos[N],sz;
void addedge(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int f)
{
deep[x]=deep[f]+1; size[x]=1;
for (int i=point[x];i;i=nxt[i]){
if (v[i]==f) continue;
fa[v[i]]=x;
dfs(v[i],x);
size[x]+=size[v[i]];
if (size[son[x]]x]=v[i];
}
}
void dfs1(int x,int chain)
{
pos[x]=++sz; belong[x]=chain;
if (!son[x]) return;
dfs1(son[x],chain);
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa[x]&&v[i]!=son[x])
dfs1(v[i],v[i]);
}
void update(int now)
{
ct[now]=0; tr[now]=0;
mn[now]=min(mn[now<<1],mn[now<<1|1]);
if (mn[now]==mn[now<<1]) ct[now]+=ct[now<<1];
if (mn[now]==mn[now<<1|1]) ct[now]+=ct[now<<1|1];
if (mn[now]==0) tr[now]=ct[now];
}
void build(int now,int l,int r)
{
cover[now]=-inf;
if (l==r) {
tr[now]=ct[now]=1; mn[now]=0;
return;
}
int mid=(l+r)/2;
build(now<<1,l,mid);
build(now<<1|1,mid+1,r);
update(now);
}
void change(int now,int l,int r,int val)
{
mn[now]=val; ct[now]=(r-l+1);
if (mn[now]==0) tr[now]=ct[now];
else tr[now]=0;
delta[now]=0; cover[now]=val;
}
void add(int now,int l,int r,int val)
{
mn[now]+=val;
if (mn[now]==0) tr[now]=ct[now];
else tr[now]=0;
if(cover[now]==-inf) delta[now]+=val;
else cover[now]+=val;
}
void pushdown(int now,int l,int r)
{
int mid=(l+r)/2;
if (cover[now]!=-inf) {
change(now<<1,l,mid,cover[now]);
change(now<<1|1,mid+1,r,cover[now]);
cover[now]=-inf;
}
if (delta[now]) {
add(now<<1,l,mid,delta[now]);
add(now<<1|1,mid+1,r,delta[now]);
delta[now]=0;
}
}
void qjcover(int now,int l,int r,int ll,int rr,int val)
{
if (ll>rr) return;
if (ll<=l&&r<=rr) {
change(now,l,r,val);
return;
}
int mid=(l+r)/2;
pushdown(now,l,r);
if (ll<=mid) qjcover(now<<1,l,mid,ll,rr,val);
if (rr>mid) qjcover(now<<1|1,mid+1,r,ll,rr,val);
update(now);
}
void qjadd(int now,int l,int r,int ll,int rr,int val)
{
if (ll>rr) return;
if (ll<=l&&r<=rr) {
add(now,l,r,val);
return;
}
int mid=(l+r)/2;
pushdown(now,l,r);
if (ll<=mid) qjadd(now<<1,l,mid,ll,rr,val);
if (rr>mid) qjadd(now<<1|1,mid+1,r,ll,rr,val);
update(now);
}
void solve(int x,int y,int z)
{
while (belong[x]!=belong[y]){
if (deep[belong[x]]y]]) swap(x,y);
qjcover(1,1,n,pos[belong[x]],pos[x],z);
x=fa[belong[x]];
}
if (pos[x]>pos[y]) swap(x,y);
qjcover(1,1,n,pos[x]+1,pos[y],z);
}
int qjmin(int now,int l,int r,int ll,int rr)
{
if (ll>rr) return inf;
if (ll<=l&&r<=rr) return mn[now];
int mid=(l+r)/2; int ans=inf;
pushdown(now,l,r);
if (ll<=mid) ans=min(ans,qjmin(now<<1,l,mid,ll,rr));
if (rr>mid) ans=min(ans,qjmin(now<<1|1,mid+1,r,ll,rr));
return ans;
}
int find(int x,int y)
{
int ans=inf;
while (belong[x]!=belong[y]){
if (deep[belong[x]]y]]) swap(x,y);
ans=min(ans,qjmin(1,1,n,pos[belong[x]],pos[x]));
x=fa[belong[x]];
}
if (pos[x]>pos[y]) swap(x,y);
ans=min(ans,qjmin(1,1,n,pos[x]+1,pos[y]));
return ans;
}
void solve1(int x,int y,int z)
{
while (belong[x]!=belong[y]){
if (deep[belong[x]]y]]) swap(x,y);
qjadd(1,1,n,pos[belong[x]],pos[x],z);
x=fa[belong[x]];
}
if (pos[x]>pos[y]) swap(x,y);
qjadd(1,1,n,pos[x]+1,pos[y],z);
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1;iint x,y; scanf("%d%d",&x,&y);
addedge(x,y);
}
dfs(1,0); dfs1(1,1);
// for (int i=1;i<=n;i++) cout<<pos[i]<<" "; cout<1,1,n);
for (int i=1;i<=m;i++) {
int opt,x,y,z; scanf("%d%d%d%d",&opt,&x,&y,&z);
if (opt==1) solve(x,y,z);
if (opt==2) {
int t=find(x,y);
solve1(x,y,max(z,-t));
}
printf("%d\n",tr[1]-1);
}
}