poj 3321 Apple Trie

/*
  poj 3321 Apple Trie 
  这道题的关键是如何将一个树建成一个一维数组利用树状数组来解题!
  可以利用dfs()来搞定,我们在对一个节点深搜后,所经过的节点的数目就是该节点的子树的数目
  所以我们利用start[i]数组来记录 i 节点在一维数组的起始位置, 而end[i]则是记录i节点所有孩子 
  节点最后一个孩子节点在数组的位置,那么end[i]-start[i]+1,就是 i 节点(包括自身)和其所有孩子节点的
  数目。数组建好了,那么最后就是套用树状数组模板进行求解了! 
*/
#include<iostream> 
#include<vector>
#include<cstring>
#include<cstdio>
#define N 100005
using namespace std;
class node
{
public :
    int k;
    node *next;
    node()
    {
    	next=NULL;
    }
};

node trie[N];
//trie[i]记录的是所有是 i 节点 孩子节点组成的链表的头部 
int C[N], num[N];
int start[N], end[N];
int cnt, n;

void dfs(int cur)
{
    start[cur]=cnt;
    if(trie[cur].next==NULL)
    {
    	end[cur]=cnt;
        return;
    }
    for(node *p=trie[cur].next; p!=NULL; p=p->next)//遍历cur节点的所有孩子节点 
    {
    	++cnt;
    	dfs(p->k);
    }
    end[cur]=cnt;//深搜之后得到的cnt值就是cur节点最后一个孩子在一维数组中的位置 
}

int lowbit(int x)
{
   return x&(-x);
}

void init(int p, int k)
{
   int i;
   num[p]=k;
   for(i=p-lowbit(p)+1; i<=p; ++i)
      C[p]+=num[i];
}

int getSum(int p)
{
    int s=0;	
    while(p>0)
    {
    	s+=C[p];
    	p-=lowbit(p);
    }
    return s;
}

void update(int p, int k)
{
    while(p<=n)
    {
    	C[p]+=k;
    	p+=lowbit(p);
    }
}


int main()
{
   int i, u, v, m;
   char ch[2];
   int f;
   while(scanf("%d", &n)!=EOF)
   {
      cnt=1;
      memset(C, 0, sizeof(C));
      for(i=1; i<n; ++i)
      {
      	  scanf("%d%d", &u, &v);
      	  node *p=new node();
      	  p->k=v;
      	  p->next=trie[u].next;
      	  trie[u].next=p;
      }
      dfs(1);
      for(i=1; i<=n; ++i)
         init(i, 1);
      scanf("%d", &m);
      while(m--)
      {
	 scanf("%s%d", ch, &f);
	 if(ch[0]=='C')
	 {
	     if(num[f]==1)
	       {
	          update(start[f], -1);
	          num[f]=0;
	       }
	     else
	       {
	       	  update(start[f], 1);
	       	  num[f]=1;
	       } 
	 }
	 else
	    printf("%d\n", getSum(end[f])-getSum(start[f]-1));
      }
   }
   return 0;
}

/*
这道题利用二维数组建图也可以过,不过数组的大小还真是难以捉摸....
*/
#include<iostream> #include<vector> #include<cstring> #include<cstdio> #define N 100005 using namespace std; int node[N][100]; int C[N], num[N]; int start[N], end[N]; int cnt, n; void dfs(int cur) { int sz=node[cur][0]; start[cur]=cnt; if(sz==0) { end[cur]=cnt; return; } for(int i=1; i<=sz; ++i) { ++cnt; dfs(node[cur][i]); } end[cur]=cnt; } int lowbit(int x) { return x&(-x); } void init(int p, int k) { int i; num[p]=k; for(i=p-lowbit(p)+1; i<=p; ++i) C[p]+=num[i]; } int getSum(int p) { int s=0; while(p>0) { s+=C[p]; p-=lowbit(p); } return s; } void update(int p, int k) { while(p<=n) { C[p]+=k; p+=lowbit(p); } } int main() { int i, u, v, m; char ch[2]; int f; while(scanf("%d", &n)!=EOF) { cnt=1; for(i=1; i<=n; ++i) node[i][0]=0; memset(C, 0, sizeof(C)); for(i=1; i<n; ++i) { scanf("%d%d", &u, &v); node[u][++node[u][0]]=v; } dfs(1); for(i=1; i<=n; ++i) init(i, 1); scanf("%d", &m); while(m--) { scanf("%s%d", ch, &f); if(ch[0]=='C') { if(num[f]==1) { update(start[f], -1); num[f]=0; } else { update(start[f], 1); num[f]=1; } } else printf("%d\n", getSum(end[f])-getSum(start[f]-1)); } } return 0; }

  

  

你可能感兴趣的:(apple)