题解 | #两个有序数组间相加和的Topk问题#

// 牛客的测试链接: // https://www.nowcoder.com/practice/7201cacf73e7495aa5f88b223bbbf6d1 // 不要提交包信息,把import底下的类名改成Main,提交下面的代码可以直接通过 // 因为测试平台会卡空间,所以把set换成了动态加和减的结构

import java.util.Scanner; import java.util.Comparator; import java.util.HashSet; import java.util.PriorityQueue;

public class Main {

public static void main(String[] args) {
    Scanner scanner = new Scanner(System.in);
    int N = scanner.nextInt();
    int K = scanner.nextInt();
    int[] arr1 = new int[N];
    int[] arr2 = new int[N];
    for (int i = 0; i < N; i++) {
        arr1[i] = scanner.nextInt();
    }

    for (int i = 0; i < N; i++) {
        arr2[i] = scanner.nextInt();
    }
    int[] topK = topKSum(arr1,arr2,K);
    for (int i = 0; i < K; i++) {
        System.out.print(topK[i] + " ");
    }
    System.out.println();
    scanner.close();
}

public static class Node{
    public int index1;
    public int index2;
    public int sum;

    public Node(int index1,int index2,int sum){
        this.index1 = index1;
        this.index2 = index2;
        this.sum = sum;  //arr1[index1]+arr2[index2]
    }
}

//o2.sum - o1.sum : 降序 排序比较器
public static class NodeComparator implements Comparator<Node>{

    @Override
    public int compare(Node o1, Node o2) {
        return o2.sum - o1.sum;
    }
}

public static int[] topKSum(int[] arr1, int[] arr2, int topK) {
    int N = arr1.length;
    int M = arr2.length;
    topK = Math.min(topK,N * M);

    // res的长度可能是topK,也可能是N*M,向res填sum时,
    // 需要将进行下标换算( i1 * M + i2),确保每个sum都是 arr1[i1] + arr2[i2]得到的唯一sum(sum可能相同,但组成sum的来源不同)
    int[] res = new int[topK];
    int resIndex = 0;
    // 大根堆,最大元素放在堆顶
    PriorityQueue<Node> maxHeap = new PriorityQueue<Node>(new NodeComparator());
    HashSet<Long> set = new HashSet<>();
    int i1 = N -1;
    int i2 = M -1;
    set.add(calcuIndex(i1,i2,M));
    maxHeap.add(new Node(i1,i2,arr1[i1] + arr2[i2]));
    while (resIndex != topK){
        Node curNode = maxHeap.poll();
        res[resIndex++] = curNode.sum;
        i1 = curNode.index1;
        i2 = curNode.index2;
        set.remove(calcuIndex(i1,i2,M));
        // 切记,一定要保证set集合和maxHeap中的数据要同步,maxHeap poll过元素,set中也一定要移除对应元素,避免脏数据

        //  每个sum 都对应了一个resIndex,!set.contains(calcuIndex(i1-1,i2,M)可以判断当前得到的sum是否在之前遍历过
        if (i1 - 1 >= 0 && !set.contains(calcuIndex(i1-1,i2,M))){
            maxHeap.add(new Node(i1-1,i2,arr1[i1-1] + arr2[i2]));
            set.add(calcuIndex(i1-1,i2,M));
        }

        if (i2 - 1 >= 0 && !set.contains(calcuIndex(i1,i2-1,M))){
            maxHeap.add(new Node(i1,i2-1,arr1[i1] + arr2[i2-1]));
            set.add(calcuIndex(i1,i2-1,M));

        }

    }
    return res;

}

public static long calcuIndex(int i1,int i2,int M){
    return (long) i1 * (long)M + (long)i2;

}

}

全部评论

相关推荐

点赞 收藏 评论
分享
牛客网
牛客企业服务