SPOJ 913 Query on a tree II 树链剖分

对于询问dist,树链剖分搞之,把边权转化到点上,然后注意细节就好(我在代码里标出来了,为了这个细节,wa了一屏)

对于询问kth,可以先求出两点(x和y)的lca,然后判断第k个数字是在x到lca的路径上还是y到lca的路径上,确定之后,倍增的寻找就好了~

 

View Code
  1 #include <iostream>

  2 #include <cstring>

  3 #include <cstdlib>

  4 #include <algorithm>

  5 #include <cstdio>

  6 

  7 #define N 50000

  8 #define M 100000

  9 

 10 using namespace std;

 11 

 12 int head[N],next[M],to[M],len[M];

 13 int n,tot,cnt;

 14 int fa[N],son[M],top[N],dat[N],sum[N<<2],dep[N],sz[N],pre[N],bh[N];

 15 int f[N][22],bit[22];

 16 int q[M];

 17 

 18 inline void init()

 19 {

 20     memset(head,-1,sizeof head); cnt=2; tot=0;

 21     memset(son,0,sizeof son);

 22     memset(fa,0,sizeof fa);

 23     memset(f,0,sizeof f);

 24     memset(sum,0,sizeof sum);

 25     bit[0]=1;

 26     for(int i=1;i<=20;i++) bit[i]=bit[i-1]<<1;

 27 }

 28 

 29 inline void prep()

 30 {

 31     int h=1,t=2,sta;

 32     q[1]=1; dep[1]=1;

 33     while(h<t)

 34     {

 35         sta=q[h++]; sz[sta]=1;

 36         for(int i=head[sta];~i;i=next[i])

 37             if(fa[sta]!=to[i])

 38             {

 39                 fa[to[i]]=sta;

 40                 f[to[i]][0]=sta;

 41                 pre[to[i]]=i^1;

 42                 dep[to[i]]=dep[sta]+1;

 43                 q[t++]=to[i];

 44             }

 45     }

 46     for(int j=t-1;j>=1;j--)

 47     {

 48         sta=q[j];

 49         for(int i=head[sta];~i;i=next[i])

 50             if(fa[sta]!=to[i])

 51             {

 52                 sz[sta]+=sz[to[i]];

 53                 if(sz[to[i]]>sz[son[sta]]) son[sta]=to[i];

 54             }

 55     }

 56     for(int i=1;i<t;i++)

 57     {

 58         sta=q[i];

 59         if(son[fa[sta]]==sta) top[sta]=top[fa[sta]];

 60         else top[sta]=sta;

 61     }

 62 }

 63 

 64 inline void rewrite()

 65 {

 66     for(int i=1;i<=n;i++)

 67         if(top[i]==i)

 68             for(int j=i;j;j=son[j])

 69             {

 70                 bh[j]=++tot;

 71                 dat[tot]=len[pre[j]];

 72             }

 73 }

 74 

 75 inline void lcainit()

 76 {

 77     for(int j=1;j<=20;j++)

 78         for(int i=1;i<=n;i++)

 79             f[i][j]=f[f[i][j-1]][j-1];

 80 }

 81 

 82 inline void pushup(int x)

 83 {

 84     sum[x]=sum[x<<1]+sum[x<<1|1];

 85 }

 86 

 87 inline void build(int u,int L,int R)

 88 {

 89     if(L==R) {sum[u]=dat[L];return;}

 90     int MID=(L+R)>>1;

 91     build(u<<1,L,MID); build(u<<1|1,MID+1,R);

 92     pushup(u);

 93 }

 94 

 95 inline void add(int u,int v,int w)

 96 {

 97     to[cnt]=v; len[cnt]=w; next[cnt]=head[u]; head[u]=cnt++;

 98 }

 99 

100 inline void read()

101 {

102     init();

103     scanf("%d",&n);

104     for(int i=1,a,b,c;i<n;i++)

105     {

106         scanf("%d%d%d",&a,&b,&c);

107         add(a,b,c); add(b,a,c);

108     }

109     prep();

110     rewrite();

111     build(1,1,tot);

112     lcainit();

113 }

114 

115 inline int querysum(int u,int L,int R,int l,int r)

116 {

117     if(l<=L&&R<=r) return sum[u];

118     int MID=(L+R)>>1,res=0;

119     if(l<=MID) res+=querysum(u<<1,L,MID,l,r);

120     if(MID<r) res+=querysum(u<<1|1,MID+1,R,l,r);

121     return res;

122 }

123 

124 inline int getsum(int x,int y)

125 {

126     int res=0;

127     while(top[x]!=top[y])

128     {

129         if(dep[top[x]]<dep[top[y]]) swap(x,y);

130         res+=querysum(1,1,tot,bh[top[x]],bh[x]);

131         x=fa[top[x]];

132     }

133     if(x==y) return res;//这句话好坑啊!把边权转移到点权上时会出现这个问题! 

134     if(bh[x]>bh[y]) swap(x,y);

135     res+=querysum(1,1,tot,bh[son[x]],bh[y]);//细节 

136     return res;

137 }

138 

139 inline int getlca(int x,int y)

140 {

141     if(dep[x]<dep[y]) swap(x,y);

142     for(int i=20;i>=0;i--)

143         if(dep[f[x][i]]>=dep[y]) x=f[x][i];

144     if(x==y) return x;

145     for(int i=20;i>=0;i--)

146         if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];

147     return f[x][0];

148 }

149 

150 inline int getlen(int x,int lca)

151 {

152     int res=0;

153     for(int i=20;i>=0;i--)

154         if(dep[f[x][i]]>=dep[lca]) x=f[x][i],res+=bit[i];

155     return res;

156 }

157 

158 inline int getnum(int x,int p)

159 {

160     int res=0;

161     for(int i=20;i>=0;i--)

162         if(res+bit[i]<=p) x=f[x][i],res+=bit[i];

163     return x;

164 }

165 

166 inline int getkth(int x,int y,int p)

167 {

168     int lca=getlca(x,y);

169     int lx=getlen(x,lca)+1;

170     int ly=getlen(y,lca)+1;

171     if(lx>=p) return getnum(x,p-1);

172     return getnum(y,lx+ly-p-1);

173 }

174 

175 inline void go()

176 {

177     char str[10];int a,b,c;

178     while(scanf("%s",str))

179     {

180         if(str[1]=='O') break;

181         if(str[0]=='K')

182         {

183             scanf("%d%d%d",&a,&b,&c);

184             printf("%d\n",getkth(a,b,c));

185         }

186         else

187         {

188             scanf("%d%d",&a,&b);

189             printf("%d\n",getsum(a,b));

190         }

191     }

192     puts("");

193 }

194 

195 int main()

196 {

197     int cas;scanf("%d",&cas);

198     while(cas--) read(),go();

199     return 0;

200 }

 

 

你可能感兴趣的:(query)