首页 > 试题广场 >

在二叉树中找到一个节点的后继节点

[编程题]在二叉树中找到一个节点的后继节点
  • 热度指数:2742 时间限制:C/C++ 2秒,其他语言4秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
二叉树中一个节点的后继节点指的是,二叉树的中序遍历的序列中的下一个节点。

输入描述:
第一行输入两个整数 n 和 root,n 表示二叉树的总节点个数,root 表示二叉树的根节点。

以下 n 行每行三个整数 fa,lch,rch,表示 fa 的左儿子为 lch,右儿子为 rch。(如果 lch 为 0 则表示 fa 没有左儿子,rch同理)

最后一行输入要询问的节点 node。


输出描述:
输出一个整数表示答案。(如果 node 是最后一个节点,则输出 0 )
示例1

输入

10 6
6 3 9
3 1 4
1 0 2
2 0 0
4 0 5
5 0 0
9 8 10
10 0 0
8 7 0
7 0 0
10

输出

0

备注:

中序遍历时用一个prev变量记录前一个节点的值,当这个值与询问节点值相等时,当前遍历到的节点就是要求的后继节点。
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.Stack;
import java.util.HashMap;

class TreeNode {
    public int val;
    public TreeNode left;
    public TreeNode right;
    public TreeNode(int val) {
        this.val = val;
        this.left = null;
        this.right = null;
    }
}

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] params = br.readLine().trim().split(" ");
        int n = Integer.parseInt(params[0]), rootVal = Integer.parseInt(params[1]);
        // 先建树
        TreeNode root = new TreeNode(rootVal);
        HashMap<Integer, TreeNode> map = new HashMap<>();
        map.put(root.val, root);
        for(int i = 0; i < n; i++){
            params = br.readLine().split(" ");
            int val = Integer.parseInt(params[0]);
            int leftVal = Integer.parseInt(params[1]);
            int rightVal = Integer.parseInt(params[2]);
            TreeNode node = map.get(val);
            if(leftVal != 0) {
                node.left = new TreeNode(leftVal);
                map.put(leftVal, node.left);
            }
            if(rightVal != 0) {
                node.right = new TreeNode(rightVal);
                map.put(rightVal, node.right);
            }
        }
        // 中序遍历
        int query = Integer.parseInt(br.readLine());
        int prev = -1, res = 0;
        Stack<TreeNode> stack = new Stack<>();
        while(!stack.isEmpty() || root != null){
            while(root != null){
                stack.push(root);
                root = root.left;
            }
            if(!stack.isEmpty()){
                root = stack.pop();
                if(prev == query){
                    // 前一个节点为询问节点了,当前节点为后继节点
                    res = root.val;
                    break;
                }
                prev = root.val;
                root = root.right;
            }
        }
        System.out.println(res);
    }
}

发表于 2021-11-16 11:47:35 回复(0)
emmm... 这种输入拿数组接比较方便

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.StreamTokenizer;


public class Main{
    
    public static void main(String[] args)throws Exception{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] commands = br.readLine().split(" ");
        int size = Integer.parseInt(commands[0]);
        int root = Integer.parseInt(commands[1]);
        
        int[] lr = new int[size+1];
        int[] rr = new int[size+1];
        int[] pr = new int[size+1];
        pr[root] = -1;  //root的父节点是-1

        for(int i=0;i<size;++i){
            commands = br.readLine().split(" ");
            int cur = Integer.parseInt(commands[0]);
            int left = Integer.parseInt(commands[1]);
            int right = Integer.parseInt(commands[2]);
            pr[left] = cur;
            pr[right] = cur;
            lr[cur]=left;
            rr[cur]=right;
        } 
        commands = br.readLine().split(" ");
        int x = Integer.parseInt(commands[0]);
		
        fun(lr,rr,pr,x);
    }
    
    private static void fun(int[] lr,int[] rr, int[] pr, int cur){
        if(rr[cur] != 0){
            //有右子树
            int temp = rr[cur];
            while(lr[temp]!=0){
                temp = lr[temp];
            }
            System.out.println(temp);
            return;
        }else{
            //往上找
            int temp = pr[cur];
            int son = cur;
            while(temp != -1){
                if(son == lr[temp]){
                    //如果是左孩子
                    System.out.println(temp);
                    return;
                }
                //否则继续往上找
                son=temp;
                temp = pr[temp];
            }
            System.out.println(0);
        }
    }
}


发表于 2020-11-19 23:31:48 回复(0)
#include<iostream>
#include<vector>
using namespace std;
struct treenode{    //树节点结构
    int par;
    int left;
    int right;
};
int main(){
    int n,root;
    cin>>n>>root;
    vector<treenode> tree(n+1);
    tree[0].left = 0;
    tree[0].right = 0;
    tree[0].par = 0;
    int cur_id, cur_l,cur_r;
    for(int i=0;i<n;i++){    //建树
        cin>>cur_id>>cur_l>>cur_r;
        tree[cur_id].left = cur_l;
        tree[cur_id].right = cur_r;
        if(cur_id == root)
            tree[cur_id].par = 0;
        tree[cur_l].par = cur_id;
        tree[cur_r].par = cur_id;
    }
    int find_node;
    int res = 0;
    cin >> find_node;
    if(tree[find_node].right != 0){ //如果查询节点有右子节点,则其右子节点的最左子节点即为后继节点
        int cur_node = tree[find_node].right;
        while(tree[cur_node].left != 0)
            cur_node = tree[cur_node].left;
        res = cur_node;
    }
    else{ //如果查询节点没有右子节点,则向上查找其父节点,第一个查找到的查询节点为其左子节点的父节点即为后继节点
        if(find_node == root)
            res = 0;
        else{
            int par_node = tree[find_node].par;
            int cur_node = find_node;
            while(cur_node == tree[par_node].right && cur_node != 0){
                cur_node = par_node;
                par_node = tree[cur_node].par;
            }
            if(cur_node == 0)    //查找到根节点依然无符合条件的父节点(查询节点均为所有父节点的右子节点),则查询节点无后继节点
                res = 0;
            else
                res = par_node;
        }
    }
    cout << res << endl;
    return 0;
}

发表于 2020-10-05 10:52:42 回复(0)
#include <bits/stdc++.h>
using namespace std;

const int M = 500000;
int n, r, fa[M], lch[M], rch[M];
int F(int r){
    if(rch[r]){
        int l = rch[r];
        while(lch[l])
            l = lch[l];
        return l;
    }else{
        int p = fa[r];
        while(p!=r && r!=lch[p]){
            r = p;
            p = fa[p];
        }
        return p==r?0:p;
    }
}

int main(){
    int x, y, z;
    scanf("%d%d", &n, &r);
    for(int i=0;i<n;i++){
        scanf("%d%d%d", &x, &y, &z);
        lch[x] = y;
        rch[x] = z;
        if(y)
            fa[y] = x;
        if(z)
            fa[z] = x;
    }
    scanf("%d", &x);
    printf("%d\n", F(x));
    return 0;
}

发表于 2020-06-18 00:40:02 回复(0)
import java.util.*;

class Node {
    public int val;
    public Node right;
    public Node left;
    public Node (int val) {
        this.val = val;
    }
}

public class Main {
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        Node head = constructTree(sc, n);
        int k = sc.nextInt();
        Node res = findNext(head, k);
        if (res == null)
            System.out.println(0);
        else
            System.out.println(res.val);
    }
    
    public static Node constructTree (Scanner sc, int n) {
        HashMap<Integer, Node> map = new HashMap<>();
        int rootVal = sc.nextInt();
        Node root = new Node(rootVal);
        map.put(rootVal, root);
        for (int i = 0; i < n; i++) {
            int nodeVal = sc.nextInt();
            int leftVal = sc.nextInt();
            int rightVal = sc.nextInt();
            if (map.containsKey(nodeVal)) {
                Node tmpNode = map.get(nodeVal);
                Node leftNode = leftVal == 0 ? null : new Node(leftVal);
                Node rightNode = rightVal == 0 ? null : new Node(rightVal);
                tmpNode.left = leftNode;
                tmpNode.right = rightNode;
                if (leftVal != 0)
                    map.put(leftVal, leftNode);
                if (rightVal != 0)
                    map.put(rightVal, rightNode);
            }
        }
        return root;
    }
    
    public static Node findNext (Node head, int k) {
        if (head == null)
            return null;
        Node cur = head;
        Node result = null;
        boolean test = false;
        Node mostRight = null;
        while (cur != null) {
            mostRight = cur.left;
            if (mostRight != null) {
                while (mostRight.right != null && mostRight.right != cur) {
                    mostRight = mostRight.right;
                }
                if (mostRight.right == null) {
                    mostRight.right = cur;
                    cur = cur.left;
                    continue;
                }
                else {
                    mostRight.right = null;
                }
            }
            if (cur.val == k || test == true) {
                if (test == true) {
                    result = cur;
                    break;
                }
                test = true;
            }
            cur = cur.right;
        }
        return result;
    }
}


用Morris遍历一样ok,没必要非得用parent.
发表于 2021-06-20 19:09:51 回复(0)
# include<iostream>
# include<map>

using namespace std;

struct TreeNode{
    int val;
    TreeNode* parent;
    TreeNode* left;
    TreeNode* right;
    TreeNode(int x){
        val = x;
        parent = left = right = nullptr;
    }
};

TreeNode* findNext(TreeNode* node){
    TreeNode* cur = node;
    if (cur->right != nullptr){
        cur = cur->right;
        while (cur->left != nullptr){
            cur = cur->left;
        }
        return cur;
    }
    else{
        while (cur->parent != nullptr && cur->parent->left != cur){
            cur = cur->parent;
        }
        return cur->parent;
    }
}

int main(){
    int n,rootval;
    cin>>n>>rootval;
    map<int,TreeNode*> mp;
    mp[0] = nullptr;
    for (int i=0; i<n; i++){
        int pa,lc,rc;
        cin >> pa >> lc >> rc;
        if (!mp.count(pa)) mp[pa] = new TreeNode(pa);
        if (!mp.count(lc)) mp[lc] = new TreeNode(lc);
        if (!mp.count(rc)) mp[rc] = new TreeNode(rc);
        mp[pa]->left = mp[lc];
        mp[pa]->right = mp[rc];
        if (mp[lc] != nullptr) mp[lc]->parent = mp[pa];
        if (mp[rc] != nullptr) mp[rc]->parent = mp[pa];
    }
    int visit;
    cin >> visit;
    TreeNode* nextNode = findNext(mp[visit]);
    int res = nextNode == nullptr? 0:nextNode->val;
    cout << res << endl;
    return 0;
}

发表于 2021-04-29 19:59:52 回复(0)
3 1
1 2 3
3 0 0
2 0 0
3
第21个用例给的不对,应该是
3 1
1 2 3
2 0 0
3 0 0
3
就这一个用例,我得重写生成二叉树的代码,佛了
发表于 2022-02-16 10:57:37 回复(0)
剑指Offer思路
#include<iostream>
using namespace std;
struct node {
    int left;
    int right;
    int parent;
}Tree[500001];
int main(){
    int n, root;
    while (cin >> n >> root) {
        Tree[root].parent = 0;
        for (int i = 0; i < n; i++) {
            int fa;
            cin >> fa;
            cin >> Tree[fa].left >> Tree[fa].right;
            Tree[Tree[fa].left].parent = fa;
            Tree[Tree[fa].right].parent = fa;
        }
        int pnode;
        cin >> pnode;
        if (pnode == 0)cout << "0" << endl;
        int pnext = 0;
        ///所求节点的右孩子不为空,则下一个节点是右孩子最左的节点
        if (Tree[pnode].right != 0) {
            int tmp = Tree[pnode].right;
            while (Tree[tmp].left != 0) {
                tmp = Tree[tmp].left;
            }
            pnext = tmp;
        }
        //所求节点的右孩子是空:
        //1-所求节点是它父节点的左节点,下一个节点是他父节点
        //2-所有节点是他父节点的右节点,下一个节点是:如果他父节点是他祖父节点的左子树,那么下一个节点是他祖父节点
        else if(Tree[pnode].parent!=0) {
            int pParent = Tree[pnode].parent;
            int cur = pnode;
            while (pParent != 0 && cur == Tree[pParent].right) {
                cur = pParent;
                pParent = Tree[cur].parent;
            }
            pnext = pParent;
        }
        cout << pnext << endl;
    }
    return 0;
}
	


编辑于 2020-08-24 11:14:58 回复(0)
#include <iostream>
using namespace std;
int find_next_node(int node, int *pa, int *lc, int *rc, int root)
{
    int rchild = rc[node];
    if (rchild) { //如果有右孩子沿着右孩子的左孩子方向找
        int lchild = rchild;
        while (lc[lchild])
            lchild = lc[lchild];
        return lchild;
    } else {  //如果没有右孩子沿着父亲节点的方向找
        int parent = pa[node];
        while (parent != root && node != lc[parent]) {
            node = parent;
            parent = pa[parent];
        }
        if(parent==root&&rc[parent])  return parent;//若找到根节点,但根节点有右孩子,输出根节点
        return parent == root ? 0 : parent;
    }
}
int main(void)
{
    int n, root;
    cin >> n >> root;
    int *pa = new int[n + 1];//父亲节点数组
    int *lc = new int[n + 1];//左孩子数组
    int *rc = new int[n + 1];//右孩子数组
    int parent, lchild, rchild;
    for (int i = 0; i < n; i++) {
        cin >> parent >> lchild >> rchild;
        lc[parent] = lchild;
        rc[parent] = rchild;
        if (lchild) pa[lchild] = parent;
        if (rchild) pa[rchild] = parent;
    }
    int node;
    cin >> node;
    cout << find_next_node(node, pa, lc, rc, root);
    return 0;
}

以上参考了讨论中其他同学的答案,改进了根节点有右孩子的情况。

编辑于 2020-08-07 16:26:18 回复(0)
// 通过中序递归遍历实现,欢迎拍砖
public class Testss {
	public static List<Integer> result = new ArrayList();
	public static void main(String[] args) {
		int[][] arr = {{10,6  },
			{6,3,9 },
			{3,1,4 },
			{1,0,2 },
			{2,0,0 },
			{4,0,5 },
			{5,0,0 },
			{9,8,10},
			{10,0,0},
			{8,7,0 },
			{7,0,0 },
			{6    }};
		int root = arr[0][1];
		int allNode = arr[0][0];
		int target = arr[arr.length-1][0];
		midTravel(arr, root);
		int tag = 0;
		for (int i=0;i< result.size();i++) {
			if (result.get(i) == target){
				tag = i;
				break;
			}
		}
		if (tag >= allNode-1) {
			System.out.println(0);
		} else {
			System.out.println(result.get(tag+1));
		}
	}

	// 通过中序递归遍历来实现,遍历结果保存到result中
	public static void midTravel(int[][] arr, int target){
		int left = getLeft(arr, target);
		if (left != 0) {
			midTravel(arr, left);
		}
		result.add(target);
		int right = getRight(arr, target);
		if (right != 0) {
			midTravel(arr, right);
		}
	}

	private static int getRight(int[][] arr, int target) {
		for (int i=1;i<arr.length-1;i++){
			if (arr[i][0] == target) {
				return arr[i][2];
			}
		}
		return 0;
	}

	private static int getLeft(int[][] arr, int target) {
		for (int i=1;i<arr.length-1;i++){
			if (arr[i][0] == target) {
				return arr[i][1];
			}
		}
		return 0;
	}
}

发表于 2020-07-24 23:31:14 回复(0)
#include <bits/stdc++.h>
using namespace std;

int next(vector<int> &lefts, vector<int> &rights, vector<int> &parents, int node){
    if(node == 0) return 0;
    if(rights[node]){
        //找右子树的最左下节点
        node = rights[node];
        while(lefts[node]){
            node = lefts[node];
        }
        return node;
    }
    int parent;
    while((parent = parents[node]) && node == rights[parent]){
        node = parent;
    }
    return parent;
}

int main(){
    int n, root;
    cin >> n >> root;
    int node, lch, rch;
    vector<int> lefts(n + 1), rights(n + 1), parents(n + 1);
    while(cin >> node >> lch >> rch){
        lefts[node] = lch;
        parents[lch] = node;
        rights[node] = rch;
        parents[rch] = node;
    }
    int target = node;
    cout << next(lefts, rights, parents, target) << endl;
    return 0;
}
发表于 2020-06-29 15:31:32 回复(0)
def nextNode(root,n,tree):
    global flag
    if root==0:
        return 0
    m=nextNode(tree[root][0],n,tree)
    if m>0:
        return m
    if flag:
        return root
    if root==n:
        flag=True
    m=nextNode(tree[root][1],n,tree)
    return m

if __name__=='__main__':
    n,root=list(map(int,input().split()))
    tree=[[0,0] for _ in range(n+1)]
    for _ in range(n):
        node=list(map(int,input().split()))
        tree[node[0]][0]=node[1]
        tree[node[0]][1]=node[2]
    m=int(input())
    flag=False
    ans=nextNode(root,m,tree)
    print(ans)

发表于 2020-04-26 13:16:00 回复(0)
#include <iostream>

using namespace std;

int find_next_node(int node, int *pa, int *lc, int *rc, int root)
{
    int rchild = rc[node];
    if (rchild) {
        int lchild = rchild;
        while (lc[lchild])
            lchild = lc[lchild];
        return lchild;
    } else {
        int parent = pa[node];
        while (parent != root && node != lc[parent]) {
            node = parent;
            parent = pa[parent];
        }
        return parent == root ? 0 : parent;
    }
}

int main(void)
{
    int n, root;
    cin >> n >> root;
    int *pa = new int[n + 1];
    int *lc = new int[n + 1];
    int *rc = new int[n + 1];
    int parent, lchild, rchild;
    for (int i = 0; i < n; i++) {
        cin >> parent >> lchild >> rchild;
        lc[parent] = lchild;
        rc[parent] = rchild;
        if (lchild) pa[lchild] = parent;
        if (rchild) pa[rchild] = parent;
    }
    int node;
    cin >> node;
    cout << find_next_node(node, pa, lc, rc, root);
    return 0;
}

发表于 2020-02-07 14:19:28 回复(0)
#include<bits/stdc++.h>
using namespace std;
int getNextNode(int root,int* lc,int* rc,int* pa)
{
    // 有右孩子
    if(rc[root])
    {
        int node = rc[root];
        while(lc[node])
        {
            node = lc[node];
        }
        return node;
    }
    // 无右孩子
    int node = 0;
    while(root)
    {
        if(pa[root] && root==lc[pa[root]])
            return pa[root];
        root = pa[root];
    }
    return 0;
}
int main()
{
    int n,root;
    cin>>n>>root;
    int p,l,r;
    int* lc = new int[n+1];
    int* rc = new int[n+1];
    // 记录每个结点的父节点
    int* pa = new int[n+1];
    for(int i=0;i<n;++i)
    {
        cin>>p;
        cin>>l>>r;
        lc[p] = l;
        rc[p] = r;
        if(l) pa[l] = p;
        if(r) pa[r] = p;
    }
    int target;
    cin>>target;
    cout<<getNextNode(target,lc,rc,pa);
    return 0;
}
发表于 2020-02-06 16:43:57 回复(0)
#include<iostream>
#include<vector>
using namespace std;

///在二叉树中找到一个节点的后继节点
void fun_inorder(vector<pair<int,int>>& node,int root,bool* flag,int pos,int* res)
{
	if (root == 0)
		return;
    fun_inorder(node, node[root].first,flag,pos,res);
    if (*flag)
	{
		*res = root;
		*flag = false;
	}
	else if (root == pos)
		*flag = true;
    fun_inorder(node, node[root].second, flag, pos, res);


}
void fun_findNext()
{
	int n, root;
	scanf("%d%d",&n,&root);
	vector<pair<int, int>> nodes(n+1);
	for (int i = 0; i < n; ++i)
	{
		int temp;
		scanf("%d",&temp);
		scanf("%d%d",&nodes[temp].first,&nodes[temp].second);
	}
	int res = 0, pos;
    bool flag = false;
	scanf("%d",&pos);
	fun_inorder(nodes,root,&flag,pos,&res);
	printf("%d\n",res);
}
int main()
{
    fun_findNext();
    return 0;
}


编辑于 2019-10-26 09:46:00 回复(0)
#include <iostream>
#include <vector>

using namespace std;

long cook(vector<vector<long>> c, long& find)
{
    auto findNode=[&c](long& node){
        for(long i=0; (size_t)i<c.size(); i++)
        {
            if (c[i][0]==node)
                return i;
        }
        return (long)(-1);
    };
    
    long pos = findNode(find);
    if (pos==-1)
        return 0;
    long right=c[pos][2];
    long left=c[pos][1];
    
    if ((long)0==right)
    {
        if (find==c[0][0])
            return 0;
        
        long tmp=find;
        while(tmp)
        {
            for (long k=0; (long)k<c.size(); k++)
            {
                if (tmp==c[k][1])//it is a left child
                    return c[k][0];
                else if (tmp==c[k][2])
                {
                    tmp=c[k][0];
                }
                else if (tmp==c[0][0])//is the root.
                {
                    return 0;
                }
            }
        }
        return 0;
    }
    
    long far=right;
    while(far)
    {
        pos=findNode(far);
        if (c[pos][1] != (long)0)
            far = c[pos][1];
        else
            return far;
    }
}

int main()
{
    long num;
    long root;
    long tmp1,tmp2,tmp3;
    long find;
    cin >> num >> root;
    vector<vector< long> > nodes(num);
    
    for (long i=0; i<num; i++)
    {
        cin>>tmp1>>tmp2>>tmp3;
        nodes[i].push_back(tmp1);
        nodes[i].push_back(tmp2);
        nodes[i].push_back(tmp3);
    }
    cin>>find;
    cout<<cook(nodes, find);
    
    return 0;
}
发表于 2019-08-02 21:22:25 回复(0)