首页 > 试题广场 >

集合合并

[编程题]集合合并
  • 热度指数:2448 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 128M,其他语言256M
  • 算法知识视频讲解
给定若干个32位int数字集合,每个集合中的数字无重复,譬如:
  {1,2,3}  {2,5,6}  {8}
将其中交集不为空的集合合并,保证合并完成后所有集合之间无交集,输出合并后的集合个数以及最大集合中元素的个数。

输入描述:
输入格式:
1. 第一行为一个数字N,表示集合数。
2. 接下来N行,每行一个非空集合,包含若干个数字,数字之间用空格分开。
假设第i个集合的大小为Si,数据满足N<=100000,ΣSi<=500000


输出描述:
输出格式:
1. 第一行为合并后的集合个数。
2. 第二个为最大集合中元素的个数。
示例1

输入

3
1 2 3
2 5 6
8

输出

2
5
Python优化后的并查集
N = int(input())
sets = [list(map(int, input().split())) for _ in range(N)]
# parent用来记每个数的父节点
# nums用来记每个根节点的树中有几个节点
parent, nums = {}, {}

for _set in sets:
    # 对于每个集合,默认取第一个元素为根节点来构造树
    p = _set[0]
    # 根节点为p的树默认有0个元素
    nums.setdefault(p, 0)

    for s in _set:
        # 如果当前数没有出现在parent中,说明是个新的数,把它添加进来并以p作为它的父节点
        if s not in parent:
            parent[s] = p
            nums[p] += 1
        # 如果当前数已经出现过,它必然有一个父节点,且可以根据父节点向上找到根节点,
        # 于是把一路经过的所有节点都指向p,相当于一个层级压缩,让p成为新的根节点
        # 同时把原先的根节点中的计数赋给p
        else:
            while parent[s] != p:
                temp = parent[s]
                parent[s] = p
                s = temp
            parent[s] = p
            if s != p and s in nums:
                nums[p] += nums[s]
                del nums[s]

print(len(nums.keys()))
print(max(nums.values()))




编辑于 2019-09-06 09:33:41 回复(0)
/*
并查集算法
*/
#include <bits/stdc++.h>
using namespace std;

class DisjointSet
{
private:
    std::unordered_map<int, int> parent;
    std::unordered_map<int, int> rank; // 秩

public:
    void add(int x)
    {
        if(parent.find(x) == parent.end()) {
            parent[x] = x;
            rank[x] = 0;
        }
    }
    int find(int x)
    {
        // 查找根节点,并包含路径压缩,提高运行效率
        return x == parent[x] ? x : (parent[x] = find(parent[x]));
    }
    void to_union(int x1, int x2)
    {
        int f1 = find(x1);
        int f2 = find(x2);
        if (f1 == f2) return;
        // 按秩合并,find-union操作最坏的运行时间提高至O(log n)
        if (rank[f1] > rank[f2])
            parent[f2] = f1;
        else {
            parent[f1] = f2;
            if (rank[f1] == rank[f2])
                ++rank[f2];
        }
    }
    void printRes()
    {
        int cnt = 0, len_max;
        map<int, int> set;
        for(auto it = parent.begin(); it != parent.end(); it++) {
            find(it->first);  // 将所有节点到根节点的距离压缩至1步
            if(set.find(it->second) == set.end()) set[it->second] = 0;
            set[it->second]++; // 统计合并后每个集合的大小
            if(it->first == it->second) cnt++; // 当根节点为本身时,集合数加一
        }
        for(auto p = set.begin(); p != set.end(); p++)
            len_max = max(len_max, p->second);
        cout << cnt << endl << len_max << endl;
//        for(auto it = parent.begin(); it != parent.end(); it++) {
//            cout << it->first << " " << it->second << endl;
//        }
    }
};


int main(void)
{
    DisjointSet ans = DisjointSet();
    int n, a, b;
    cin >> n;
    while(n--) {
        scanf("%d", &a);
        ans.add(a);
        while(getchar() != '\n') {
            scanf("%d", &b);
            ans.add(b);
            ans.to_union(a, b);
        }
    }
    ans.printRes();
    return 0;
}

发表于 2019-07-15 14:19:48 回复(5)
思路很简单,使用并查集来求解。当然,并查集常用的优化套路全都要用上,如:路径压缩、根据rank进行合并。但是在本题的数据量下,老老实实读完数据然后构建并查集是过不了的!必须一边读数据一边构建并查集,同时还要一边做集合的合并,才能勉强AC!
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.HashMap;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String line;
        while((line = br.readLine()) != null){
            int n = Integer.parseInt(line);
            UnionFind uf = new UnionFind();
            // 边读取数据边对并查集中的元素进行合并
            for(int i = 0; i < n; i++){
                String[] str = br.readLine().split(" ");
                int[] set = new int[str.length];
                for(int j = 0; j < set.length; j++){
                    set[j] = Integer.parseInt(str[j]);
                    uf.add(set[j]);
                    if(j > 0) {
                        uf.union(set[j], set[j - 1]);     // 合并集合
                    }
                }
            }
            // 获得连通分量数
            System.out.println(uf.getCount());
            // 获得最大连通分量大小
            System.out.println(uf.getMaxSize());
        }
    }
}

class UnionFind {
    int count;        // 连通分量数
    int maxSize;      // 最大集合的规模
    private HashMap<Integer, Integer> parent;       // 节点->根
    private HashMap<Integer, Integer> rank;         // 节点->树的大小
    public UnionFind(){
        parent = new HashMap<Integer, Integer>();
        rank = new HashMap<Integer, Integer>();
        maxSize = 1;
    }
    
    public void add(int num) {
        if(!parent.containsKey(num)){
            parent.put(num, num);
            rank.put(num, 1);
            count++;
        }
    }
    
    private int find(int x){
        int root = x;
        while(parent.get(root) != root){
            root = parent.get(root);
        }
        // 把沿途节点的根都修改为father
        while(x != root){
            int temp = parent.get(x);
            parent.put(x, root);
            x = temp;
        }
        return root;
    }

    public void union(int x, int y){
        int rootX = find(x);
        int rootY = find(y);
        if(rootX == rootY){
            return;
        }
        int rankX = rank.get(rootX);
        int rankY = rank.get(rootY);
        if(rankX < rankY){
            parent.put(rootX, rootY);      // y的树高就将x合并到y
            rank.put(rootY, rankX + rankY);
            maxSize = Math.max(maxSize, rank.get(rootY));
        }else{
            // 高度相等时随便合并,但是树的最大高度会增加
            parent.put(rootY, rootX);
            rank.put(rootX, rankX + rankY);
            maxSize = Math.max(maxSize, rank.get(rootX));
        }
        count--;
    }
    
    public int getCount(){
        return count;
    }
    
    public int getMaxSize(){
        return maxSize;
    }
}

发表于 2022-01-06 15:24:52 回复(0)
#include <iostream>
#include <cstring>
#include <algorithm>
#include <unordered_map>

using namespace std;

unordered_map<int, int> p;
unordered_map<int, int> cnt;

int find(int x) {
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

void merge(int pa, int pb) {
    if(cnt[pa] >= cnt[pb]) {
        p[pb] = p[pa];
        cnt[pa] += cnt[pb];
        cnt[pb] = 0;
    } else {
        p[pa] = p[pb];
        cnt[pb] += cnt[pa];
        cnt[pa] = 0;
    }
}

int main() {
    int n;
    cin >> n;
    for(int i=0; i<n; i++) {
       int a, b;
       cin >> a;
       if(!p.count(a)) {
           p[a] = a;
           cnt[a] = 1;
       } 
       while(getchar() != '\n') {
           cin >> b;
           if(!p.count(b)) {
               p[b] = find(a);
               cnt[find(a)] ++ ;
           } else {
               int pb = find(b), pa = find(a);
               if(pb != pa) {
                   merge(pb, pa);
               }
           }
       }
    }
    
    int res1 = 0, res2 = 0;
    for(auto v : cnt) {
        if(v.second != 0) {
            res1 ++ ;
            res2 = max(res2, v.second);
        }
    }
    cout << res1 << endl << res2 << endl;
    return 0;
}

发表于 2021-10-29 16:05:29 回复(1)
自己做的,通过率73.3%;Python编写
思路:在添加数据的时候就进行合并操作。这个想法源自插入排序,我将整个集合群作为一个列表,方便对元素操作;每个元素是个集合,这样元素(集合)之间进行并交集运算很方便。假定左边部分为没有交集的集合群,每次新添加的集合放到列表尾部,遍历其之前的各个集合,如果有交集就标记索引,遍历结束后,将索引值对应的集合与新添加的集合取并集,并把原先标记的索引处的元素弹出(因为它已经成为大并集的一部分了,自身没有存在的价值了)
while True:
    try:
        n = int(input())
        a = []
        maxlth = 0
        for i in range(n):
            data = {int(x) for x in input().split()}
            a.append(data)
            maxlth = max(maxlth, len(data))
            j = len(a)
            b = []
            while j > 1:
                if len(a[-1] & a[j-2]) > 0:
                    b.append(j-2)
                j -= 1
            for k in b:
                a[-1] |= a[k]
                maxlth = max(maxlth, len(a[-1]))
                a.pop(k)
        print(len(a))
        print(maxlth)
    except:
        break
讨论区也有并查集算法,大家可以学习一下,我也在学习

发表于 2020-07-16 09:58:48 回复(0)
#include <bits/stdc++.h>
int main()
{
    int N;
    scanf("%d",&N);
    getchar();
        std::map<int,int> mp;
        std::vector<int> vi[N];        //序号
        std::vector<int> res[N],cpy[N];
        int c;
        for(int i=0;i<N;++i)
        {
            while(scanf("%d",&c)!=EOF)
            {
                if(mp.find(c)!=mp.end())
                {
                    vi[mp[c]].push_back(i);
                    vi[i].push_back(mp[c]);
                }
                mp[c]=i;
                res[i].push_back(c);
                char ch = getchar();
                if (ch == '\n')
                    break;
            }
        }
        int bark[N];
        int t;
        int maxv=0,numc=0;
        std::set<int> temp;
        memset(bark,0,sizeof(bark));
        for(int i=0;i<N;++i)
        {
            if(bark[i]==0)
            {
                std::queue<int> qt;
                qt.push(i);
                while(!qt.empty())
                {
                    t=qt.front();
                    if(bark[t]==0)
                    {
                        cpy[i].push_back(t);
                        bark[t]=1;
                        for(int j=0;j<vi[t].size();++j)
                            if(bark[vi[t][j]]==0 )
                                qt.push(vi[t][j]);
                    }
                    qt.pop();
                }
            }
        }
        for(int i=0;i<N;++i)
            if(!cpy[i].empty())
            {
                numc++;
                temp.clear();
                for(int j=0;j<cpy[i].size();++j)
                {
                    for(int k=0;k<res[cpy[i][j]].size();++k)
                        temp.insert(res[cpy[i][j]][k]);
                }
                maxv=std::max(maxv,int(temp.size()));
            }
         std::cout<<numc<<std::endl;
        std::cout<<maxv<<std::endl;
      
}

发表于 2020-04-27 20:17:51 回复(1)
抄了一手路径压缩过了
不然只能过90%+。。
问题是用C++只能过70+。。。
parent = {}
size = {}
def find(x):
    if x==parent[x]:
        return x
    else:
        parent[x]=find(parent[x])
        return  parent[x]
#
'''
def find(a):
    cur = a
    while(cur!=parent[cur]):
        cur = parent[cur]
    parent[a]=cur
    return cur
'''
#
def union(a,b):
    if a not in parent.keys():
        parent[a]=a
        size[a]=1
    if b not in parent.keys():
        parent[b]=b
        size[b] = 1
    if a==b:
        return
    fa = find(a)
    fb = find(b)
    if fa==fb:
        return
    if size[fa]>=size[fb]:
        size[fa]=size[fa]+size[fb]
        parent[fb]=fa
    else:
        size[fb]=size[fa]+size[fb]
        parent[fa]=fb
m = int(input())
for _ in range(m):
    cur = list(map(int,input().split()))
    union(cur[0],cur[0])
    for i in range(1,len(cur)):
        union(cur[0],cur[i])
a = set()
for i in parent.keys():
    a.add(find(i))
print(len(a))
print(max(size.values()))


发表于 2019-10-14 15:43:00 回复(0)

有 java ac 过的吗,我已经优化到我的极致了,还是卡 66.67% 超时,我太难了。

我的一些优化:

  1. 路径压缩
  2. 基于rank合并
  3. 动态计算最大值
  4. 在一行中合并,采用分治法,复杂度从O(n^2)降为O(nlogn)。
import java.util.*;
public class Main {
    private static Map<Integer, Integer> parent = new HashMap<>();
    private static Map<Integer, Integer> size = new HashMap<>();
    private static Map<Integer, Integer> rank = new HashMap<>();
    private static int max = 0;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = Integer.valueOf(sc.nextLine());
        for(int i = 0; i < n; i ++) {
            String[] strs = sc.nextLine().split("\\s+");
            int[] arr = new int[strs.length];
            for(int j = 0; j < strs.length; j ++) {
                Integer num = Integer.valueOf(strs[j]);
                arr[j] = num;
                parent.putIfAbsent(num, num);
                size.putIfAbsent(num, 1);
                rank.putIfAbsent(num, 1);
            }

            unionSet(arr, 0, arr.length - 1);
        }

        int count = 0;
        for(Integer k : parent.keySet()) {
            if(k.equals(parent.get(k))) count ++;
        }

        System.out.println(count);
        System.out.println(max);
    }

    public static int unionSet(int[] arr, int start, int end) {
        if(start >= end) return find(arr[start]);

        int mid = (start + end) >>> 1;
        int leftRoot = unionSet(arr, start, mid);
        int rightRoot = unionSet(arr, mid + 1, end);

        return union(leftRoot, rightRoot);
    }

    public static int find(int x) {
        if(x != parent.get(x)) {
            parent.put(x, find(parent.get(x)));
        }

        return parent.get(x);
    }

    public static int union(int p, int q) {
        int pRoot = find(p);
        int qRoot = find(q);
        if(pRoot == qRoot) return pRoot;

        if(rank.get(pRoot) < rank.get(qRoot)) {
            parent.put(pRoot, qRoot);
            size.put(qRoot, size.get(pRoot) + size.get(qRoot));
            max = Math.max(max, size.get(qRoot));
            return qRoot;
        } else {
            parent.put(qRoot, pRoot);
            size.put(pRoot, size.get(pRoot) + size.get(qRoot));
            max = Math.max(max, size.get(pRoot));
            if(rank.get(pRoot).equals(rank.get(qRoot))) {
                rank.put(pRoot, rank.get(pRoot) + 1);
            }
            return pRoot;
        }
    }
}
编辑于 2019-08-23 11:34:24 回复(1)
 import java.util.*;

public class Main {
    public static void main(String[] args){
        Scanner scanner=new Scanner(System.in);
        int n=scanner.nextInt();
        scanner.nextLine();
        Set<Integer>[] set=new HashSet[n];
        for(int i=0;i<n;i++){
            String s=scanner.nextLine();
            set[i]=new HashSet<Integer>();
            String[] str=s.split(" ");
            for(int j=0;j<str.length;j++)
                set[i].add(Integer.parseInt(str[j]));
        }
        int count=n;//记录集合各数
        int max=set[0].size();//记录最大集合里有多少个元素
        int sign;//每一次并集最后一个集合的位置。
        for(int i=0;i<n-1;i++) {
            sign=i;
            if(set[i].isEmpty())
                continue;
            for (int j = i+1; j < n; j++) {
                if (set[j].isEmpty())
                    continue;
                if (compare(set[i],set[j])) {
                    set[i].addAll(set[j]);
                    set[j].clear();
                    count--;
                    sign=j;
                    if (max < set[i].size()) {
                        max = set[i].size();
                    }
                }
            }
            if(sign!=i) {
                set[sign].addAll(set[i]);//换到后面可以把之前没法并集,现在可以并集的集合,并集。
            }
        }
        System.out.println(count);
        System.out.println(max);
    }
    public static boolean compare(Set<Integer> set1,Set<Integer> set2){
        Iterator<Integer> iterator=set2.iterator();
        while(iterator.hasNext()){
            if(set1.contains(iterator.next()))
                return true;
        }
        return false;
    }
}
有没有大佬可以抢救下,超时了。
编辑于 2019-07-29 23:43:57 回复(0)
import java.util.Scanner;
import java.util.HashMap;
import java.util.List;
import java.util.ArrayList;

class UF {
    int maxCount;    // 最大集合的元素个数
    int count;    // 集合个数
    HashMap<Integer, Integer> parent; // value 是 key 所在集合的根节点
    HashMap<Integer, Integer> counter; // value 是 key 所在集合的元素个数
    public UF() {
        count = 0;
        maxCount = 0;
        parent = new HashMap<Integer, Integer>();
        counter = new HashMap<Integer, Integer>();
    }
    public int count() {
        return count;
    }
    public int maxCount() {
        return maxCount;
    }
    public int find(int p) {
        int root = p;
        while (root != parent.get(root)) {
            root = parent.get(root);
        }
        // 路径压缩
        while (p != root) {
            int temp = parent.get(p);
            parent.put(p, root);
            p = temp;
        }
        return root;
    }
    
    public void union(int p, int q) {
        int i = find(p);
        int j = find(q);
        if (i == j) return;
        if (counter.get(i) < counter.get(j)) {
            parent.put(i, j);
            counter.put(j, counter.get(j) + counter.get(i));
            if (maxCount < counter.get(j)) maxCount = counter.get(j);
        } else {
            parent.put(j, i);
            counter.put(i, counter.get(i) + counter.get(j));
            if (maxCount < counter.get(i)) maxCount  = counter.get(i);
        }
        count--;
    }
}

public class Main {
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = Integer.valueOf(sc.nextLine());
        List<int[]> list = new ArrayList<>();
        UF uf = new UF();
        while (sc.hasNextLine()) {
            String[] strs = sc.nextLine().trim().split(" ");
            int[] nums = new int[strs.length];
            for (int i = 0; i < nums.length; i++) {
                nums[i] = Integer.valueOf(strs[i]);
                if (!uf.parent.containsKey(nums[i])) {
                    uf.parent.put(nums[i], nums[i]);
                    uf.counter.put(nums[i], 1);
                    uf.count++;
                }
            }
            int p = nums[0];
            for (int i = 1; i < nums.length; i++) {
                list.add(new int[] {p, nums[i]});
            }
        }
        for (int[] pair : list) {
            uf.union(pair[0], pair[1]);
        }
        System.out.println(uf.count);
        System.out.println(uf.maxCount);
    }
}

求助各位大佬们,我这个代码case通过率是66.67%,超时了,不知道怎么优化了
编辑于 2019-07-23 21:55:48 回复(0)
牛逼,666

发表于 2019-07-15 21:16:33 回复(0)