首页 > 试题广场 >

在二叉树中找到两个节点的最近公共祖先(进阶)

[编程题]在二叉树中找到两个节点的最近公共祖先(进阶)
  • 热度指数:1071 时间限制:C/C++ 3秒,其他语言6秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
给定一棵二叉树,多次给出这棵树上的两个节点 o1 和 o2,请对于每次询问,找到 o1 和 o2 的最近公共祖先节点。

输入描述:
第一行输入两个整数 n 和 root,n 表示二叉树的总节点个数,root 表示二叉树的根节点。

以下 n 行每行三个整数 fa,lch,rch,表示 fa 的左儿子为 lch,右儿子为 rch。(如果 lch 为 0 则表示 fa 没有左儿子,rch同理)

第 n+2 行输入一个整数 m,表示询问的次数。

以下 m 行每行两个节点 o1 和 o2。


输出描述:
对于每组询问每行输出一个整数表示答案。
示例1

输入

8 1
1 2 3
2 4 5
4 0 0
5 0 0
3 6 7
6 0 0
7 8 0
8 0 0
4
4 5
5 2
6 8
5 8

输出

2
2
3
1

备注:

建立每个节点的高度,当查询o1 o2点时,先将两个点沿着father移动到同一高度。
然后一起移动到同一节点,即为最近公共祖先。


#include<iostream>
#include<vector>
using namespace std;
struct node
{
    int l;
    int r;
    int level;
    int parent=-1; 
};

void dfs(vector<node>&arr,int root,int level) //由root开始生成每棵树的level
{
    if(root==0)
        return;
    arr[root].level=level;
    dfs(arr,arr[root].l,level+1);
    dfs(arr,arr[root].r,level+1); 
}

int LCA(vector<node>&arr,int o1,int o2) //寻找o1 o2的最近祖先
{
    while(o1!=o2)
    {
        
        if(arr[o1].level>arr[o2].level)  //o1更低
        {
            o1=arr[o1].parent;
        }
        else if(arr[o1].level<arr[o2].level)
        {
            o2=arr[o2].parent;
        }
        else
        {
             o1=arr[o1].parent;
             o2=arr[o2].parent;  //两者一起向上  
        }
    }
    return o1;
}
int main()
{
    
    int n,root;
    cin>>n>>root;
    vector<node> arr(n+1);
    for(int i=0;i<n;i++)
    {
        int fa,lc,rc;
        cin>>fa>>lc>>rc;
        arr[fa].l=lc;
        arr[fa].r=rc;
        if(lc)
            arr[lc].parent=fa;
        if(rc)
            arr[rc].parent=fa;
    }
    dfs(arr,root,1);
    int m;
    cin>>m;
    while(m--)
    {
        int o1,o2;
        cin>>o1>>o2;
        cout<<LCA(arr,o1,o2)<<endl;   
    }
 
}


发表于 2021-01-21 19:20:09 回复(0)
书里面的代码,query部分有bug
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int rootVal = sc.nextInt();
        TreeNode root = buildTree(sc, n, rootVal);
        Record record = new Record(root);
        int m = sc.nextInt();
        StringBuilder cache = new StringBuilder();
        while (m-- > 0) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            cache.append(record.query(u, v)).append("\n");
        }
        System.out.println(cache);
    }


    private static TreeNode buildTree(Scanner sc, int n, int rootVal) {
        Map<Integer, TreeNode> map = new HashMap<>();
        while (n-- > 0) {
            int fa = sc.nextInt();
            int lch = sc.nextInt();
            int rch = sc.nextInt();
            TreeNode faNode = map.get(fa);
            if (faNode == null) {
                faNode = new TreeNode(fa);
                map.put(fa, faNode);
            }
            if (lch != 0) {
                TreeNode lchNode = map.get(lch);
                if (lchNode == null) {
                    lchNode = new TreeNode(lch);
                    map.put(lch, lchNode);
                }
                faNode.left = lchNode;
            }
            if (rch != 0) {
                TreeNode rchNode = map.get(rch);
                if (rchNode == null) {
                    rchNode = new TreeNode(rch);
                    map.put(rch, rchNode);
                }
                faNode.right = rchNode;
            }
        }
        return map.get(rootVal);
    }
}

class Record {
    private Map<Integer, HashMap<Integer, Integer>> map;  // a b c: a和b的父节点是c

    public Record(TreeNode head) {
        map = new HashMap<>();
        initMap(head);
        setMap(head);
    }

    private void initMap(TreeNode head) {
        if (head == null) {
            return;
        }
        map.put(head.val, new HashMap<>());
        initMap(head.left);
        initMap(head.right);
    }

    private void setMap(TreeNode head) {
        if (head == null) {
            return;
        }
        headRecord(head.left, head);
        headRecord(head.right, head);
        subRecord(head);
        setMap(head.left);
        setMap(head.right);
    }

    private void headRecord(TreeNode son, TreeNode head) {
        if (son == null) {
            return;
        }
        map.get(son.val).put(head.val, head.val);
        headRecord(son.left, head);
        headRecord(son.right, head);
    }

    private void subRecord(TreeNode head) {
        if (head == null) {
            return;
        }
        preLeft(head.left, head.right, head);
        subRecord(head.left);
        subRecord(head.right);
    }

    private void preLeft(TreeNode left, TreeNode right, TreeNode head) {
        if (left == null) {
            return;
        }
        preRight(left, right, head);
        preLeft(left.left, right, head);
        preLeft(left.right, right, head);
    }

    private void preRight(TreeNode left, TreeNode right, TreeNode head) {
        if (right == null) {
            return;
        }
        map.get(left.val).put(right.val, head.val);
        preRight(left, right.left, head);
        preRight(left, right.right, head);
    }

    public int query(int o1, int o2) {
        if (o1 == o2) {
            return o1;
        }
        if (map.containsKey(o1) && map.get(o1).containsKey(o2)) {
            return map.get(o1).get(o2);
        }
        if (map.containsKey(o2) && map.get(o2).containsKey(o1)) {
            return map.get(o2).get(o1);
        }
        return -1;
    }
}


发表于 2021-08-20 11:35:59 回复(0)
#include<iostream>
#include<vector>
#include<unordered_map>
using namespace std;
struct TreeNode {
    int val;
    int left;
    int right;
};
vector<TreeNode* > vec;
// 这种方法更练coding一些
unordered_map<TreeNode*, unordered_map<TreeNode*, TreeNode*>> maps;

void init_map(TreeNode* root){
    if(root == NULL)
        return;
    unordered_map<TreeNode*, TreeNode*> map_;
    maps[root] = map_;
    init_map(vec[root->left]);
    init_map(vec[root->right]);
}
void head_record(TreeNode* node, TreeNode* head){
    if(node == NULL)
        return;
    maps[node][head] = head;
    head_record(vec[node->left], head);
    head_record(vec[node->right], head);
}

void pre_right(TreeNode *l, TreeNode* r, TreeNode* root){
    if(r == NULL)
        return;
    maps[l][r] = root;
    pre_right(l, vec[r->left], root);
    pre_right(l, vec[r->right], root);
}

void pre_left(TreeNode* l, TreeNode* r, TreeNode* root){
    if(l == NULL)
        return;
    pre_right(l, r, root);
    pre_left(vec[l->left], r, root);
    pre_left(vec[l->right], r, root);
}

void sub_record(TreeNode* root){
    if(root == NULL)
        return;
    pre_left(vec[root->left], vec[root->right], root);
    // 原书貌似这里是个冗余操作,加上这两行会超时,去掉会过
    // 加入相当于每个点加了个O(n^2)的操作,所以最后相当于O(n^3).
    //sub_record(vec[root->left]);
    //sub_record(vec[root->right]);
}

void set_map(TreeNode* root){
    if(root == NULL)
        return;
    // 以root为公共祖先的点
    head_record(vec[root->left], root);
    head_record(vec[root->right], root);
    sub_record(root);
    set_map(vec[root->left]);
    set_map(vec[root->right]);
}
int main(){
    int n, root_val;
    cin>>n>>root_val;
    vec.resize(n+1);
    vec[0] = NULL;
    for(int i = 0;i<n;i++){
        TreeNode* node = new TreeNode();
        int fa, lch, rch;
        cin>>fa>>lch>>rch;
        node->left = lch;
        node->right = rch;
        node->val = fa;
        vec[fa] = node;
    }
    // 初始化
    init_map(vec[root_val]);
    // 开始放点
    set_map(vec[root_val]);
    int m;
    cin>>m;
    for(int i = 0;i<m;i++)
    {
       int o1, o2;
       cin>>o1>>o2;
       //cout<<i<<"-----------------------------"<<endl;
       if(maps[vec[o1]].count(vec[o2])!=0)
           cout<<maps[vec[o1]][vec[o2]]->val<<endl;
       if(maps[vec[o2]].count(vec[o1])!=0)
           cout<<maps[vec[o2]][vec[o1]]->val<<endl;       
    }
    return 0;
}

发表于 2021-05-15 15:13:36 回复(0)
#include<bits/stdc++.h>
using namespace std;
struct LCA{
    // 用哈希表记录下任意两个节点的最近公共祖先
    map<int,map<int,int>>mp;
    // 初始化,每个节点都成为这个map的一个键
    void initMap(int root,int* lc,int* rc)
    {
        if(!root) return;
        map<int,int>map_;
        mp.insert(make_pair(root,map_));
        initMap(lc[root],lc,rc);
        initMap(rc[root],lc,rc);
    }
    // 填完整这张表
    void setMap(int root,int* lc,int* rc)
    {
        if(!root) return;
        // 任意节点h和它的后代节点的LCA是h
        headRecord(lc[root],root,lc,rc);
        headRecord(rc[root],root,lc,rc);
        // h的左子树的每个节点和h右子树的每个节点的LCA是h
        subRecord(root,lc,rc);
        // 对每个结点均按此要求来填表
        setMap(lc[root],lc,rc);
        setMap(rc[root],lc,rc);
    }
    void headRecord(int n,int h,int* lc,int* rc)
    {
        if(!n) return ;
        mp[n].insert(make_pair(h,h));
        headRecord(lc[n],h,lc,rc);
        headRecord(rc[n],h,lc,rc);
    }
    void subRecord(int h,int* lc,int* rc)
    {
        if(!h) return;
        preLeft(lc[h],rc[h],h,lc,rc);
        subRecord(lc[h],lc,rc);
        subRecord(rc[h],lc,rc);
    }
    void preLeft(int l,int r,int h,int* lc,int* rc)
    {
        if(!l) return;
        preRight(l,r,h,lc,rc);
        preLeft(lc[l],r,h,lc,rc);
        preLeft(rc[l],r,h,lc,rc);
    }
    void preRight(int l,int r,int h,int* lc,int* rc)
    {
        if(!r) return;
        mp[l].insert(make_pair(r,h));
        preRight(l,lc[r],h,lc,rc);
        preRight(l,rc[r],h,lc,rc);
    }
    int query(int a,int b)
    {

        if(mp[a].find(b)!=mp[a].end())
            return mp[a][b];
        if(mp[b].find(a)!=mp[b].end())
            return mp[b][a];
    }
};
int main()
{
    int n,root;
    //cin>>n>>root;
    scanf("%d %d",&n,&root);
    int p;
    int* lc = new int[n+1];
    int* rc = new int[n+1];
    for(int i=0;i<n;++i)
    {
        scanf("%d",&p);
        scanf("%d %d",&lc[p],&rc[p]);
    }
    LCA* lca = new LCA();
    lca->initMap(root,lc,rc);
    lca->setMap(root,lc,rc);
    int m;
    //cin>>m;
    scanf("%d",&m);
    int a,b;
    while(m--)
    {
        //cin>>a>>b;
        //cout<<LCA(root,lc,rc,a,b)<<endl;
        scanf("%d %d",&a,&b);
        printf("%d\n",lca->query(a,b));
    }
    return 0;
}
发表于 2020-02-06 23:25:28 回复(0)