首页 > 试题广场 >

树上最短链

[编程题]树上最短链
  • 热度指数:2384 时间限制:C/C++ 2秒,其他语言4秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
在一个地区有 个城市以及 条无向边,每条边的时间边权都是 ,并且这些城市是联通的,即这个地区形成了一个树状结构。每个城市有一个等级。
现在小强想从一个城市走到另一个不同的城市,并且每条边经过至多一次,同时他还有一个要求,起点和终点城市可以任意选择,但是等级必须是相同的。
但是小强不喜欢走特别远的道路,所以他想知道时间花费最小是多少。

进阶:时间复杂度,空间复杂度

输入描述:
第一行一个正整数 ,含义如题面所述。
第二行  个正整数 ,代表每个城市的等级。
接下来 行每行两个正整数 ,代表一条无向边。
保证给出的图是一棵树。




输出描述:
仅一行一个整数代表答案,如果无法满足要求,输出 
示例1

输入

3
1 2 1
1 2
2 3

输出

2
根据WwSsXx的解法,换成BFS,感谢dalao
    import java.util.*;
    import java.math.*;
    public class Main{
        static int[] level;
        static ArrayList<Integer>[] lists;
        static int res = Integer.MAX_VALUE;
        public static void main(String []args){
            Scanner in = new Scanner(System.in);
            int n = in.nextInt();
            level = new int[n];
            lists = new ArrayList[n];
            for(int i=0;i<n;i++){
                level[i] = in.nextInt();
                lists[i] = new ArrayList<Integer>();
            }
            for(int i=0;i<n-1;i++){
                int x = in.nextInt()-1;
                int y = in.nextInt()-1;
                lists[x].add(y);
                lists[y].add(x);
            }
            for(int i=0;i<n;i++){
                Queue<Integer> que = new LinkedList<>();
                boolean []vis = new boolean[n];
                que.offer(i);
                vis[i] = true;
                int length = 0;
                while(!que.isEmpty()){
                    int size = que.size();
                    int flag= 0;
                    for(int j=0;j<size;j++){
                        int temp = que.poll();
                        if(temp!=i&&level[temp]==level[i]){
                            res = Math.min(res,length);
                            flag =1;
                            break;
                        }
                        for(int x:lists[temp]){
                            if(!vis[x]){
                                que.offer(x);
                                vis[x] = true;
                            }
                        }
                    }
                    if(flag==1) break;
                    length++;
                }
            }
            if(res==Integer.MAX_VALUE){
                res = -1;
            }
            System.out.println(res);
        }
    }


发表于 2021-06-10 16:48:14 回复(0)
需要注意的是,这道题目卡常,下面的代码第一个开动态数组的能过,开静态数组的不给过。
尝试了一下堆优化版的dijstra,同样也不给过。
\\bfs 能过
#include <bits/stdc++.h>
using namespace std;

#define ull unsigned long long;
#define pi 3.14;

typedef long long LL;
typedef pair<int,int> PII;
const int mod=1e9+7;
const double E=1e-8;
const int N=5001;
vector<vector<int>> arr;

int n,res=INT_MAX;
unordered_map<int,int> mp_2;

void solution(){

}

void bfs(int start){
    vector<bool> st(n+1);
    vector<int> dist(n+1);
    if(res==1) return;
    queue<int> q;
    q.push(start);
    dist[start]=0;
    st[start]=true;
    while(q.size()){
        auto t=q.front();
        q.pop();

        if(t!=start&&mp_2[t]==mp_2[start]){
            res=min(res,dist[t]);
            break;
        }
        for(auto &x:arr[t]){
            if(!st[x]){
                st[x]=true;
                dist[x]=dist[t]+1;
                q.push(x);
            }
        }
    } 

}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin>>n;
    arr=vector<vector<int>>(n+1);
    for(int i=1;i<=n;i++){
        int c;
        cin>>c;
        mp_2[i]=c;
    }
    for(int i=0;i<n-1;i++){
        int a,b;
        cin>>a>>b;
        arr[a].push_back(b);
        arr[b].push_back(a);
    }
    for(int i=1;i<=n;i++){
        bfs(i);
    }
    if(res==INT_MAX) res=-1;
    cout<<res<<endl;
    return 0;
}
//开静态数组不给过
#include <bits/stdc++.h>
using namespace std;

#define ull unsigned long long;
#define pi 3.14;

typedef long long LL;
typedef pair<int,int> PII;
const int mod=1e9+7;
const double E=1e-8;
const int N=5001;
int e[N],ne[N],h[N],dist[N],idx;
bool st[N];
int n,res=INT_MAX;
unordered_map<int,int> mp_2;
void add(int a,int b){
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}

void solution(){

}

void bfs(int start){
    memset(st,0,sizeof st);
    memset(dist,0,sizeof dist);
    if(res==1) return;
    queue<int> q;
    q.push(start);
    dist[start]=0;
    st[start]=true;
    while(q.size()){
        auto t=q.front();
        q.pop();

        if(t!=start&&mp_2[t]==mp_2[start]){
            res=min(res,dist[t]);
            break;
        }
        for(int i=h[t];~i;i=ne[i]){
            int j=e[i];
            if(!st[j]){
                dist[j]=dist[t]+1;
                q.push(j);
                st[j]=true;
            } 
        }
    } 

}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    idx=0;
    memset(h,-1,sizeof h);
    cin>>n;
    for(int i=1;i<=n;i++){
        int c;
        cin>>c;
        mp_2[i]=c;
    }
    for(int i=0;i<n-1;i++){
        int a,b;
        cin>>a>>b;
        add(a,b);
        add(b,a);
    }
    for(int i=1;i<=n;i++){
        bfs(i);
    }
    cout<<res<<endl;
    return 0;
}
\\堆优化版的最短路
#include <bits/stdc++.h>
using namespace std;

#define ull unsigned long long;
#define pi 3.14;

typedef long long LL;
typedef pair<int,int> PII;
const int mod=1e9+7;
const double E=1e-8;
const int N=5001;
vector<vector<int>> arr;

int n,res=INT_MAX;
unordered_map<int,vector<int>> mp;
unordered_map<int,int> mp_2;

void solution(){

}

void dijstra(int start){
    if(res==1) return;
    vector<int> dist(n+1,0x3f3f3f3f);
    vector<bool> st(n+1);
    priority_queue<PII,vector<PII>,greater<PII>> q;
    q.push({0,start});
    while(q.size()){
        auto t=q.top();
        q.pop();

        int node=t.second,d=t.first;
        if(st[node]) continue;
        st[node]=true;

        for(auto &x:arr[node]){
            if(dist[x]>d+1){
                dist[x]=d+1;
                q.push({dist[x],x});
            }
        }
    }
    int c=mp_2[start];
    for(auto &x:mp[c]){
        if(x!=start) res=min(res,dist[x]);
        if(res==1) return;
    } 
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin>>n;
    arr=vector<vector<int>>(n+1);
    for(int i=1;i<=n;i++){
        int c;
        cin>>c;
        mp[c].push_back(i);
        mp_2[i]=c;
    }
    for(int i=0;i<n-1;i++){
        int a,b;
        cin>>a>>b;
        arr[a].push_back(b);
        arr[b].push_back(a);
    }
    for(int i=1;i<=n;i++){
        dijstra(i);
    }
    if(res==INT_MAX||res==0x3f3f3f3f) res=-1;
    cout<<res<<endl;
    return 0;
}
发表于 2022-04-04 21:05:47 回复(0)
链式前向星存储,空间复杂度O(N)
n个BFS,时间复杂度O(N2)
//极小连通子图=生成树 任意两点之间有且只有一条路径 且图中没有环
import java.util.*;
public class Main{
    public static void main(String[]args){
        Scanner sc=new Scanner(System.in);
        int n=sc.nextInt();
        int[]grade=new int[n+1];
        for(int i=0;i<n;i++)
            grade[i+1]=sc.nextInt();
        //链式前向星存储图
        int[][]edge=new int[2*n-2][2];    //边表,从0开始
        int[]head=new int[n+1];    //指向所连接的第一条边,从1开始
        Arrays.fill(head,-1);
        int cnt=0;    //边的数量
        for(int i=0;i<n-1;i++){
            int a=sc.nextInt();
            int b=sc.nextInt();
            //添加两条边
            edge[cnt][0]=b;
            edge[cnt][1]=head[a];
            head[a]=cnt++;
            edge[cnt][0]=a;
            edge[cnt][1]=head[b];
            head[b]=cnt++;
        }
        int ans=0x3fffffff;
        for(int i=1;i<=n;i++){    //计算i到其他所有节点的最短距离
            boolean[]visit=new boolean[n+1];
            Queue<int[]>que=new LinkedList<>();
            que.offer(new int[]{i,0});  
            visit[i]=true;
            while(!que.isEmpty()){
                int[]f=que.poll();
                int k=f[0];
                int h=f[1];
                if(k!=i && grade[k]==grade[i]){
                    ans=Math.min(ans,h);
                }
                int t=head[k];
                while(t!=-1){    //遍历所有i能到达的节点
                    int e=edge[t][0];
                    if(!visit[e]){
                        visit[e]=true;
                        que.offer(new int[]{e,h+1});
                    }
                    t=edge[t][1];
                }
            }
        }
        if(ans!=0x3fffffff)
            System.out.println(ans);
        else
            System.out.println(-1);
    }
}


发表于 2022-06-28 17:34:17 回复(0)
import collections
n = int(input())
edge = [[] for i in range(n)]
a=list(map(int, input().split()))
for i in range(n-1):
    x,y = map(int, input().split())
    edge[x-1].append(y-1)
    edge[y-1].append(x-1)
ans = n
d = collections.defaultdict(list)
for i in range(n):
    d[a[i]].append(i)
if len(d)==len(a):
    print(-1)
else:
    for k,v in d.items():
        if len(v)==1:
            continue
        else:
            for root in v:
                stack = [root]
                deep = 0
                visited=set()
                while stack:
                    tmp = []
                    while stack:
                        node = stack.pop()
                        visited.add(node)
                        if a[node]==a[root] and node!=root:
                            ans = min(ans, deep)
                            break
                        for i in edge[node]:
                            if i not in visited:
                                tmp.append(i)
                    deep+=1
                    stack = tmp.copy()
                    if deep>=ans:
                        break
    print(ans)

发表于 2021-08-02 11:05:56 回复(0)
5000的数据量直接dfs就行了。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
int n,a[5005],root,ans=1e9;
vector<int> e[5005];
void dfs(int r,int f,int deep)
{
    if(r!=root&&a[r]==a[root])
        ans=min(ans,deep);
    for(int i=0; i<e[r].size(); i++)
        if(e[r][i]!=f)
            dfs(e[r][i],r,deep+1);
}
int main()
{
    ios::sync_with_stdio(0),cin.tie(0);
    int i,j,x,y;
    cin>>n;
    for(i=1; i<=n; i++)
        cin>>a[i];
    for(i=1; i<n; i++)
    {
        cin>>x>>y;
        e[x].push_back(y);
        e[y].push_back(x);
    }
    for(i=1; i<=n; i++)
    {
        root=i;
        dfs(i,0,0);
    }
    cout<< (ans==1e9?-1:ans);
    return 0;
}

发表于 2021-05-02 18:55:37 回复(0)
题目给的复杂度很明显,我们需要遍历点对,点对满足等级相等的直接计算距离,套用LCA求最近公共祖先,用深度求距离即可。
//
// Created by SANZONG on 2020/8/10.
//
//lca模板

#include "bits/stdc++.h"

using namespace std;
const int N = 500005;
int cnt, head[N * 2];
int lg[N];
int f[N][50];
int depth[N];
struct node {
    int to, next;
} a[2 * N];
int lev[N];

void add(int u, int v) {
    a[++cnt].next = head[u];
    a[cnt].to = v;
    head[u] = cnt;
}

void dfs(int u, int fa) {
    f[u][0] = fa;
    depth[u] = depth[fa] + 1;
    for (int i = 1; i <= lg[depth[u]]; ++i)       //2^i = 2^(i-1)+2^(i-1)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for (int i = head[u]; i; i = a[i].next) {
        int v = a[i].to;
        if (v == fa) continue;
        dfs(v, u);
    }
}

int LCA(int x, int y) {
    //先到同深度,再俩个一起找最近点。
    if (depth[x] < depth[y]) {
        swap(x, y);
    }
    while (depth[x] > depth[y]) {
        x = f[x][lg[depth[x] - depth[y]]];
    }
    if (x == y)
        return x;
    for (int k = lg[depth[x]]; k >= 0; --k) {
        if (f[x][k] != f[y][k]) {
            x = f[x][k];
            y = f[y][k];
        }
    }
    return f[x][0];
}

int main() {
    int n;
    cin >> n;
    for (int j = 1; j <= n; ++j) {
        cin >> lev[j];
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        add(u, v);
        add(v, u);
    }
    for (int i = 0; i <= n; ++i) {
        lg[i] = i == 0 ? -1 : lg[i >> 1] + 1;
    }
    dfs(1, 0);
    int mi = 1e9;
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j < i; ++j) {
            if (lev[i] == lev[j])
                mi = min(mi, depth[i] + depth[j] - 2 * depth[LCA(i, j)]);
        }
    }
    
    cout << (mi>=1e9?-1:mi) << endl;
}


发表于 2022-02-04 19:00:29 回复(1)
#include<iostream>
#include<bits/stdc++.h>
using namespace std;
//其实就是求深度(因为节点的长度是1)图论中的深度
vector<int> g[5005];
int rank_d[5005];
int ans = INT_MAX;
int root = 0;

void dfs(int cur,int before,int deep){
   
    if(cur!=root&&rank_d[cur]==rank_d[root]){
        ans = min(ans,deep);
    }
    //遍历节点下的所有元素
    for(int i=0;i<g[cur].size();i++){
       if(g[cur][i]!=before){
           //这要保证 不与上一个节点重复即可
           dfs(g[cur][i],cur,deep+1);
           
       }
    }
}
int main(){
    int n;
    cin>>n;
    //等级
    for(int i=1;i<=n;i++){
        int temp = 0;
        cin>>temp;
        rank_d[i] = temp;
     }
    //处理图
    for(int j=1;j<n;j++){
        int x,y;
        cin>>x>>y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    
    //深度遍历图的每一个节点
    for(int i=1;i<=n;i++){
        root  = i;
        dfs(i,0,0);
    }
    
    cout<<(ans==INT_MAX?-1:ans);
    
    return 0;
}

编辑于 2021-08-15 21:58:29 回复(0)
import java.util.*;

public class Niuke7 {
    public static void main(String[] args) {
        Scanner s = new Scanner(System.in);
        int n = s.nextInt();
        //节点
        int[] nodes = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            nodes[i] = s.nextInt();
        }
        //边,用邻接表表示图
        List<List<Integer>> edgs = new ArrayList<>();
        for (int i = 0; i <= n; i++) {
            edgs.add(new ArrayList<>());
        }
        for(int i = 0; i < n - 1; i++){
            int u = s.nextInt();
            int v = s.nextInt();
            //构造无向图,应该双向表示
            edgs.get(u).add(v);
            edgs.get(v).add(u);
        }
        int min = Integer.MAX_VALUE;
        for (int i = 1; i <= n; i++) {
            min = bfs(edgs, nodes, i) == -1? min : Math.min(min, bfs(edgs, nodes, i));
        }
        System.out.println(min == Integer.MAX_VALUE ? -1 : min);
    }
    public static int bfs(List<List<Integer>> edgs, int[] nodes, int i){
        boolean[] used = new boolean[nodes.length];
        Queue<int[]> que = new LinkedList<>();
        que.add(new int[]{i, 0});
        used[i] = true;
        while(!que.isEmpty()){
            int[] node = que.poll();
            int index = node[0];
            int path = node[1];
            for(int v : edgs.get(index)){
                if(used[v]){
                    continue;
                }
                if(nodes[v] == nodes[i]){
                    return path + 1;
                }
                used[v] = true;
                que.add(new int[]{v, path + 1});
            }
        }
        return -1;
    }
}

发表于 2022-07-19 16:13:12 回复(0)
简单的DFS
#include <bits/stdc++.h>
#include <climits>
#include <vector>
using namespace std;
vector<int> A;
vector<vector<int>> G;
int mintime=INT_MAX;
//从u结点出发找值value的点并返回最小步数,找不到就是INT_MAX
int dfs(int u,int par,int root){
    if(u!=root&&A[u]==A[root]) return 0;
    int steps=INT_MAX;
    for(int v:G[u]){
        if(v==par) continue;
        steps=min(steps,dfs(v,u,root));
    }
    if(steps==INT_MAX) return steps;
    else return steps+1;
}

int main() {
    int n;cin>>n;
    A.resize(n+1,0);
    G.resize(n+1,vector<int> (0));
    for(int i=1;i<=n;i++){
        cin>>A[i];
    }
    for(int i=1;i<=n-1;i++){
        int u,v;
        cin>>u>>v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    for(int i=1;i<=n;i++){
       mintime=min(mintime, dfs(i,0,i));
    }
    int res= mintime==INT_MAX?-1:mintime;
    cout<<res<<endl;
}
// 64 位输出请用 printf("%lld")


发表于 2023-04-16 19:48:27 回复(0)
BFS
#include <climits>
#include <iostream>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <queue>
using namespace std;
unordered_map<int,unordered_set<int>> grades;
int n=0;

int minPath(unordered_set<int>& points, vector<vector<int>>& edges){
    int res=INT_MAX;
    for(const int& point:points){
        vector<bool> visited(n+1,false);
        int minPath=0;
        queue<int> q;
        q.push(point);
        visited[point]=true;
        while(!q.empty()){
            int len=q.size();
            for(int i=0; i<len; ++i){
                int city=q.front();
                q.pop();
                for(int& next:edges[city]){
                    if(!visited[next]){
                        visited[next]=true;
                        if(points.find(next)!=points.end()){
                            q=queue<int>();
                            i=len;
                            break;
                        }
                        else
                            q.push(next);
                    }
                }
            }
            ++minPath;
        }
        res=min(res,minPath);
    }
    return res;
}

int main() {
    int res=INT_MAX, k=1;
    cin>>n;
    vector<vector<int>> edge(n+1,vector<int>());
    while(k<=n){
        int input;
        cin>>input;
        grades[input].insert(k++);
    }
    k=1;    
    while(k<n){
        int u,v;
        cin>>u>>v;
        edge[u].push_back(v);
        edge[v].push_back(u);
        k++;
    }
    for(auto& [_,points]:grades){
        int len=points.size();
        if(len>1){
            res=min(res,minPath(points,edge));
        }
    }
    res=res==INT_MAX ? -1 : res;
    cout<<res;
}
发表于 2023-03-19 16:26:37 回复(0)
BFS居然就过了
#include <vector>
#include <limits.h>
#include <iostream>
#include <queue>

std::vector<int> T[5000];
std::vector<int> V;
int n;

int main()
{
	std::cin >> n;
	for (int i = 0; i < n; i++)
	{
		int x;
		std::cin >> x;
		V.push_back(x);
	}

	for (int i = 0; i < n-1; i++)
	{
		int x,y;
		std::cin >> x >> y;
		T[x - 1].push_back(y - 1);
		T[y - 1].push_back(x - 1);
	}

	int min = INT_MAX;
	for (int i = 0; i < n-1; i++)
	{
		int target_v = V[i];
		std::vector<int> vis(n);
		vis[i] = true;
		std::fill(vis.begin(), vis.end(), false);
		std::queue<std::pair<int, int> > Q; // <node, step>
		Q.push(std::pair<int, int>(i, 0));
		while (!Q.empty())
		{
			std::pair<int, int> node = Q.front();
			Q.pop();
			if (V[node.first] == target_v && node.first != i)
			{
				min = std::min(min, node.second);
				break;
			}
			for (size_t j = 0; j < T[node.first].size(); j++)
			{
				if (!vis[T[node.first][j]])
				{
					vis[T[node.first][j]] = true;
					Q.push(std::pair<int, int>(T[node.first][j], node.second + 1));
				}
			}
		}
	}

	if (min == INT_MAX)
		std::cout << -1;
	else
		std::cout << min;

	return 0;
}


发表于 2022-09-08 21:24:17 回复(0)
def bfs(city:int, djs:list, bian:dict, n: int, dj:int):
    # 从城市city出发,广度优先搜索寻找dj的城市
    visited = [False] * n
    length = 0
    q = [(city, 0)]  # (城市,长度)
    visited[city] = True
    while len(q) > 0:
        length += 1  #
        city, s = q.pop(0)
        for i in bian[city]:
            if visited[i]:  # 已经访问过了
                continue
            if djs[i] == dj:
                return s+1
            visited[i] = True
            q.append((i, s+1))
    return n


def func(dengji:dict, djs:list, bian:dict, n:int):
    # dengji是[dengji, [城市编号,]] bian是(城市编号:[城市编号])
    length = n
    for dj, cities in dengji.items():
        if len(cities) <= 1:
            # 该等级只有一个城市
            continue
        for city in cities:  # 该等级有多个城市,则从城市city出发宽度优先搜索
            length = min(bfs(city, djs, bian, n, dj), length)
    return length


def main():
    # 城市编号居然是从1开始
    n = int(input())  # n个城市
    dengji = list(map(int, input().split()))  # n个城市的等级
    bian = {i: [] for i in range(n)}  # key是城市编号,value是一个list,表示相邻城市
    for _ in range(n-1):
        u, v = list(map(int, input().split()))
        bian[u-1].append(v-1)
        bian[v-1].append(u-1)
    dengji_dict = {}  # key是等级,item是城市编号列表
    for i, dj in enumerate(dengji):
        if dj in dengji_dict:
            dengji_dict[dj].append(i)
        else:
            dengji_dict[dj] = [i]
    if len(dengji_dict) == n:  # n个城市有n个等级,
        return -1
    return func(dengji_dict, dengji, bian, n)

print(main())

发表于 2022-05-11 12:26:25 回复(0)

结合以上大佬们给出的题解,遍历每一个节点,使用bfs寻找任意两节点的最短路径。题目中增加限制条件:节点要在同一级,所以使用领接表存储每个节点相连的节点。

import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int[] grade = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            grade[i] = in.nextInt();
        }
        List> edges = new ArrayList();
        for (int i = 0; i <= n; i++) {
            edges.add(new ArrayList());
        }
        for (int i = 0; i < n - 1; i++) {
            int u = in.nextInt();
            int v = in.nextInt();
            // 双向边
            edges.get(u).add(v);
            edges.get(v).add(u);
        }
        int min = Integer.MAX_VALUE;
        for (int i = 1; i <= n; i++) {
            min = Math.min(min, bfs(edges, grade, i));
        }
        System.out.println(min == Integer.MAX_VALUE ? -1 : min);
    }
    private static int bfs(List> edges, int[] grade, int st) {
        Queue queue = new LinkedList();
        int n = grade.length;
        boolean[] marked = new boolean[n];
        queue.add(st);
        marked[st] = true;
        int cur = 0;
        while (!queue.isEmpty()) {
            int size = queue.size();
            cur++;
            while (size-- > 0) {
                int t = queue.poll();
                for (int x : edges.get(t)) {
                    if (marked[x]) continue;
                    if (grade[x] == grade[st])
                        return cur;
                    queue.add(x);
                    marked[x] = true;
                }
            }
        }
        return Integer.MAX_VALUE;
    }
}
发表于 2022-04-19 22:48:34 回复(0)

O(n^2)暴力

  • 计算以每个点为根的与根节点等级相同的节点的最小深度
  • 每次DFS复杂度为O(n), 共DFSn
#include <bits/stdc++.h>

using namespace std;
const int N = 5010;
int A[N];
vector<int> g[N];
int ans = INT_MAX;

void dfs(int u, int fa, int root, int dep){
    if (A[u] == A[root] and u != root) 
        ans = min(ans, dep);

    for (auto& v : g[u]) {
        if (v == fa)
            continue;
        dfs(v, u, root, dep + 1);
    }
}

void solve(){
    int n;
    cin >> n;
    for (int i = 1; i <= n; i ++ )
        cin >> A[i];
    for (int i = 1; i < n; i ++ ) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for (int i = 1; i <= n; i ++ )
        dfs(i, -1, i, 0);

    if (ans == INT_MAX)
        ans = -1;
    cout << ans << endl;
}

int main()
{   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    int t;
    // cin >> t;
    t = 1;
    while (t -- )
        solve();
    return 0;
}
发表于 2021-10-19 10:46:34 回复(0)

基础DFS,JS版本

let readline = require('readline');
let rl = readline.createInterface({
  input: process.stdin,
  output: process.stdout
});
let n, level, grid;
let row = 1;
rl.on('line', (line) => {
  if (row === 1) {
    n = +line;
    grid = new Array(n + 1).fill('').map(() => { return new Array() });
  } else if (row === 2) {
    level = line.split(' ').map((item, index) => { return +item });
  } else if (row > 2 && row <= 2 + n - 1) {
    let [x, y] = line.split(' ').map((item, index) => { return +item });
    grid[x].push(y);
    grid[y].push(x);
  }
  if (row === 2 + n - 1) {
    let res = leastCost(n);
    console.log(res);
    rl.close();
  }
  row++;

  function leastCost(n) {
    let root, ans = Infinity;
    // console.log(grid);
    for (let i = 1; i <= n; i++) {
      root = i;
      dfs(i, 0, 0);
    }
    if(ans===Infinity){
      return -1;
    }
    return ans;

    /**
     * 
     * @param {*} r 当前dfs的出发点
     * @param {*} f 上一个访问过的点
     * @param {*} deep 此时的遍历深度
     */
    function dfs(r, f, deep) {
      if (r !== root && level[r - 1] === level[root - 1]) {
        ans = Math.min(ans, deep);
      }
      for (let i = 0; i < grid[r].length; i++) {
        if (grid[r][i] !== f) {
          dfs(grid[r][i], r, deep + 1);
        }
      }
    }
  }
});
发表于 2021-09-05 14:20:08 回复(0)
import java.util.*;

public class Main {
    
    private static int N, root, res;
    private static int[] A;
    private static Map<Integer, List<Integer>> g;
    private static boolean[] vis;
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        // cite
        N = sc.nextInt();

        // level
        A = new int[N + 1];
        for (int i = 1; i <= N; i++) {
            A[i] = sc.nextInt();
        }

        // graph
        g = new HashMap<>();
        for (int i = 1; i < N; i++) {
            int a = sc.nextInt(), b = sc.nextInt();
            g.computeIfAbsent(a, k -> new ArrayList<>()).add(b);
            g.computeIfAbsent(b, k -> new ArrayList<>()).add(a);
        }

        // calculate
        res = 0x3f3f3f3f;
        for (int i = 1; i <= N; i++) {
            root = i;
            vis = new boolean[N + 1];
            dfs(i, 0);
        }

        System.out.println(res == 0x3f3f3f3f ? -1 : res);
    }
    
    public static void dfs(int cur, int level) {
        // pruning
        if (vis[cur]) return;

        // exit
        if (cur != root && A[cur] == A[root]) {
            res = Math.min(res, level);
            return;
        }

        vis[cur] = true;

        // dfs
        for (int next : g.get(cur)) {
            dfs(next, level + 1);
        }
    }
}

发表于 2021-08-13 15:11:52 回复(0)
import sys
sys.setrecursionlimit(100000)
 
 
n = int(input())
ar = [int(i) for i in input().split(" ")]
 
 
global ans
 
def dfs(root, fa):
    global ans
    tmp = {ar[root-1]:0}
    for u in tree[root]:
        if u == fa: continue
        son_data = dfs(u, root)
        for key in son_data:
            if key in tmp:
                ans = min(ans, son_data[key] + tmp[key])
                tmp[key] = min(tmp[key], son_data[key])
            else: tmp[key] = son_data[key]
    for key in tmp: tmp[key]+=1
    return tmp
 
 
tree = {}
for _ in range(n-1):
    [a, b] = [int(i) for i in input().split(" ")]
    if not (a in tree): tree[a] = []
    tree[a].append(b)
    if not (b in tree): tree[b] = []
    tree[b].append(a)
ans = 50001
dfs(1, -1)
if ans == 50001: print(-1)
else: print(ans)

发表于 2021-04-28 10:51:16 回复(0)

热门推荐

通过挑战的用户