You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.
We will ask you to perform the following operation:
u v k : ask for the kth minimum weight on the path from node u to node v
Input
In the first line there are two integers N and M. (N, M <= 100000
In the second line there are N integers. The ith integer denotes the weight of the ith node
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v)
In the next M lines, each line contains three integers u v k, which means an operation asking for the kth minimum weight on the path from node u to node v
Output
For each operation, print its result.
给一棵无根树,m个询问,问树上区间第k大。
在接触树上第k大之前先做了HDU-4757,这题是树上的可持久化01字典树,所以一上来就很有思路。
树上的路径就是一个lca,像上面这题一样,每一个节点以父亲为last版本,新建可持久化线段树,然后在查询的时候,左右树相加再减去一个算重了的lca树,减去不用考虑的lca的father树,就差分出了树上路径树。
ac代码:
#include
using namespace std;
const int maxn = 100005;
int n, m, in[maxn], num[maxn];
vector<int> G[maxn];
int p[maxn][20], dep[maxn];
struct Node {
int val;
Node *lc, *rc;
} *root[maxn], pool[maxn * 20], *tail = pool, *null;
Node* update(Node *pre, int l, int r, int pos) {
if(pos < l || pos > r) {
return pre;
}
Node *nd = ++tail;
if(l == r) {
nd->val = pre->val + 1;
return nd;
}
int mid = (l + r) >> 1;
nd->lc = update(pre->lc, l, mid, pos);
nd->rc = update(pre->rc, mid + 1, r, pos);
nd->val = nd->lc->val + nd->rc->val;
return nd;
}
void dfs(int u, int fa) {
root[u] = update(root[fa], 1, n, num[u]);
p[u][0] = fa;
dep[u] = dep[fa] + 1;
for(int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if(v == fa) {
continue;
}
dfs(v, u);
}
}
int lca(int u, int v) {
if(dep[u] > dep[v]) swap(u, v);
for(int i = 0; i < 20; i++) {
if((dep[v] - dep[u]) >> i & 1) {
v = p[v][i];
}
}
if(v == u) return u;
for(int i = 20 - 1; i >= 0; i--) {
if(p[u][i] != p[v][i]) {
u = p[u][i];
v = p[v][i];
}
}
return p[u][0];
}
int query(Node *topfa, Node *top, Node *u, Node *v, int l, int r, int k) {
if(l == r) {
return l;
}
int mid = (l + r) >> 1;
int cnt = (u->lc->val + v->lc->val) - top->lc->val - topfa->lc->val;
if(k <= cnt) {
return query(topfa->lc, top->lc, u->lc, v->lc, l, mid, k);
}
return query(topfa->rc, top->rc, u->rc, v->rc, mid + 1, r, k - cnt);
}
void init() {
null = ++tail;
null->val = 0;
null->lc = null->rc = null;
for(int i = 0; i <= n; i++) {
root[i] = null;
}
sort(in + 1, in + n + 1);
unique(in + 1, in + n + 1);
for(int i = 1; i <= n; i++) {
num[i] = lower_bound(in + 1, in + n + 1, num[i]) - in;
}
dfs(1, 0);
for(int i = 1; i < 20; i++) {
for(int j = 1; j <= n; j++) {
p[j][i] = p[p[j][i - 1]][i - 1];
}
}
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) {
scanf("%d", &in[i]);
num[i] = in[i];
G[i].clear();
}
int u, v, k;
for(int i = 1; i <= n - 1; i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
init();
while(m--) {
scanf("%d%d%d", &u, &v, &k);
int top = lca(u, v);
int ans = in[query(root[p[top][0]], root[top], root[u], root[v], 1, n, k)];
printf("%d\n", ans);
}
return 0;
}