【题解】BZOJ 3674 可持久化并查集加强版

传送门
题目是可持久化并查集加强版,其实并没有加强,只是原题可以用离线算法水过,而这道题才是用来练可持久化并查集的板子题。
首先对于学习可持久化并查集有一个先决条件,就是学会用可持久化线段树实现可持久化数组,如果不会的可以戳这。
接下来我们就来讲讲怎么用可持久化数组实现可持久化并查集。
讲解所需要的图其实在这里面已经贴出来了,我在这里就不重贴了。主要讲讲该如何实现可持久化并查集的各个操作。
普通并查集只有查找和合并两个操作,相似的,可持久化并查集也有这两个操作。

init

和可持久化数组类似,建树,记为第0版本。

void build(int &k, int l, int r) {  //数组模拟链表,因为线段树有多棵“缠在一起”,所以很难使用类似堆得表示方法存储
    k = ++cnt;  
    if (l == r) {
        v[k] = l; //v中存储的内容和普通并查集类似
        return;
    }  
    int mid = (l + r) >> 1;  
    build(lc[k], l, mid); //递归建树
    build(rc[k], mid + 1, r);  
}  

查找

先码上普通并查集的查找代码,再根据它进行修改

void find(int x) {
    if (v[x] == x) return x;
    else {
        int ret = find(v[x]);
        v[x] = ret;
        return ret;
    } 
}

本来写三目运算可以更加简洁,但是这里为了对比,就不加了。
上面的代码已经加入了路径压缩,我们先不考虑这个,先考虑查找它的最终的父亲。
考虑到线段树中叶子节点存储的就是该节点的直属父亲(先不考虑路径压缩),直接线段树单点查询即可,知道找的结点的父亲就是它本身时停止操作。
接下来考虑路径压缩。
像上面的普通并查集代码一样,对沿途所有结点进行修改,修改操作一会儿给出。
完整的查找代码:

int query(int k, int l, int r, int pos) {  //单点修改部分
    if (l == r) return v[k];  
    int mid = (l + r) >> 1;  
    if (pos <= mid) 
        return query(lc[k], l, mid, pos);  
    else return query(rc[k], mid + 1, r, pos);  
} 
int find(int &root, int x) {  
    int tmp = query(root, 1, n, x);   //找到直属父亲
    if (tmp == x) return x;  
    else {  
        int ret = find(root, tmp);  //找到最终的父亲
        insert(root, root, 1, n, x, ret);  //路径压缩
        return ret;  
    }  
}  

合并

函数就是上面的insert(),具体函数及其参数为void insert(int x, int &y, int l, int r, int pos, int val),意思是需要将x的内容搬到y中,并且将pos的值改成val,具体操作原理在这中已经给出了。具体操作就是每次对某个历史版本进行修改时,对于所有包含该位置的区间结点全部新开一个,并与其父节点连边,对于其他结点,由于不需要发生改动,所以直接连接即可。
完整合并代码:

void insert(int x, int &y, int l, int r, int pos, int val) {  
    y = ++cnt;  
    if (l == r) {
        v[y] = val;
        return;
    }  
    int mid = (l + r) >> 1;  
    lc[y] = lc[x]; rc[y] = rc[x];  //先复制左子树右子树再递归更新
    if (pos <= mid) insert(lc[x], lc[y], l, mid, pos, val);  
    else insert(rc[x], rc[y], mid + 1, r, pos, val);  
}  

完整代码

#include  
#include  
#include  
#include  
#include  
#include  
#define maxn 200005  
#define maxm 10000000  
using namespace std;  
int n, m, p, x, y, cnt, lastans, ans;  
int rt[maxn], v[maxm], lc[maxm], rc[maxm];  
inline int getint() {  
    int x = 0, f = 1;
    char ch = getchar();  
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }  
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }  
    return x * f;  
}  
void build(int &k, int l, int r) {  
    k = ++cnt;  
    if (l == r) {
        v[k] = l;
        return;
    }  
    int mid = (l + r) >> 1;  
    build(lc[k], l, mid);
    build(rc[k], mid + 1, r);  
}  
void insert(int x, int &y, int l, int r, int pos, int val) {  
    y = ++cnt;  
    if (l == r) {
        v[y] = val;
        return;
    }  
    int mid = (l + r) >> 1;  
    lc[y] = lc[x]; rc[y] = rc[x];  
    if (pos <= mid) insert(lc[x], lc[y], l, mid, pos, val);  
    else insert(rc[x], rc[y], mid + 1, r, pos, val);  
}  
int query(int k, int l, int r, int pos) {  
    if (l == r) return v[k];  
    int mid = (l + r) >> 1;  
    if (pos <= mid) 
        return query(lc[k], l, mid, pos);  
    else return query(rc[k], mid + 1, r, pos);  
}  
int find(int &root, int x) {  
    int tmp = query(root, 1, n, x);   
    if (tmp == x) return x;  
    else {  
        int ret = find(root, tmp);  
        insert(root, root, 1, n, x, ret);  
        return ret;  
    }  
}  
int main() {  
    n = getint(); m = getint();  
    build(rt[0],1,n);  
    for(int i = 1; i <= m; i++) {  
        int opt = getint();  
        if (opt == 1) {  
            int x = getint() ^ lastans;
            int y = getint() ^ lastans;
            int fx = find(rt[i - 1], x), fy = find(rt[i - 1], y);  
            if (fx == fy) rt[i] = rt[i - 1];  //如果已经在同个集合,直接版本复制
            else insert(rt[i - 1], rt[i], 1, n, fx, fy);  //合并
        }  
        else if (opt == 2) {  
            int x = getint() ^ lastans;  
            rt[i] = rt[x];  //直接复制版本
        }  
        else {  
            int x = getint() ^ lastans;
            int y = getint() ^ lastans;
            int fx = find(rt[i - 1], x), fy = find(rt[i - 1], y);  
            lastans = fx == fy ? 1 : 0;  
            printf("%d\n", lastans);
            rt[i] = rt[i - 1];  //版本复制
        }  
    }  
    return 0;  
}  

你可能感兴趣的:(线段树,可持久化数据结构,BZOJ)