首页 > 试题广场 >

两个有序数组间相加和的Topk问题

[编程题]两个有序数组间相加和的Topk问题
  • 热度指数:3302 时间限制:C/C++ 2秒,其他语言4秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
给定两个有序数组arr1和arr2,再给定一个整数k,返回来自arr1和arr2的两个数相加和最大的前k个,两个数必须分别来自两个数组
按照降序输出
[要求]
时间复杂度为

输入描述:
第一行三个整数N, K分别表示数组arr1, arr2的大小,以及需要询问的数
接下来一行N个整数,表示arr1内的元素
再接下来一行N个整数,表示arr2内的元素


输出描述:
输出K个整数表示答案
示例1

输入

5 4
1 2 3 4 5
3 5 7 9 11

输出

16 15 14 14

备注:
保证
要先排序一下。。。。。。
发表于 2019-12-25 17:19:59 回复(1)
import java.util.*;
public class Main {
    //放入大根堆中的结构
    static class Node {
        public int index1;  //arr1中的位置
        public int index2;  //arr2中的位置
        public int sum;     //arr1[index1]+arr2[index2]
        public Node(int i1, int i2, int s) {
            index1 = i1;
            index2 = i2;
            sum = s;
        }
    }

    public static int[] topKSum(Integer[] arr1, Integer[] arr2, int topK) {
        if (arr1 == null || arr2 == null || topK < 1) {
            return null;
        }
        topK = Math.min(topK, arr1.length * arr2.length);
        int[] res = new int[topK];
        int resIndex = 0;
        //自定义比较器,实现大根堆
        PriorityQueue<Node> maxHeap = new PriorityQueue<>((N1, N2) -> N2.sum - N1.sum);
        // set[i][j] == false , arr1[i] arr2[j] 之前没进过堆
        // set[i][j] == true , arr1[i] arr2[j] 之前进过堆
        //boolean[][] set = new boolean[arr1.length][arr2.length];
        //使用hashset解决超内存问题
        HashSet<String> positionSet = new HashSet<>();
        //从右下角开始
        int i1 = arr1.length - 1;
        int i2 = arr2.length - 1;
        maxHeap.add(new Node(i1, i2, arr1[i1] + arr2[i2]));
        //set[i1][i2] = true;
        positionSet.add(i1 + "_" + i2);
        while (resIndex != topK) {
            Node curNode = maxHeap.poll();
            res[resIndex++] = curNode.sum;
            i1 = curNode.index1;
            i2 = curNode.index2;
//            if (i1 - 1 >= 0 && set[i1 - 1][i2] == false) {
//                set[i1 - 1][i2] = true;
//                maxHeap.add(new Node(i1 - 1, i2, arr1[i1 - 1] + arr2[i2]));
//            }
//            if (i2 - 1 >= 0 && set[i1][i2 - 1] == false) {
//                set[i1][i2 - 1] = true;
//                maxHeap.add(new Node(i1, i2 - 1, arr1[i1] + arr2[i2 - 1]));
//            }
            if (i1 - 1 >= 0 && !positionSet.contains(i1 - 1 + "_" + i2)) {
                positionSet.add(i1 - 1 + "_" + i2);
                maxHeap.add(new Node(i1 - 1, i2, arr1[i1 - 1] + arr2[i2]));
            }
            if (i2 - 1 >= 0 && !positionSet.contains(i1 + "_" + (i2 - 1))) {
                positionSet.add(i1 + "_" + (i2 - 1));
                maxHeap.add(new Node(i1, i2 - 1, arr1[i1] + arr2[i2 - 1]));
            }
        }
        return res;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int k = in.nextInt();
        Integer[] arr1 = new Integer[n];
        Integer[] arr2 = new Integer[n];
        for (int i = 0; i < n; i++) {
            arr1[i] = in.nextInt();
        }
        for (int i = 0; i < n; i++) {
            arr2[i] = in.nextInt();
        }
        //要将输入的两个数字排序
        Arrays.sort(arr1);
        Arrays.sort(arr2);
        int[] res = topKSum(arr1, arr2, k);
        for (int re : res) {
            System.out.print(re + " ");
        }
    }
}

编辑于 2020-08-23 00:19:13 回复(0)
补一个python3版本的。两个坑,需要排序,而且第一行只有两个数,因为数组长度相同。
基本思路就是广度优先搜索+优先队列。所有的和可以看成一个N*N的矩阵,右下角的数是a1[N-1]+a2[N-1],是最大的,先把他加入优先队列q,要以元组形式加入(val, N-1,N-1),后面两个是它的坐标
然后开始循环,结束条件是得到了前k个最大的值。
循环中,
1.先从q中取得最大值(val,i,j),加入结果集中;
2.然后把这个最大值的相邻两个元素(i-1,j) 和(i,j-1)也加入队列q,前提是这两个没有重复加入,用一个哈希表就可以判断;添加元素的复杂度是log(k)。         原理就是,q中剩余的元素有可能是下一个要取的最大值,同样当前的最大值(val,i,j)的两个相邻元素(i-1,j) (i,j-1)也有可能是下一个最大值,至于更远的,如(i-1,j-1) (i-2,j) (i,j-2),它们比(i-1,j) (i,j-1)要小,不可能成为下一个最大值。
当得到k个值后就跳出循环,输出。
import sys
import heapq as hq
f=sys.stdin
# f=open('a.txt','r')
mnk=f.readline().strip().split()
m,k=[int(x) for x in mnk]
# 题干说第一行有三个数,但实际只有两个,数组长度相同
n=m
# 取负值,因为用的最小堆,最后反过就行了
a1=f.readline().strip().split()
a1=[-int(x) for x in a1]
a2=f.readline().strip().split()
a2=[-int(x) for x in a2]

# 要排序,真坑
a1=sorted(a1,reverse=True)
a2=sorted(a2,reverse=True)

visited=set()
visited.add((m-1,n-1))
que=[(a1[m-1]+a2[n-1], m-1,n-1)]#分别是值,ij坐标
res=[]
while len(res)<k:
    # 先取当前老大
    val,i,j=hq.heappop(que)
    res.append(val)
    # 加进新的候选人
    if i-1>=0 and (i-1,j) not in visited:
        hq.heappush(que, ( a1[i-1]+a2[j] ,i-1,j))
        visited.add( (i-1,j) )
    if j-1>=0 and (i,j-1) not in visited:
        hq.heappush(que, ( a1[i]+a2[j-1] ,i,j-1))
        visited.add( (i,j-1) )
# 取反输出
for i in range(len(res)):
    print(-res[i], end='')
    if i!=len(res):
        print(' ',end='')


发表于 2020-08-06 13:11:27 回复(0)
本地是对的,不知道这个为啥就不对,那位大佬指导一下
#include<bits/stdc++.h>
using namespace std;
int main()
{
    int n,k;
    cin>>n>>k;
    vector<long>arr1(n),arr2(n);
    for(int i=0;i<n;i++)
        cin>>arr1[i];
    for(int i=0;i<n;i++)
        cin>>arr2[i];
    int x=n-1,y=n-1;
    if(arr1[x]>=arr2[y])
    {
        for(int i=0;i<k;i++)
        {
            
            if(x>0&&arr1[x]+arr2[y]<arr1[x-1]+arr2[n-1])
            {
                cout<<arr1[x-1]+arr2[n-1]<<" ";
                x--;
                y=n-1;
            }
            else
            {
                cout<<arr1[x]+arr2[y]<<" ";
                y--;
                if(y<0)
                {
                    y=n-1;
                    x--;
                }
            }
        } 
    }
    else
    {
        for(int i=0;i<k;i++)
        {
            
            if(y>0&&arr1[x]+arr2[y]<arr1[n-1]+arr2[y-1])
            {
                cout<<arr1[n-1]+arr2[y-1]<<" ";
                x=n-1;
                y--;
            }
            else
            {
                cout<<arr1[x]+arr2[y]<<" ";
                x--;
                if(x<0)
                {
                    x=n-1;
                    y--;
                }
            }
        } 
    }
    return 0;
}

发表于 2019-09-05 16:48:15 回复(0)
package class18;
// 牛客的测试链接:
// https://www.nowcoder.com/practice/7201cacf73e7495aa5f88b223bbbf6d1
// 不要提交包信息,把import底下的类名改成Main,提交下面的代码可以直接通过
// 因为测试平台会卡空间,所以把set换成了动态加和减的结构
// 请同学们务必参考如下代码中关于输入、输出的处理
// 这是输入输出处理效率很高的写法
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.StreamTokenizer;
import java.util.Comparator;
import java.util.HashSet;
import java.util.PriorityQueue;
public class Code04_TopKSumCrossTwoArrays {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StreamTokenizer in = new StreamTokenizer(br);
        PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
        while (in.nextToken() != StreamTokenizer.TT_EOF) {
            int N = (int) in.nval;
            in.nextToken();
            int K = (int) in.nval;
            int[] arr1 = new int[N];
            int[] arr2 = new int[N];
            for (int i = 0; i < N; i++) {
                in.nextToken();
                arr1[i] = (int) in.nval;
            }
            for (int i = 0; i < N; i++) {
                in.nextToken();
                arr2[i] = (int) in.nval;
            }
            int[] topK = topKSum(arr1, arr2, K);
            for (int i = 0; i < K; i++) {
                out.print(topK[i] + " ");
            }
            out.println();
            out.flush();
        }
    }
    // 放入大根堆中的结构
    public static class Node {
        public int index1;// arr1中的位置
        public int index2;// arr2中的位置
        public int sum;// arr1[index1] + arr2[index2]的值
        public Node(int i1, int i2, int s) {
            index1 = i1;
            index2 = i2;
            sum = s;
        }
    }
    // 生成大根堆的比较器
    public static class MaxHeapComp implements Comparator<Node> {
        @Override
        public int compare(Node o1, Node o2) {
            return o2.sum - o1.sum;
        }
    }
    public static int[] topKSum(int[] arr1, int[] arr2, int topK) {
        if (arr1 == null || arr2 == null || topK < 1) {
            return null;
        }
        int N = arr1.length;
        int M = arr2.length;
        topK = Math.min(topK, N * M);
        int[] res = new int[topK];
        int resIndex = 0;
        PriorityQueue<Node> maxHeap = new PriorityQueue<>(new MaxHeapComp());
        HashSet<String> set = new HashSet<>();
        int i1 = N - 1;
        int i2 = M - 1;
        maxHeap.add(new Node(i1, i2, arr1[i1] + arr2[i2]));
        set.add(i1 + "_" + i2);
        while (resIndex != topK) {
            Node curNode = maxHeap.poll();
            res[resIndex++] = curNode.sum;
            i1 = curNode.index1;
            i2 = curNode.index2;
            set.remove(i1 + "_" + i2);
            if (i1 - 1 >= 0 && !set.contains((i1 - 1) + "_" + i2)) {
                set.add((i1 - 1) + "_" + i2);
                maxHeap.add(new Node(i1 - 1, i2, arr1[i1 - 1] + arr2[i2]));
            }
            if (i2 - 1 >= 0 && !set.contains(i1 + "_" + (i2 - 1))) {
                set.add(i1 + "_" + (i2 - 1));
                maxHeap.add(new Node(i1, i2 - 1, arr1[i1] + arr2[i2 - 1]));
            }
        }
        return res;
    }
}
发表于 2023-04-17 09:42:30 回复(0)
java的
import java.util.*;

public class Main {

    static class Node {
        int x;
        int y;

        public Node (int x, int y) {
            this.x = x;
            this.y = y;
        }

    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        int K = sc.nextInt();
        Integer[] arr1 = new Integer[N];
        Integer[] arr2 = new Integer[N];
        for (int i = 0; i < N; i++) {
            arr1[i] = sc.nextInt();
        }
        for (int i = 0; i < N; i++) {
            arr2[i] = sc.nextInt();
        }
        
        Arrays.sort(arr1, new Comparator<Integer>() {
            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });
        
        Arrays.sort(arr2, new Comparator<Integer>() {
            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });
        
        List<Integer> res = new ArrayList<>();

        PriorityQueue<Node> queue = new PriorityQueue<>(new Comparator<Node>() {
            @Override
            public int compare(Node o1, Node o2) {
                return arr1[o2.x] + arr2[o2.y] - arr1[o1.x] - arr2[o1.y];
            }
        });

        int[] visitX = new int[N];
        int[] visitY = new int[N];
        for (int i = 0; i < N; i++) {
            visitX[i] = -1;
            visitY[i] = -1;
        }

        queue.add(new Node(0, 0));
        while (!queue.isEmpty()) {
            Node curr = queue.remove(); 
            visitX[curr.x] = curr.y;
            visitY[curr.y] = curr.x;
            res.add(arr1[curr.x] + arr2[curr.y]);
            if (K == res.size()) {
                break;
            }
            if (curr.x + 1 < N) {
                if (curr.y == 0 || visitX[curr.x + 1] == curr.y - 1) {
                    queue.add(new Node(curr.x + 1, curr.y));
                }
            }
            if (curr.y + 1 < N) {
                if (curr.x == 0 || visitY[curr.y + 1] == curr.x - 1) {
                    queue.add(new Node(curr.x, curr.y + 1));
                }
            }
        }

        if (res.size() > 0) {
            for (int i = 0; i < res.size() - 1; i++) {
            System.out.print(res.get(i) + " ");
        }
            System.out.print(res.get(res.size() - 1));
        }
    }

}

发表于 2020-08-21 22:41:58 回复(0)
#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>
using namespace std;
struct node {
    int rows/*行*/, cols/*列*/;
    int val;
    node(int row, int col, int v) : rows(row), cols(col), val(v){}
    bool operator<(const node& rhs) const {
         return val < rhs.val;
    }
};
void Solution(vector& a, vector& b, int n, int k) {
    vector ret(k, 0);
    int count = 0;
    priority_queue max_heap; 
    vector visited(n*n, false);
    visited[n*n-1] = true;
    max_heap.push(node(n-1, n-1, a[n-1] + b[n-1]));
    while(count!=k) {
        auto Node = max_heap.top();
        max_heap.pop();
        int row = Node.rows;
        int col = Node.cols;
        ret[count++] = Node.val;
        if (row-1 >= 0 && !visited[(row-1) * n + col]) {
            visited[(row-1) * n + col] = true;
            max_heap.push(node(row-1, col, a[row-1] + b[col]));
        }
        if (col-1 >= 0 && !visited[row * n + col - 1]) {
            visited[row * n + col - 1] = true;
            max_heap.push(node(row, col-1, a[row] + b[col-1]));
        }
    }
    for(int i = 0; i<k; i++)  {
        cout << ret[i];
        if (i!=k-1)
            cout << " ";
    }
}
void Input(vector& a, vector& b, int N) {
    for(int i = 0; i> a[i];
    for(int i = 0; i> b[i];
}
int main() {
    int N, K;
    while (cin >> N >> K) {
        vector a(N, 0);
        vector b(N, 0);
        Input(a, b, N);
        sort(a.begin(), a.end());
        sort(b.begin(), b.end());
        Solution(a, b, N, K);
    }
}


发表于 2020-08-18 22:05:51 回复(0)
要先排序 我滴个天
#include<iostream>
#include<string>
#include<vector>
#include<queue>
#include<set>
#include<algorithm>

using namespace std;

class Node{
  public:
    int lindex;
    int rindex;
    int sum;
  public:
    Node(int l,int r,int s){
        lindex = l;
        rindex = r;
        sum =s;
    }
    //必须重载
    friend bool operator< (Node l,Node r){
        return l.sum < r.sum;//从小到大排序
    }
    
};

class Solution {
    public:
      vector<int> getMaxTopK(vector<int> arr1,vector<int> arr2,int K){
          if(arr1.size() == 0 || arr2.size()==0 || arr1.size()*arr2.size() < K)
              return {};
          vector<int> res;
          priority_queue<Node> que;
          set<string> s;
          //先放入最大的
          que.push(Node(arr1.size()-1,arr2.size()-1,arr1[arr1.size()-1] + arr2[arr2.size()-1]));
          s.insert(to_string(arr1.size()-1) + to_string(arr2.size()-1));
          int index = 0;
          //int len = arr1.size() * arr2.size();
          //K = min(K, len);//len1*len2表示最多能够组成的结果的个数
          while(index < K && !que.empty()){
              Node temp = que.top();
              que.pop();
              res.push_back(temp.sum);
              //判断  找不到
              if(s.count(to_string(temp.lindex-1) + to_string(temp.rindex)) == 0){
                  que.push(Node(temp.lindex-1,temp.rindex,arr1[temp.lindex-1]+arr2[temp.rindex]));
                  s.insert(to_string(temp.lindex-1) + to_string(temp.rindex));
              }
              if(s.count(to_string(temp.lindex) + to_string(temp.rindex-1)) == 0){
                  que.push(Node(temp.lindex,temp.rindex-1,arr1[temp.lindex] + arr2[temp.rindex-1]));
                  s.insert(to_string(temp.lindex) + to_string(temp.rindex-1));
              }
              index ++;
              
          }
          return res;
          
      }
      
    
};


int main(){
    int N,N2,K;
    cin>>N>>K;
    vector<int> arr1(N);
    vector<int> arr2(N);
    vector<int> res;
   // cout<<"dsds"<<endl;
    for(int i=0;i<N;i++){
        cin>>arr1[i];
    }
    //cout<<"dsdsdsd"<<endl;
    for(int i=0;i<N;i++){
        cin>>arr2[i];
    }
    sort(arr1.begin(),arr1.end());
    sort(arr2.begin(),arr2.end());
    //cout<<"dsdsdsdsdsd"<<endl;
    //cout<<arr1[0]<<arr2[0]<<K<<endl;
    Solution c1;
    res = c1.getMaxTopK(arr1,arr2,K);
    for(int i=0;i<res.size();i++){
        cout<<res[i]<<" ";
    }
    cout<<endl;
    return 0;
}

发表于 2020-05-25 23:43:18 回复(0)
这题目也是醉了,面试出题的时候没有最下方的N和K的取值范围,从题目描述看N妥妥是数组长度,哪来的超过长度一说。。
另外***也不是有序数组,我真是服了
发表于 2020-04-15 22:28:09 回复(0)
有序数组????
发表于 2019-08-15 15:32:29 回复(0)