首页 > 试题广场 >

树上上升序列

[编程题]树上上升序列
  • 热度指数:1464 时间限制:C/C++ 2秒,其他语言4秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
度度熊给定一棵树,树上的第个节点有点权。请你找出一条最长的路径,使得从沿着唯一路径走到的途中,点权不断严格递增。
换句话说,设路径为,则需要满足。输出最长满足条件的路径的长度。

输入描述:
第一行树的节点个数 , 接下来一行个数字,表示每个点的点权。接下来行,每行两个数代表树上的一条边,连接点
.


输出描述:
一行一个数字表示答案,即最长的长度。
示例1

输入

5
3 5 5 4 1
1 2
1 3
2 4
2 5

输出

2
示例2

输入

4
3 4 1 2
1 2
2 3
2 4

输出

2
建立邻接矩阵会超内存,所以用字典结构来保存连接信息。
递归代码如下:
n = int(input())
weights = list(map(int, input().split()))
links = dict()#连接字典
for i in range(n - 1):
    left, right = list(map(int, input().split()))
    if links.get(left - 1, -1) == -1:
        links[left - 1] = [right - 1]
    else:
        links[left - 1].append(right - 1)
    if links.get(right - 1, -1) == -1:
        links[right - 1] = [left - 1]
    else:
        links[right - 1].append(left - 1)
dp = [1] * n
def dfs(i):
    curlist = links.get(i, -1)
    if curlist == -1:
        return dp[i]
    for j in range(len(curlist)):
        if weights[curlist[j]] > weights[i]:
            dp[i] = max(dp[i], 1 + dfs(curlist[j]))
    return dp[i]
for i in range(n):
    if dp[i] != 1: #不为1说明已经在递归中访问过
        continue
    dp[i] = dfs(i)
print(max(dp))

发表于 2022-03-24 17:12:47 回复(0)
//邻接表
import java.util.*;

public class Main {
    static List<List<Integer>> graph = new ArrayList<>();
    static int max = 0;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] num = new int[n];
        for (int i = 0; i < n; i++) {
            num[i] = sc.nextInt();
        }
        for (int i = 0; i < num.length; i++) {
            graph.add(new ArrayList<>());
        }
        // 构建图,邻接矩阵
        for (int i = 0; i < n - 1; i++) {
            int a = sc.nextInt() - 1;
            int b = sc.nextInt() - 1;
            graph.get(a).add(b);
            graph.get(b).add(a);
        }

        for (int i = 0; i < num.length - 1; i++) {
            dfs(num, i, 1);
        }
        System.out.println(max);

    }

    public static void dfs(int[] num, int start, int wayNum) {
        max = Math.max(max, wayNum);
        for (int i : graph.get(start)) {
            if (num[i] > num[start]) {
                dfs(num, i, wayNum + 1);
            }
        }
    }
}

发表于 2022-07-26 16:44:49 回复(0)
题目实际上是一个有向无环图。因为从小值走大值,所以按权值大小决定输入边的方向。用拓扑排序的方法求出最长路径。
#include <bits/stdc++.h>//ASI
typedef long long ll;
using namespace std;
int n,a[100005],d[100005],dp[100005],ans;
vector<int>e[100005];
int main()
{
    ios::sync_with_stdio(0),cin.tie(0);
    int i,j,x,y;
    cin>>n;
    for(i=1;i<=n;i++)
        cin>>a[i];
    for(i=1;i<n;i++)
    {
        cin>>x>>y;
        if(a[x]<a[y])
        e[x].push_back(y),d[y]++;
        else if(a[x]>a[y])
            e[y].push_back(x),d[x]++;
    }
    queue<int>q;
    for(i=1;i<=n;i++)
        if(!d[i])
        q.push(i);
    while(q.size())
    {
        int t=q.front();
        q.pop();
        for(i=0;i<e[t].size();i++)
        {
            y=e[t][i];
            dp[y]=max(dp[y],dp[t]+1);
            ans=max(ans,dp[y]);
            d[y]--;
            if(!d[y])
                q.push(y);
        }
    }
    cout<<ans+1;
    return 0;
}


发表于 2021-05-12 08:34:29 回复(0)
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

/**
 * @Author: LI
 * @Date: Created in 14:34 2022/9/13
 */

class TreeNode {
    int weight;
    List<TreeNode> friend;

    public TreeNode(int weight) {
        this.weight = weight;
        friend = new ArrayList<>();
    }
}

public class Main {
    static int max = 0;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        TreeNode[] nodes = new TreeNode[n];
        for (int i = 0; i < n; i++) {
            nodes[i] = new TreeNode(sc.nextInt());
        }
        for (int i = 0; i < n - 1; i++) {
            int a = sc.nextInt() - 1;
            int b = sc.nextInt() - 1;
            nodes[a].friend.add(nodes[b]);
            nodes[b].friend.add(nodes[a]);
        }

        for (TreeNode node : nodes) {
            dfs(node, 1);
        }
        System.out.println(max);
    }

    private static void dfs(TreeNode node, int preSum) {
        max = Math.max(max, preSum);
        for (TreeNode friend : node.friend) {
            if (friend.weight > node.weight) {
                dfs(friend, preSum + 1);
            }
        }
    }
}

发表于 2022-09-13 14:55:08 回复(0)
#include<iostream>
#include<vector>
using namespace std;

void dfs(vector<vector<int>>& dp,int i,vector<bool>& vis,vector<int>&res,vector<int>& wei){
    vis[i]=1;
    for(int j=0;j<dp[i].size();++j){
        // 如果这个dp[i][j]节点已经被访问过,或者从节点i出发不能使得节点dp[i][j]更大
        // 或者节点i不能更新dp[i][j]节点,我们直接返回
        if(vis[dp[i][j]] || res[dp[i][j]]>res[i] || wei[dp[i][j]]<=wei[i])
            continue;
        res[dp[i][j]]=res[i]+1;
        dfs(dp,dp[i][j],vis,res,wei);
    }
    vis[i]=0;
}

// 树形DP,类似数组中求最长连续递增字串
int main(int argc,char* argv[]){
    int n;
    cin>>n;
    vector<int> weight(n,0);
    for(int i=0;i<n;++i)
        cin>>weight[i];
    // 邻接表
    vector<vector<int>> dp(n,vector<int>());
    // 防止遍历已经遍历过的节点
    vector<bool> vis(n,0);
    // dp[i]为以节点i为尾的最长连续递增字串的长度
    vector<int> res(n,1);
    int first,second;
    for(int i=0;i<n-1;++i){
        cin>>first>>second;
        dp[first-1].emplace_back(second-1);
        dp[second-1].emplace_back(first-1);
    }
    for(int i=0;i<n;++i){
        dfs(dp,i,vis,res,weight);
    }
    int rs=0;
    // 找dp中最大的那个就是答案
    for(int i=0;i<n;++i)
        rs=max(rs,res[i]);
    cout<<rs<<endl;
    return 0;
}
发表于 2022-03-21 14:18:21 回复(0)
// golang
package main
import (
    "fmt"
)
/*
实际上是舍弃无向图中至少一半的边变成树或森林,
而且最长路径不一定是从树的根节点出发。
*/
func main() {
    n := 0
    _, err := fmt.Scan(&n)
    if err != nil{
        fmt.Println(0)
        return
    }
    w := make([]int, n+1)
    for i:=1;i<=n;i++ {
        wi := 0
        _, err := fmt.Scan(&wi)
        if err != nil{
            break
        }
        w[i] = wi
    }
    m := make(map[int] map[int]struct{})
    child := make([]bool, n+1)
    for {
        var s, t int
        _, err := fmt.Scan(&s, &t)
        if err != nil{
            break
        }
        if w[s] == w[t] {
            continue
        }
        if w[s] > w[t] {
            s, t = t, s
        }
        if _, ok := m[s]; !ok {
            m[s] = map[int]struct{}{}
        }
        m[s][t] = struct{}{}
    }
    ret := 0
    for i:=1;i<=n;i++ {
        last := []int{i}
        cret := 0
        for len(last) > 0 {
            cret++
            cur := []int{}
            for i := range last {
                s := last[i]
                for t, _ := range m[s] {
                    cur = append(cur, t)
                }
            }
            last = cur
        }
        if cret > ret {
            ret = cret
        }
    }
    fmt.Println(ret)
}
编辑于 2023-03-09 22:14:52 回复(0)
def dfs(cur,src,step):
    if book[cur]>step:
        return None
    book[cur] = step
    for neighbor in adjList[cur]:
        if neighbor!=src and weight[neighbor]>weight[cur]:
            dfs(neighbor,cur,step+1)
n = int(input())
weight = [0]
w = input()
for i in w.split(" "):
    weight.append(int(i))
adjList = [[] for i in range(n+1)] #把树当成图,构建邻接表
for i in range(n-1):
    a,b = map(int,input().split(' '))
    adjList[a].append(b)
    adjList[b].append(a)
book = [1 for i in range(n+1)]
for i in range(1,n+1):
    dfs(i,0,1)
    #print(book)
maxStep = 1
for i in range(1,n+1):
    maxStep = max(maxStep,book[i])
print(maxStep)

发表于 2022-03-04 20:55:16 回复(0)
这道题可以用并查集吗
发表于 2021-12-16 15:11:14 回复(1)


import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Main solution = new Main();
        Scanner scanner = new Scanner(System.in);

        int n = scanner.nextInt();
        int[] weight = new int[n];
        for (int i = 0; i < n; i++) {
            weight[i] = scanner.nextInt();
        }
        ArrayList<Integer>[] lists = new ArrayList[n + 1];
        for (int i = 0; i < lists.length; i++) {
            lists[i] = new ArrayList<>();
        }
        for (int i = 0; i < n - 1; i++) {
            int u = scanner.nextInt();
            int v = scanner.nextInt();
            lists[u].add(v);
            lists[v].add(u);
        }

        int res = solution.solve(lists, weight);
        System.out.println(res);

    }


    static ArrayList<Integer> maxLen = new ArrayList<>();
    private int solve(ArrayList<Integer>[] lists, int[] weight) {

        LinkedList<Integer> path = new LinkedList<>();

        for (int i = 1; i < lists.length; i++) {
            path.add(i);
            loopBack(lists, weight, path, i);
            path.removeLast();
        }

        return maxLen.size();
    }

    private void loopBack(ArrayList<Integer>[] lists, int[] weight, LinkedList<Integer> path, int index) {

        if (maxLen.size() < path.size()) {
            maxLen = new ArrayList<>(path);
        }

        int curWeight = weight[index - 1];
        for (Integer integer : lists[index]) {
            if (weight[integer - 1] > curWeight) {
                path.add(integer);
                loopBack(lists, weight, path, integer);
                path.removeLast();
            }
        }
    }
}

发表于 2021-11-13 10:42:46 回复(0)

树形DP

def solve():
    n = int(input())
    weights = list(map(int, input().split()))
    from collections import defaultdict
    edges = defaultdict(list)
    for _ in range(n-1):
        u, v = map(int, input().split())
        edges[u-1].append(v-1)
        edges[v-1].append(u-1)
    dp = [-1] * n 

    def helper(root):
        if dp[root] != -1:
            return dp[root]
        res = 1
        for nxt in edges[root]:
            if weights[root] < weights[nxt]:
                res = max(res, 1+helper(nxt))
        dp[root] = res 
        return res 

    for i in range(n):
        helper(i)
    print(max(dp))
solve()
发表于 2021-10-18 16:03:20 回复(0)
来一个java风格的
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;

public class Main {
    static int[] memo;
    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(reader.readLine());
        int[] weights = new int[n+1];
        memo = new int[n+1];
        String[] strs = reader.readLine().split(" ");
        for (int i = 0; i < n; i++) {
            weights[i+1] = Integer.parseInt(strs[i]);
        }
        int[][] tree = new int[n+1][2];
        for (int i = 0; i < n-1; i++) {
            strs = reader.readLine().split(" ");
            int a = Integer.parseInt(strs[0]);
            int b = Integer.parseInt(strs[1]);
            if (weights[a] > weights[b]) {
                //tree[b][a] = true;
                if (tree[b][0] == 0) {
                    tree[b][0] = a;
                } else {
                    tree[b][1] = a;
                }
            } else if (weights[a] <= weights[b]) {
                //tree[a][b] = true;
                if (tree[a][0] == 0) {
                    tree[a][0] = b;
                } else {
                    tree[a][1] = b;
                }
            }
        }
        int res = 0;
        for (int i = 1; i <= n; i++) {
            int a = dfs(i, tree, n);
            res = Math.max(res, a);
        }
        System.out.println(res);

    }

    private static int dfs(int a, int[][] tree, int n) {
        if (memo[a] != 0) {
            return memo[a];
        }
        int res = 0;
        for (int i = 0; i < 2; i++) {
            if (tree[a][i] != 0) {
                int b = tree[a][i];
                res = Math.max(res, dfs(b, tree, n));
            }
        }
        memo[a] = res + 1;
        return res + 1;
    }
}
9/10 最有一个n=100000的过不去,也不知道为什么
发表于 2021-09-15 15:05:02 回复(0)
import sys
class Node(object):
    def __init__(self, val) -> None:
        self.val = val
        self.childs = []

n = int(input().strip())
nodes = list(map(lambda x:Node(val=int(x)), input().strip().split()))
for line in sys.stdin:
    if line.strip() == "":
        break
    father, child = map(int, line.strip().split())
    father, child = child-1, father-1
    nodes[father].childs.append(nodes[child])

res = float('-inf')

def dfs(node):
    global res
    if node == None:
        return 0, 0
    inc, dec = 0, 0
    
    for child in node.childs:
        cinc, cdec = dfs(child)
        if node.val > child.val:
            inc = max(inc, cinc)
        elif node.val < child.val:
            dec = max(dec, cdec)

    res = max(res, inc + dec + 1)
    #print(res)
    return inc + 1, dec + 1

dfs(nodes[0])
# print("result: {}".format(res))
print(res)


编辑于 2021-09-06 18:03:16 回复(0)