点击跳转
用 f [ u ] [ i ] f[u][i] f[u][i]表示以 u u u为根的子树中深度为 i i i的节点上的松鼠跑到节点 u u u上,还没打架的时候时有多少只
那么 u u u点的打架次数就是 ∑ [ f [ u ] [ i ] > 1 ] \sum [f[u][i]>1] ∑[f[u][i]>1]
暴力转移很好写, f [ u ] [ i ] = ∑ v m a x ( m i n ( 1 , f [ v ] [ i − 1 ] ) , f [ v ] [ i − 1 ] − 1 ) f[u][i]=\sum_{v} max(min(1,f[v][i-1]),f[v][i-1]-1) f[u][i]=∑vmax(min(1,f[v][i−1]),f[v][i−1]−1)
看到这个形式之后就直接想到长链剖分,短链往长链合并的时候直接按照上述公式暴力即可
但问题是长链上怎么弄?
长链上我需要解决:给区间中所有大于 1 1 1的数减去 1 1 1,还可以支持单点加一个正数,还可以支持查询区间中有多少个数大于 1 1 1
这个东西可以线段树搞
在每个叶子上维护真实值 v a l val val,然后整颗线段树的每个节点都维护一个 c n t cnt cnt表示这个子树中所对应区间上有多少个 v a l val val是 > 1 >1 >1的,再维护一下当前所有 v a l > 1 val>1 val>1的数的最小值
每次区间 − 1 -1 −1之后,都看一下根节点的 m n mn mn有没有变成 ≤ 1 \le 1 ≤1的数,如果变成了,那我就重构线段树(只重构 m n ≤ 1 mn \le 1 mn≤1的那些节点),一直修改到那些值变得 ≤ 1 \le 1 ≤1的叶子,然后自下而上重新算出 c n t cnt cnt
单点加就是直接找到那个叶子然后加上相应的值,并重构这个叶子到根节点的一条链
如果没有单点加的话,可以发现每个叶子一旦变成了 ≤ 1 \le 1 ≤1就永远不会再变回来了,因此每个变成 ≤ 1 \le 1 ≤1的叶子都引发了其到根节点的一次重构,且仅会引发一次,所以重构的复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)。现在考虑单点加操作,单点加只会修改一个叶子的值,也就是说顶多会增加一个叶子重构的工作量。所以有多少次单点加,就会多出来多少 × l o g n \times logn ×logn的重构次数。所以最后重构这部分的复杂度是 O ( ( n + 单 点 加 的 次 数 ) l o g n ) O((n+单点加的次数)logn) O((n+单点加的次数)logn)。
查询就是普通的查询就行
#include
#include
#include
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 200010
#define maxe 400010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
struct Graph
{
int etot, head[maxn], to[maxe], next[maxe], w[maxe];
void clear(int N)
{
for(int i=1;i<=N;i++)head[i]=0;
etot=0;
}
void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
#define forp(_,__) for(auto p=__.head[_];p;p=__.next[p])
}G;
struct Longest_Chain_Decomposition
{
ll len[maxn], son[maxn], depth[maxn], istop[maxn];
void dfs(Graph& G, ll u, ll fa)
{
son[u]=0;
len[u]=1;
depth[u]=depth[fa]+1;
istop[u]=false;
forp(u,G)
{
ll v(G.to[p]); if(v==fa)continue;
dfs(G,v,u);
if(len[v]+1>len[u])len[u]=len[v]+1, son[u]=v;
}
forp(u,G)
{
ll v(G.to[p]); if(v==fa)continue;
if(v!=son[u])istop[v]=true;
}
}
void run(Graph& G, int root)
{
depth[0]=0, dfs(G,root,0);
istop[root]=true;
}
}lcd;
struct SegmentTree
{
ll mn[maxn<<2], val[maxn<<2], cnt[maxn<<2], add[maxn<<2], L[maxn<<2], R[maxn<<2];
void maketag_add(ll o, ll v)
{
add[o]+=v;
mn[o]+=v;
if(L[o]==R[o])
{
if(val[o]>=1)val[o]=max(1ll,val[o]+v);
}
}
void pushdown(ll o)
{
if(L[o]==R[o])return;
if(add[o])
{
maketag_add(o<<1,add[o]);
maketag_add(o<<1|1,add[o]);
add[o]=0;
}
}
void pushup(ll o)
{
mn[o]=min(mn[o<<1],mn[o<<1|1]);
cnt[o]=cnt[o<<1]+cnt[o<<1|1];
}
void rebuild(ll o)
{
if(mn[o]>1)return;
if(L[o]==R[o])
{
if(val[o]>1)mn[o]=val[o], cnt[o]=1;
else mn[o]=linf, cnt[o]=0;
return;
}
pushdown(o);
rebuild(o<<1), rebuild(o<<1|1);
pushup(o);
}
void build(ll o, ll l, ll r)
{
ll mid(l+r>>1);
L[o]=l, R[o]=r;
mn[o]=linf;
if(l==r)return;
build(o<<1,l,mid);
build(o<<1|1,mid+1,r);
pushup(o);
}
void Add(ll o, ll l, ll r, ll v)
{
ll mid(L[o]+R[o]>>1);
if(l<=L[o] and r>=R[o]){maketag_add(o,v);return;}
pushdown(o);
if(l<=mid)Add(o<<1,l,r,v);
if(r>mid)Add(o<<1|1,l,r,v);
pushup(o);
}
ll q(ll o, ll l, ll r)
{
ll mid(L[o]+R[o]>>1), ans(0);
if(l<=L[o] and r>=R[o])return cnt[o];
pushdown(o);
if(l<=mid)ans+=q(o<<1,l,r);
if(r>mid)ans+=q(o<<1|1,l,r);
return ans;
}
ll qv(ll o, ll pos)
{
ll mid(L[o]+R[o]>>1);
if(L[o]==R[o])return val[o];
pushdown(o);
if(pos<=mid)return qv(o<<1,pos);
else return qv(o<<1|1,pos);
}
void ins(ll o, ll pos, ll v)
{
ll mid(L[o]+R[o]>>1), ans(0);
if(L[o]==R[o])
{
val[o]+=v;
if(val[o]>1)mn[o]=val[o], cnt[o]=1;
else mn[o]=linf, cnt[o]=0;
return;
}
pushdown(o);
if(pos<=mid)ins(o<<1,pos,v);
if(pos>mid)ins(o<<1|1,pos,v);
pushup(o);
}
}segtree;
ll n, s, a[maxn], start[maxn], tot, ans;
void dfs(ll u, ll fa)
{
if(lcd.istop[u])
{
start[u] = tot;
tot += lcd.len[u];
}
if(lcd.son[u])
{
start[lcd.son[u]]=start[u]+1;
dfs(lcd.son[u],u);
segtree.Add(1,start[u]+1,start[u]+lcd.len[u]-1,-1);
segtree.rebuild(1);
}
segtree.ins(1,start[u],a[u]);
forp(u,G)
{
ll v(G.to[p]);
if(v==fa or v==lcd.son[u])continue;
dfs(v,u);
ll i; rep(i,0,lcd.len[v]-1)
{
ll add = segtree.qv(1,start[v]+i);
if(add>1)add--;
segtree.ins(1,start[u]+i+1,add);
}
}
ans -= segtree.q(1,start[u],start[u]+lcd.len[u]-1);
}
int main()
{
ll i, u, v;
n=read(), s=read();
rep(i,1,n)a[i]=read(), ans+=a[i];
rep(i,1,n-1)
{
u=read(), v=read();
G.adde(u,v), G.adde(v,u);
}
segtree.build(1,0,n-1);
lcd.run(G,s);
dfs(s,0);
printf("%lld",ans);
return 0;
}