首页 > 试题广场 >

小红的推理树平衡链路

[编程题]小红的推理树平衡链路
  • 热度指数:1051 时间限制:C/C++ 2秒,其他语言4秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
小红正在分析一个用于智能体推理的决策树。树上的每个节点都记录了一个整数分值,表示模型在这一层做出某个选择时带来的偏移量。
如果一条路径同时满足下面三个条件,小红就把它称为一条平衡路径:
1. 路径可以从树中的任意节点开始;
2. 从起点出发后,每一步都只能走向当前节点的左子节点或右子节点,也就是说整条路径必须始终向下延伸,不能回到父节点,也不能分叉;
3. 路径上所有节点的分值之和恰好为 0,并且这条路径包含的节点个数至少为 2。
现在给定这棵二叉树的层序遍历结果,请你帮助小红统计这棵树中一共有多少条平衡路径。

输入描述:
第一行一个整数 `n`,表示层序遍历列表中的元素个数,`1 <= n <= 200000`。
第二行给出 `n` 个元素,按从上到下、从左到右的顺序描述这棵树。每个元素要么是一个整数,表示对应节点的分值;要么是字符串 `None`,表示该位置为空。保证第一个元素不是 `None`,非空节点的分值满足 `-10^9 <= val <= 10^9`,并且输入一定能够唯一确定一棵合法二叉树。
建树时采用常见的层序规则:每次从队列中取出一个非空节点,再按顺序读取它的左孩子和右孩子;如果某个孩子为 `None`,则对应位置不建立节点。


输出描述:
输出一个整数,表示平衡路径的总数。
注意答案可能超过 32 位整数范围。
示例1

输入

3
0 0 0

输出

2

说明

这棵树的根节点有两个孩子,且它们的分值都为 0。以根节点为起点向左走可以得到一条长度为 2 的路径,节点和为 0;向右走同样满足条件,所以答案是 2。
import sys
from collections import Counter

class TreeNode:
    def __init__(self, val = 0, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

sys.setrecursionlimit(200000)

def build_tree_from_list(lst):
    if not lst or lst[0] is None:
        return None
    root = TreeNode(lst[0])
    queue = [root]
    i = 1
    while queue and i < len(lst):
        node = queue.pop(0)
        if lst[i] is not None:
            node.left = TreeNode(lst[i])
            queue.append(node.left)
        i += 1
        if i < len(lst) and lst[i] is not None:
            node.right = TreeNode(lst[i])
            queue.append(node.right)
        i += 1
    return root

def count_balanced_paths(root):
    cnt = Counter()
    cnt[0] = 1
    ans = 0
    zero_cnt = 0

    def dfs(node, s):
        nonlocal ans, zero_cnt
        if node is None:
            return

        s += node.val
        if node.val == 0:
            zero_cnt += 1

        ans += cnt[s]
        cnt[s] += 1
        dfs(node.left, s)
        dfs(node.right, s)
        cnt[s] -= 1

    dfs(root, 0)
    return ans - zero_cnt

def main():
    n = int(sys.stdin.readline().strip())
    line = sys.stdin.readline().strip().split()
    tree = [int(x) if x != "None" else None for x in line]

    root = build_tree_from_list(tree)
    result = count_balanced_paths(root)
    print(result)

if __name__ == "__main__":
    main()
发表于 2026-05-26 12:10:32 回复(0)
import java.io.*;
import java.util.*;

class TreeNode{
    long val;
    TreeNode left;
    TreeNode right;
    TreeNode(long val){
        this.val = val;
        this.left = null;
        this.right = null;
    }
}
public class Main {
    static long ans = 0;
    static long zeroCount = 0;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());
        if(n == 0){
            System.out.println(0);
            return;
        }
        String[] arr = br.readLine().split(" ");
        // 读取数据

        // 利用队列,以及层序遍历结果,构建树
        Queue<TreeNode> queue = new LinkedList<>();

        TreeNode root = new TreeNode(Long.parseLong(arr[0]));
        queue.offer(root);
        int idx = 1;

        while(!queue.isEmpty() && idx < n){
            TreeNode curr = queue.poll();
            if(idx < n && !arr[idx].equals("None")){
                curr.left = new TreeNode(Long.parseLong(arr[idx]));
                queue.offer(curr.left);
            }
            idx++;
            if(idx < n && !arr[idx].equals("None")){
                curr.right = new TreeNode(Long.parseLong(arr[idx]));
                queue.offer(curr.right);
            }
            idx++;
        }

        // 树建好了

        // 收集所有非空node,这些非null节点,可以作为遍历的起点
        List<TreeNode> allNodes = new ArrayList<>();
        Queue<TreeNode> nodeQueue = new LinkedList<>();
        nodeQueue.offer(root);

        while(!nodeQueue.isEmpty()){
            TreeNode node = nodeQueue.poll();
            allNodes.add(node);
            if(node.left != null) nodeQueue.offer(node.left);
            if(node.right != null) nodeQueue.offer(node.right);
        }


        Map<Long, Integer> map = new HashMap<>();
        map.put(0L, 1); // 初始化,前缀和0出现1次
        dfs(root, 0L, map);


        System.out.println(ans - zeroCount);
    }

    private static void dfs(TreeNode node, long sum, Map<Long, Integer> map){
        if(node == null) return;

        sum += node.val;

        if(node.val == 0){
            zeroCount++;
        }
        ans += map.getOrDefault(sum, 0);

        map.put(sum, map.getOrDefault(sum, 0) + 1);

        dfs(node.left, sum, map);
        dfs(node.right, sum, map);

        map.put(sum, map.get(sum) - 1);
    }
}

发表于 2026-05-13 18:18:55 回复(0)