COT (spoj)

经典的函数式线段树题目,存下代码备忘。


#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <queue>
#include <algorithm>
#include <vector>
#include <cstring>
#include <stack>
#include <cctype>
#include <utility>   
#include <map>
#include <string>  
#include <climits> 
#include <set>
#include <string>    
#include <sstream>
#include <utility>   
#include <ctime>
#include <bitset>

using std::priority_queue;
using std::vector;
using std::swap;
using std::stack;
using std::sort;
using std::max;
using std::min;
using std::pair;
using std::map;
using std::string;
using std::cin;
using std::cout;
using std::set;
using std::queue;
using std::string;
using std::stringstream;
using std::make_pair;
using std::getline;
using std::greater;
using std::endl;
using std::multimap;
using std::deque;
using std::unique;
using std::lower_bound;
using std::random_shuffle;
using std::bitset;
using std::upper_bound;
using std::multiset;

typedef long long LL;
typedef unsigned long long ULL;
typedef unsigned UN;
typedef pair<int, int> PAIR;
typedef multimap<int, int> MMAP;
typedef LL TY;
typedef long double LF;

const int MAXN(100010);
const int MAXM(55);
const int MAXE(200010);
const int MAXK(6);
const int HSIZE(13131);
const int SIGMA_SIZE(26);
const int MAXH(20);
const int INFI((INT_MAX-1) >> 1);
const ULL BASE(31);
const LL LIM(1e13);
const int INV(-10000);
const LL MOD(1000000007);
const double EPS(1e-7);
const LF PI(acos(-1.0));

template<typename T> inline void checkmax(T &a, T b){if(b > a) a = b;}
template<typename T> inline void checkmin(T &a, T b){if(b < a) a = b;}
template<typename T> inline T ABS(const T &a){return a < 0? -a: a;}

struct EDGE
{
    int v, next;
} edge[MAXE];
int first[MAXN], rear;

void init(int n)
{
    memset(first, -1, sizeof(first[0])*(n+1));
    rear = 0;
}
void insert(int u, int v)
{
    edge[rear].v = v;
    edge[rear].next = first[u];
    first[u] = rear++;
}
struct LCA
{
    int E[MAXN*2], dep[MAXN*2], pos[MAXN];
    int back;
    int table[MAXH][MAXN*2];
    void init()
    {
        back = 0;
        dfs(1, 0, -1);
        for(int i = 1; i <= back; ++i) table[0][i] = i;
        for(int i = 1; (1 << i) <= back; ++i)
            for(int j = 1; j+(1 << i)-1 <= back; ++j)
                table[i][j] = comp(table[i-1][j], table[i-1][j+(1 << (i-1))]);
    }
    void dfs(int u, int d, int f)
    {
        E[++back] = u;
        dep[back] = d;
        pos[u] = back;
        for(int i = first[u]; ~i; i = edge[i].next)
            if(edge[i].v != f)
            {
                dfs(edge[i].v, d+1, u);
                E[++back] = u;
                dep[back] = d;
            }
    }
    inline int comp(int a, int b) {return dep[a] < dep[b]? a: b;}
    int query(int a, int b)
    {
        a = pos[a];
        b = pos[b];
        if(a > b) swap(a, b);
        int len = b-a+1, temp = 0;
        while((1 << temp) <= len) ++temp;
        --temp;
        return E[comp(table[temp][a], table[temp][b-(1 << temp)+1])];
    }
} lca;

int root[MAXN];
int ls[MAXN*36], rs[MAXN*36], sum[MAXN*36];
int back;

void build(int l, int r, int &rt)
{
    rt = ++back;
    sum[rt] = 0;
	if(l == r) return;
    int m = (l+r) >> 1;
    build(l, m, ls[rt]);
    build(m+1, r, rs[rt]);
}

void change(int prt, int &rt, int l, int r, int num)
{
    rt = ++back;
    ls[rt] = ls[prt];
    rs[rt] = rs[prt];
    sum[rt] = sum[prt]+1;
	if(l == r) return;
    int m = (l+r) >> 1;
    if(num <= m) change(ls[prt], ls[rt], l, m, num);
    else change(rs[prt], rs[rt], m+1, r, num);
}

int fa[MAXN];

int query(int frt, int lrt, int rrt, int l, int r, int rk, int fval)
{
    if(l == r) return l;
    int m = (l+r) >> 1;
    int sm = sum[ls[lrt]]+sum[ls[rrt]]-2*sum[ls[frt]];
    if(fval >= l && fval <= m) sm += 1;
    if(rk <= sm) return query(ls[frt], ls[lrt], ls[rrt], l, m, rk, fval);
    else return query(rs[frt], rs[lrt], rs[rrt], m+1, r, rk-sm, fval);
}

int arr[MAXN], tab[MAXN], N;

void dfs(int u, int f)
{
    fa[u] = f;
    change(root[f], root[u], 1, N, arr[u]);
    for(int i = first[u]; ~i; i = edge[i].next)
        if(edge[i].v != f)
            dfs(edge[i].v, u);
}

int main()
{
    int n, m;
    while(~scanf("%d%d", &n, &m))
    {
        for(int i = 1; i <= n; ++i)
        {
            scanf("%d", arr+i);
            tab[i-1] = arr[i];
        }
        sort(tab, tab+n);
        N = unique(tab, tab+n)-tab;
        for(int i = 1; i <= n; ++i) arr[i] = lower_bound(tab, tab+N, arr[i])-tab+1;
		init(n);       
		int u, v;
        for(int i = 1; i < n; ++i)
        {
            scanf("%d%d", &u, &v);
            insert(u, v);
            insert(v, u);
        }
        lca.init();
		back = 0;
        build(1, N, root[0]);
        dfs(1, 0);
		int rk;
        for(int i = 0; i < m; ++i)
        {
            scanf("%d%d%d", &u, &v, &rk);
            printf("%d\n", tab[query(root[lca.query(u, v)], root[u], root[v], 1, N, rk, arr[lca.query(u, v)])-1]);
        }
    }
    return 0;
}


你可能感兴趣的:(COT (spoj))