《算法竞赛·快冲300题》每日一题:“01树”

算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。

文章目录

  • 题目描述
  • 题解
  • C++代码
  • Java代码
  • Python代码

01树” ,链接: http://oj.ecustacm.cn/problem.php?id=1715

题目描述

【题目描述】 现在给你一个n个节点的树,而且每个节点有一个权值为0或者1。
   现在有m次询问,每次询问输入两个节点x和y,以及一个权值k。
   请你判断x和y的路径中是否存在权值为k的点。(包括x和y本身)
【输入格式】 输入第一行为两个正整数n和m,均为不超过10^5次方的正整数。
   第二行是一个长度为n的01字符串,表示从节点1到节点n的权值。
   接下来n-1行,每行两个数字u和v,表示节点u和v之间存在边。
   接下来m行,每行输入三个数字x,y,k。其中x,y不相同,k为0或者1。 。
【输出格式】 对于每一次询问,如果x和y的路径中包含权值为k的点,输出Yes,否则输出No 。
【输入样例】

5 5
11010
1 2
2 3
2 4
1 5
1 4 1
1 4 0
1 3 0
1 3 1
5 5 1

【输出样例】

Yes
No
Yes
Yes
No

题解

   本题简单的做法是先建树,然后每次查询用DFS搜索路径。任意两点之间有且只有一条路径,做一次DFS能找到这条路径,计算量O(n)。一共做m次查询,总复杂度O(mn),超时。
   不过,本题特殊在于每个点的权值是0或1,查询也是查有没有等于0或1的点。查询一条路径时,如果能确定所有点都是1,或所有点都是0,或有0有1,那么就得到了答案。
   把所有点按0和1分成多个子集,其中一些连通的1是一个子集,一些连通的0是一个子集。最后把整棵树分成很多权值为1的子集、权值为0的子集。权值为0的子集和权值为1的子集相邻。
   对一个查询“x,y,k”:
   (1)如果{x,y}属于一个子集,它们必然连通,且权值相同,权值为0或1。
   (2)如果{x,y}不属于一个子集,它们要么是相邻的两个不同权值的子集,要么它们之间的路径穿过了一个不同权值的子集,两种情况下的路径上有1也有0。
   以上讨论的实际上是并查集的操作。下面用带路径压缩的并查集编码,一次查询约为O(1),m次查询的总复杂度约为O(m)。。
【笔记】

C++代码

  

#include
using namespace std;
char str[100010];
int s[100010];  //并查集
int find_set(int x){                    //查询并查集,返回x的根
    if(x != s[x]) s[x] = find_set(s[x]);     //路径压缩
    return s[x];
}
void merge_set(int x, int y){           //合并
    x = find_set(x);   y = find_set(y);
    if(x != y)    s[x] = s[y];          //把x合并到y上,y的根成为x的根
}
int main(){
    int n, m;
    scanf("%d %d",&n,&m);
    scanf("%s",str+1);
    for(int i = 1; i <= n; i++)  s[i] = i;    //并查集初始化
    for(int i = 1; i < n; i++){
        int u, v;    scanf("%d %d",&u,&v);
        if(str[u] == str[v]) merge_set(u,v);  //合并
    }
    for(int i = 1; i <= m; i++){
        int x, y;  char k;    scanf("%d %d %c",&x,&y,&k);
        if(find_set(x) == find_set(y) && str[x] != k) //属于同一个子集,且权值不等于k
            puts("No");   //比cout快
        else                           //其他情况,既有0也有1
            puts("Yes");  //比cout快
    }
    return 0;
}

Java代码

import java.util.Scanner;
public class Main {
    static char[] str = new char[100010];
    static int[] s = new int[100010];
    static int findSet(int x) {
        if (x != s[x])       s[x] = findSet(s[x]);
        return s[x];
    }
    static void mergeSet(int x, int y) {
        x = findSet(x);
        y = findSet(y);
        if (x != y)      s[x] = s[y];
    }
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        String strInput = sc.next();
        strInput.getChars(0, strInput.length(), str, 1);
        for (int i = 1; i <= n; i++)    s[i] = i;
        for (int i = 1; i < n; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            if (str[u] == str[v])   mergeSet(u, v);
        }
        for (int i = 1; i <= m; i++) {
            int x = sc.nextInt();
            int y = sc.nextInt();
            char k = sc.next().charAt(0);
            if (findSet(x) == findSet(y) && str[x] != k)   System.out.println("No");
            else     System.out.println("Yes");
        }
    }
}

Python代码

import sys
sys.setrecursionlimit(1000000)    #注意要扩栈
str = [0] * 100010
s = [0] * 100010
def find_set(x):
    if x != s[x]:    s[x] = find_set(s[x])
    return s[x]
def merge_set(x, y):
    x = find_set(x)
    y = find_set(y)
    if x != y:   s[x] = s[y]
n, m = map(int, input().split())
str[1:n+1] = input()
for i in range(1, n+1):  s[i] = i
for i in range(n-1):
    u, v = map(int, input().split())
    if str[u] == str[v]:  merge_set(u, v)
for i in range(m):
    x, y, k = input().split()
    x = int(x)
    y = int(y)
    if find_set(x) == find_set(y) and str[x] != k:   print("No")
    else:     print("Yes")

你可能感兴趣的:(算法竞赛快冲300题,并查集)