SPOJ COT2 Count on a tree II(树上莫队)

题目链接:http://www.spoj.com/problems/COT2/

You are given a tree with N nodes.The tree nodes are numbered from 1 to N.Each node has an integer weight.

We will ask you to perfrom the following operation:

  • u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.

 

Input

In the first line there are two integers N and M.(N<=40000,M<=100000)

In the second line there are N integers.The ith integer denotes the weight of the ith node.

In the next N-1 lines,each line contains two integers u v,which describes an edge (u,v).

In the next M lines,each line contains two integers u v,which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.

Output

For each operation,print its result.

 

题目大意:给一棵树,每个点有一个权值。多次询问路径(a, b)上有多少个权值不同的点。

思路:参考VFK WC 2013 糖果公园 park 题解(此题比COT2要难。)

http://vfleaking.blog.163.com/blog/static/174807634201311011201627/

 

代码(2.37S):

  1 #include <bits/stdc++.h>

  2 using namespace std;

  3 

  4 const int MAXV = 40010;

  5 const int MAXE = MAXV << 1;

  6 const int MAXQ = 100010;

  7 const int MLOG = 20;

  8 

  9 namespace Bilibili {

 10 

 11 int head[MAXV], val[MAXV], ecnt;

 12 int to[MAXE], next[MAXE];

 13 int n, m;

 14 

 15 int stk[MAXV], top;

 16 int block[MAXV], bcnt, bsize;

 17 

 18 struct Query {

 19     int u, v, id;

 20     void read(int i) {

 21         id = i;

 22         scanf("%d%d", &u, &v);

 23     }

 24     void adjust() {

 25         if(block[u] > block[v]) swap(u, v);

 26     }

 27     bool operator < (const Query &rhs) const {

 28         if(block[u] != block[rhs.u]) return block[u] < block[rhs.u];

 29         return block[v] < block[rhs.v];

 30     }

 31 } ask[MAXQ];

 32 int ans[MAXQ];

 33 /// Graph

 34 void init() {

 35     memset(head + 1, -1, n * sizeof(int));

 36     ecnt = 0;

 37 }

 38 

 39 void add_edge(int u, int v) {

 40     to[ecnt] = v; next[ecnt] = head[u]; head[u] = ecnt++;

 41     to[ecnt] = u; next[ecnt] = head[v]; head[v] = ecnt++;

 42 }

 43 

 44 void gethash(int a[], int n) {

 45     static int tmp[MAXV];

 46     int cnt = 0;

 47     for(int i = 1; i <= n; ++i) tmp[cnt++] = a[i];

 48     sort(tmp, tmp + cnt);

 49     cnt = unique(tmp, tmp + cnt) - tmp;

 50     for(int i = 1; i <= n; ++i)

 51         a[i] = lower_bound(tmp, tmp + cnt, a[i]) - tmp + 1;

 52 }

 53 

 54 void read() {

 55     scanf("%d%d", &n, &m);

 56     for(int i = 1; i <= n; ++i) scanf("%d", &val[i]);

 57     gethash(val, n);

 58     init();

 59     for(int i = 1, u, v; i < n; ++i) {

 60         scanf("%d%d", &u, &v);

 61         add_edge(u, v);

 62     }

 63     for(int i = 0; i < m; ++i) ask[i].read(i);

 64 }

 65 /// find_block

 66 void add_block(int &cnt) {

 67     while(cnt--) block[stk[--top]] = bcnt;

 68     bcnt++;

 69     cnt = 0;

 70 }

 71 

 72 void rest_block() {

 73     while(top) block[stk[--top]] = bcnt - 1;

 74 }

 75 

 76 int dfs_block(int u, int f) {

 77     int size = 0;

 78     for(int p = head[u]; ~p; p = next[p]) {

 79         int v = to[p];

 80         if(v == f) continue;

 81         size += dfs_block(v, u);

 82         if(size >= bsize) add_block(size);

 83     }

 84     stk[top++] = u;

 85     size++;

 86     if(size >= bsize) add_block(size);

 87     return size;

 88 }

 89 

 90 void init_block() {

 91     bsize = max(1, (int)sqrt(n));

 92     dfs_block(1, 0);

 93     rest_block();

 94 }

 95 /// ask_rmq

 96 int fa[MLOG][MAXV];

 97 int dep[MAXV];

 98 

 99 void dfs_lca(int u, int f, int depth) {

100     dep[u] = depth;

101     fa[0][u] = f;

102     for(int p = head[u]; ~p; p = next[p]) {

103         int v = to[p];

104         if(v != f) dfs_lca(v, u, depth + 1);

105     }

106 }

107 

108 void init_lca() {

109     dfs_lca(1, -1, 0);

110     for(int k = 0; k + 1 < MLOG; ++k) {

111         for(int u = 1; u <= n; ++u) {

112             if(fa[k][u] == -1) fa[k + 1][u] = -1;

113             else fa[k + 1][u] = fa[k][fa[k][u]];

114         }

115     }

116 }

117 

118 int ask_lca(int u, int v) {

119     if(dep[u] < dep[v]) swap(u, v);

120     for(int k = 0; k < MLOG; ++k) {

121         if((dep[u] - dep[v]) & (1 << k)) u = fa[k][u];

122     }

123     if(u == v) return u;

124     for(int k = MLOG - 1; k >= 0; --k) {

125         if(fa[k][u] != fa[k][v])

126             u = fa[k][u], v = fa[k][v];

127     }

128     return fa[0][u];

129 }

130 /// modui

131 bool vis[MAXV];

132 int diff, cnt[MAXV];

133 

134 void xorNode(int u) {

135     if(vis[u]) vis[u] = false, diff -= (--cnt[val[u]] == 0);

136     else vis[u] = true, diff += (++cnt[val[u]] == 1);

137 }

138 

139 void xorPathWithoutLca(int u, int v) {

140     if(dep[u] < dep[v]) swap(u, v);

141     while(dep[u] != dep[v])

142         xorNode(u), u = fa[0][u];

143     while(u != v)

144         xorNode(u), u = fa[0][u],

145         xorNode(v), v = fa[0][v];

146 }

147 

148 void moveNode(int u, int v, int taru, int tarv) {

149     xorPathWithoutLca(u, taru);

150     xorPathWithoutLca(v, tarv);

151     //printf("debug %d %d\n", ask_lca(u, v), ask_lca(taru, tarv));

152     xorNode(ask_lca(u, v));

153     xorNode(ask_lca(taru, tarv));

154 }

155 

156 void make_ans() {

157     for(int i = 0; i < m; ++i) ask[i].adjust();

158     sort(ask, ask + m);

159     int nowu = 1, nowv = 1; xorNode(1);

160     for(int i = 0; i < m; ++i) {

161         moveNode(nowu, nowv, ask[i].u, ask[i].v);

162         ans[ask[i].id] = diff;

163         nowu = ask[i].u, nowv = ask[i].v;

164     }

165 }

166 

167 void print_ans() {

168     for(int i = 0; i < m; ++i)

169         printf("%d\n", ans[i]);

170 }

171 

172 void solve() {

173     read();

174     init_block();

175     init_lca();

176     make_ans();

177     print_ans();

178 }

179 

180 };

181 

182 int main() {

183     Bilibili::solve();

184 }
View Code

 

你可能感兴趣的:(count)