思路:用一个数组记录最近k次的出现位置,然后在其附近更新答案。具体见代码:
1 #pragma comment(linker, "/STACK:10240000,10240000") 2 3 #include <iostream> 4 #include <cstdio> 5 #include <algorithm> 6 #include <cstdlib> 7 #include <cstring> 8 #include <map> 9 #include <queue> 10 #include <deque> 11 #include <cmath> 12 #include <vector> 13 #include <ctime> 14 #include <cctype> 15 #include <set> 16 17 using namespace std; 18 19 #define mem0(a) memset(a, 0, sizeof(a)) 20 #define lson l, m, rt << 1 21 #define rson m + 1, r, rt << 1 | 1 22 #define define_m int m = (l + r) >> 1 23 #define Rep(a, b) for(int a = 0; a < b; a++) 24 #define lowbit(x) ((x) & (-(x))) 25 #define constructInt4(name, a, b, c, d) name(int a = 0, int b = 0, int c = 0, int d = 0): a(a), b(b), c(c), d(d) {} 26 #define constructInt3(name, a, b, c) name(int a = 0, int b = 0, int c = 0): a(a), b(b), c(c) {} 27 #define constructInt2(name, a, b) name(int a = 0, int b = 0): a(a), b(b) {} 28 29 typedef double db; 30 typedef long long LL; 31 typedef pair<int, int> pii; 32 typedef multiset<int> msi; 33 typedef multiset<int>::iterator msii; 34 typedef set<int> si; 35 typedef set<int>::iterator sii; 36 typedef vector<int> vi; 37 38 const int dx[8] = {1, 0, -1, 0, 1, 1, -1, -1}; 39 const int dy[8] = {0, -1, 0, 1, -1, 1, 1, -1}; 40 const int maxn = 1e5 + 7; 41 const int maxm = 1e5 + 7; 42 const int maxv = 1e7 + 7; 43 const int MD = 1e9 +7; 44 const int INF = 1e9 + 7; 45 const double PI = acos(-1.0); 46 const double eps = 1e-10; 47 48 template<class edge> struct Graph { 49 vector<vector<edge> > adj; 50 Graph(int n) { adj.clear(); adj.resize(n + 5); } 51 Graph() { adj.clear(); } 52 void resize(int n) { adj.resize(n + 5); } 53 void add(int s, edge e){ adj[s].push_back(e); } 54 void del(int s, edge e) { adj[s].erase(find(iter(adj[s]), e)); } 55 void clear() { adj.clear(); } 56 vector<edge>& operator [](int t) { return adj[t]; } 57 }; 58 59 template<class T> struct TreeArray { 60 vector<T> c; 61 int maxn; 62 TreeArray(int n) { c.resize(n + 5); maxn = n + 2; } 63 TreeArray() { c.clear(); maxn = 0; } 64 void clear() { memset(&c[0], 0, sizeof(T) * maxn); } 65 void resize(int n) { c.resize(n + 5); maxn = n + 2; } 66 void add(int p, T x) { while (p <= maxn) { c[p] += x; p += lowbit(p); } } 67 T get(int p) { T res = 0; while (p) { res += c[p]; p -= lowbit(p); } return res; } 68 T range(int a, int b) { return get(b) - get(a - 1); } 69 }; 70 71 bool cmp(const pair<pii, int> &a, const pair<pii, int> &b) { 72 return a.first.second < b.first.second; 73 } 74 75 int a[maxn], b[maxn], c, n, k, L[maxn], R[maxn], vis[maxn], cc, out[maxn]; 76 77 TreeArray<int> ts; 78 pair<pii, int> in[maxn]; 79 Graph<int> G; 80 81 int find(int x) { return lower_bound(b + 1, b + c + 1, x) - b; } 82 83 void DFS(int u) { 84 vis[u] = 1; 85 L[u] = ++cc; 86 b[cc] = a[u]; 87 for (int i = 0; i < G[u].size(); i++) { 88 if (!vis[G[u][i]] || !vis[G[u][i]]) { 89 DFS(G[u][i]); 90 } 91 } 92 R[u] = cc; 93 } 94 95 int main() { 96 //freopen("in.txt", "r", stdin); 97 int T, cas = 0, m; 98 cin >> T; 99 while (T--) { 100 scanf("%d%d", &n, &k); 101 for (int i = 1; i <= n; i++) { 102 scanf("%d", a + i); 103 } 104 G.clear(); 105 G.resize(n); 106 for (int i = 1, u, v; i < n; i++) { 107 scanf("%d%d", &u, &v); 108 G.add(u, v); 109 G.add(v, u); 110 } 111 mem0(vis); 112 cc = 0; 113 DFS(1); 114 memcpy(a, b, sizeof(b)); 115 sort(b + 1, b + n + 1); 116 c = unique(b + 1, b + n + 1) - b - 1; 117 for (int i = 1; i <= n; i++) { 118 a[i] = find(a[i]); 119 } 120 cin >> m; 121 for (int i = 0, x; i < m; i++) { 122 scanf("%d", &x); 123 in[i].first = make_pair(L[x], R[x]); 124 in[i].second = i; 125 } 126 sort(in, in + m, cmp); 127 G.clear(); 128 G.resize(n); 129 ts.clear(); 130 ts.resize(n); 131 mem0(vis); 132 133 int last = 0; 134 for (int i = 0; i < m; i++) { 135 for (int j = last + 1; j <= in[i].first.second; j++) { 136 G.add(a[j], j); 137 int sz = G[a[j]].size(); 138 if (sz >= k) { 139 ts.add(G[a[j]][sz - k], 1); 140 if (sz > k) { 141 ts.add(G[a[j]][sz - k - 1], -2); 142 if (sz > k + 1) ts.add(G[a[j]][sz - k - 2], 1); 143 } 144 } 145 } 146 out[in[i].second] = ts.range(in[i].first.first, in[i].first.second); 147 //cout << out[in[i].second] << " " << in[i].first.first << " " << in[i].first.second << endl; 148 149 last = in[i].first.second; 150 } 151 if (cas) cout << endl; 152 printf("Case #%d:\n", ++cas); 153 for (int i = 0; i < m; i++) { 154 printf("%d\n", out[i]); 155 } 156 } 157 return 0; 158 }