题目
先树链剖分,再根据dfn建立线段树(在同一根树链中的点dfn是连续的),然后,就很好做了。
#include<cstdio>
#define MAXN 100000
int n,fa[MAXN+10],size[MAXN+10],pos[MAXN+10],length[MAXN+10],bl[MAXN+10],Q,dfn[MAXN+10],dcnt,ed[MAXN+10];
bool f[MAXN+10];
void Read(int &x){
char c;
while(c=getchar(),c!=EOF)
if(c>='0'&&c<='9'){
x=c-'0';
while(c=getchar(),c>='0'&&c<='9')
x=x*10+c-'0';
ungetc(c,stdin);
return;
}
}
struct node{
int v;
node *next;
}*adj[MAXN+10],edge[MAXN+10],*ecnt=edge;
struct seg_tree{
int sum,tag,l,r;
}tree[MAXN*4+10];
void clear(int i){
tree[i<<1].sum=(tree[i<<1].r-tree[i<<1].l+1)*tree[i].tag;
tree[i<<1].tag=tree[i].tag;
tree[(i<<1)|1].sum=(tree[(i<<1)|1].r-tree[(i<<1)|1].l+1)*tree[i].tag;
tree[(i<<1)|1].tag=tree[i].tag;
tree[i].tag=-1;
}
void build(int i,int l,int r){
tree[i].l=l,tree[i].r=r;
if(l==r)
return;
int mid=(l+r)>>1;
build(i<<1,l,mid);
build((i<<1)|1,mid+1,r);
}
void insert(int i,int l,int r,int val){
if(tree[i].l>r||tree[i].r<l)
return;
if(tree[i].l>=l&&tree[i].r<=r){
tree[i].sum=((tree[i].r-tree[i].l)+1)*val;
tree[i].tag=val;
return;
}
if(tree[i].tag>=0)
clear(i);
insert(i<<1,l,r,val);
insert((i<<1)|1,l,r,val);
tree[i].sum=tree[i<<1].sum+tree[(i<<1)|1].sum;
}
int find(int i,int l,int r){
if(tree[i].l>=l&&tree[i].r<=r)
return tree[i].sum;
if(tree[i].r<l||tree[i].l>r)
return 0;
if(tree[i].tag>=0)
clear(i);
return find(i<<1,l,r)+find((i<<1)|1,l,r);
}
void addedge(int u,int v){
node *p=++ecnt;
p->v=v;
p->next=adj[u];
adj[u]=p;
}
void dfs1(int u){
size[u]=1;
for(node *p=adj[u];p;p=p->next){
dfs1(p->v);
size[u]+=size[p->v];
}
}
void dfs2(int u,int len){
dfn[u]=++dcnt;
int heavy=n;
for(node *p=adj[u];p;p=p->next)
if(size[p->v]>size[heavy])
heavy=p->v;
if(heavy==n){
ed[u]=dcnt;
int tp=u,i;
for(i=1;i<len;i++)
tp=fa[tp];
length[tp]=len;
for(i=len;i;i--){
pos[u]=i;
bl[u]=tp;
u=fa[u];
}
return;
}
dfs2(heavy,len+1);
for(node *p=adj[u];p;p=p->next)
if(p->v!=heavy)
dfs2(p->v,1);
ed[u]=dcnt;
}
void init(){
Read(n);
for(int i=1;i<n;i++){
Read(fa[i]);
addedge(fa[i],i);
}
build(1,1,n);
dfs1(0);
dfs2(0,1);
}
int query(int a){
int ret=0,num=0;
while(bl[a]){
num+=pos[a];
ret+=find(1,dfn[bl[a]],dfn[a]);
insert(1,dfn[bl[a]],dfn[a],1);
a=fa[bl[a]];
}
num+=pos[a];
ret+=find(1,dfn[bl[a]],dfn[a]);
insert(1,dfn[bl[a]],dfn[a],1);
return num-ret;
}
void solve(){
char s[15];
int a;
Read(Q);
while(Q--){
scanf("%s",s);
Read(a);
if(s[0]=='i')
printf("%d\n",query(a));
else{
printf("%d\n",find(1,dfn[a],ed[a]));
insert(1,dfn[a],ed[a],0);
}
}
}
int main()
{
init();
solve();
}