现在小强想从一个城市走到另一个不同的城市,并且每条边经过至多一次,同时他还有一个要求,起点和终点城市可以任意选择,但是等级必须是相同的。
但是小强不喜欢走特别远的道路,所以他想知道时间花费最小是多少。
进阶:时间复杂度
,空间复杂度%5C)
第一行一个正整数,含义如题面所述。
第二行个正整数
,代表每个城市的等级。
接下来行每行两个正整数
,代表一条无向边。
保证给出的图是一棵树。。
。
。
仅一行一个整数代表答案,如果无法满足要求,输出。
3 1 2 1 1 2 2 3
2
import java.util.*;
import java.math.*;
public class Main{
static int[] level;
static ArrayList<Integer>[] lists;
static int res = Integer.MAX_VALUE;
public static void main(String []args){
Scanner in = new Scanner(System.in);
int n = in.nextInt();
level = new int[n];
lists = new ArrayList[n];
for(int i=0;i<n;i++){
level[i] = in.nextInt();
lists[i] = new ArrayList<Integer>();
}
for(int i=0;i<n-1;i++){
int x = in.nextInt()-1;
int y = in.nextInt()-1;
lists[x].add(y);
lists[y].add(x);
}
for(int i=0;i<n;i++){
Queue<Integer> que = new LinkedList<>();
boolean []vis = new boolean[n];
que.offer(i);
vis[i] = true;
int length = 0;
while(!que.isEmpty()){
int size = que.size();
int flag= 0;
for(int j=0;j<size;j++){
int temp = que.poll();
if(temp!=i&&level[temp]==level[i]){
res = Math.min(res,length);
flag =1;
break;
}
for(int x:lists[temp]){
if(!vis[x]){
que.offer(x);
vis[x] = true;
}
}
}
if(flag==1) break;
length++;
}
}
if(res==Integer.MAX_VALUE){
res = -1;
}
System.out.println(res);
}
} \\bfs 能过
#include <bits/stdc++.h>
using namespace std;
#define ull unsigned long long;
#define pi 3.14;
typedef long long LL;
typedef pair<int,int> PII;
const int mod=1e9+7;
const double E=1e-8;
const int N=5001;
vector<vector<int>> arr;
int n,res=INT_MAX;
unordered_map<int,int> mp_2;
void solution(){
}
void bfs(int start){
vector<bool> st(n+1);
vector<int> dist(n+1);
if(res==1) return;
queue<int> q;
q.push(start);
dist[start]=0;
st[start]=true;
while(q.size()){
auto t=q.front();
q.pop();
if(t!=start&&mp_2[t]==mp_2[start]){
res=min(res,dist[t]);
break;
}
for(auto &x:arr[t]){
if(!st[x]){
st[x]=true;
dist[x]=dist[t]+1;
q.push(x);
}
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n;
arr=vector<vector<int>>(n+1);
for(int i=1;i<=n;i++){
int c;
cin>>c;
mp_2[i]=c;
}
for(int i=0;i<n-1;i++){
int a,b;
cin>>a>>b;
arr[a].push_back(b);
arr[b].push_back(a);
}
for(int i=1;i<=n;i++){
bfs(i);
}
if(res==INT_MAX) res=-1;
cout<<res<<endl;
return 0;
}
//开静态数组不给过
#include <bits/stdc++.h>
using namespace std;
#define ull unsigned long long;
#define pi 3.14;
typedef long long LL;
typedef pair<int,int> PII;
const int mod=1e9+7;
const double E=1e-8;
const int N=5001;
int e[N],ne[N],h[N],dist[N],idx;
bool st[N];
int n,res=INT_MAX;
unordered_map<int,int> mp_2;
void add(int a,int b){
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void solution(){
}
void bfs(int start){
memset(st,0,sizeof st);
memset(dist,0,sizeof dist);
if(res==1) return;
queue<int> q;
q.push(start);
dist[start]=0;
st[start]=true;
while(q.size()){
auto t=q.front();
q.pop();
if(t!=start&&mp_2[t]==mp_2[start]){
res=min(res,dist[t]);
break;
}
for(int i=h[t];~i;i=ne[i]){
int j=e[i];
if(!st[j]){
dist[j]=dist[t]+1;
q.push(j);
st[j]=true;
}
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
idx=0;
memset(h,-1,sizeof h);
cin>>n;
for(int i=1;i<=n;i++){
int c;
cin>>c;
mp_2[i]=c;
}
for(int i=0;i<n-1;i++){
int a,b;
cin>>a>>b;
add(a,b);
add(b,a);
}
for(int i=1;i<=n;i++){
bfs(i);
}
cout<<res<<endl;
return 0;
}
\\堆优化版的最短路
#include <bits/stdc++.h>
using namespace std;
#define ull unsigned long long;
#define pi 3.14;
typedef long long LL;
typedef pair<int,int> PII;
const int mod=1e9+7;
const double E=1e-8;
const int N=5001;
vector<vector<int>> arr;
int n,res=INT_MAX;
unordered_map<int,vector<int>> mp;
unordered_map<int,int> mp_2;
void solution(){
}
void dijstra(int start){
if(res==1) return;
vector<int> dist(n+1,0x3f3f3f3f);
vector<bool> st(n+1);
priority_queue<PII,vector<PII>,greater<PII>> q;
q.push({0,start});
while(q.size()){
auto t=q.top();
q.pop();
int node=t.second,d=t.first;
if(st[node]) continue;
st[node]=true;
for(auto &x:arr[node]){
if(dist[x]>d+1){
dist[x]=d+1;
q.push({dist[x],x});
}
}
}
int c=mp_2[start];
for(auto &x:mp[c]){
if(x!=start) res=min(res,dist[x]);
if(res==1) return;
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n;
arr=vector<vector<int>>(n+1);
for(int i=1;i<=n;i++){
int c;
cin>>c;
mp[c].push_back(i);
mp_2[i]=c;
}
for(int i=0;i<n-1;i++){
int a,b;
cin>>a>>b;
arr[a].push_back(b);
arr[b].push_back(a);
}
for(int i=1;i<=n;i++){
dijstra(i);
}
if(res==INT_MAX||res==0x3f3f3f3f) res=-1;
cout<<res<<endl;
return 0;
}
//极小连通子图=生成树 任意两点之间有且只有一条路径 且图中没有环
import java.util.*;
public class Main{
public static void main(String[]args){
Scanner sc=new Scanner(System.in);
int n=sc.nextInt();
int[]grade=new int[n+1];
for(int i=0;i<n;i++)
grade[i+1]=sc.nextInt();
//链式前向星存储图
int[][]edge=new int[2*n-2][2]; //边表,从0开始
int[]head=new int[n+1]; //指向所连接的第一条边,从1开始
Arrays.fill(head,-1);
int cnt=0; //边的数量
for(int i=0;i<n-1;i++){
int a=sc.nextInt();
int b=sc.nextInt();
//添加两条边
edge[cnt][0]=b;
edge[cnt][1]=head[a];
head[a]=cnt++;
edge[cnt][0]=a;
edge[cnt][1]=head[b];
head[b]=cnt++;
}
int ans=0x3fffffff;
for(int i=1;i<=n;i++){ //计算i到其他所有节点的最短距离
boolean[]visit=new boolean[n+1];
Queue<int[]>que=new LinkedList<>();
que.offer(new int[]{i,0});
visit[i]=true;
while(!que.isEmpty()){
int[]f=que.poll();
int k=f[0];
int h=f[1];
if(k!=i && grade[k]==grade[i]){
ans=Math.min(ans,h);
}
int t=head[k];
while(t!=-1){ //遍历所有i能到达的节点
int e=edge[t][0];
if(!visit[e]){
visit[e]=true;
que.offer(new int[]{e,h+1});
}
t=edge[t][1];
}
}
}
if(ans!=0x3fffffff)
System.out.println(ans);
else
System.out.println(-1);
}
} import collections n = int(input()) edge = [[] for i in range(n)] a=list(map(int, input().split())) for i in range(n-1): x,y = map(int, input().split()) edge[x-1].append(y-1) edge[y-1].append(x-1) ans = n d = collections.defaultdict(list) for i in range(n): d[a[i]].append(i) if len(d)==len(a): print(-1) else: for k,v in d.items(): if len(v)==1: continue else: for root in v: stack = [root] deep = 0 visited=set() while stack: tmp = [] while stack: node = stack.pop() visited.add(node) if a[node]==a[root] and node!=root: ans = min(ans, deep) break for i in edge[node]: if i not in visited: tmp.append(i) deep+=1 stack = tmp.copy() if deep>=ans: break print(ans)
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
int n,a[5005],root,ans=1e9;
vector<int> e[5005];
void dfs(int r,int f,int deep)
{
if(r!=root&&a[r]==a[root])
ans=min(ans,deep);
for(int i=0; i<e[r].size(); i++)
if(e[r][i]!=f)
dfs(e[r][i],r,deep+1);
}
int main()
{
ios::sync_with_stdio(0),cin.tie(0);
int i,j,x,y;
cin>>n;
for(i=1; i<=n; i++)
cin>>a[i];
for(i=1; i<n; i++)
{
cin>>x>>y;
e[x].push_back(y);
e[y].push_back(x);
}
for(i=1; i<=n; i++)
{
root=i;
dfs(i,0,0);
}
cout<< (ans==1e9?-1:ans);
return 0;
}
//
// Created by SANZONG on 2020/8/10.
//
//lca模板
#include "bits/stdc++.h"
using namespace std;
const int N = 500005;
int cnt, head[N * 2];
int lg[N];
int f[N][50];
int depth[N];
struct node {
int to, next;
} a[2 * N];
int lev[N];
void add(int u, int v) {
a[++cnt].next = head[u];
a[cnt].to = v;
head[u] = cnt;
}
void dfs(int u, int fa) {
f[u][0] = fa;
depth[u] = depth[fa] + 1;
for (int i = 1; i <= lg[depth[u]]; ++i) //2^i = 2^(i-1)+2^(i-1)
f[u][i] = f[f[u][i - 1]][i - 1];
for (int i = head[u]; i; i = a[i].next) {
int v = a[i].to;
if (v == fa) continue;
dfs(v, u);
}
}
int LCA(int x, int y) {
//先到同深度,再俩个一起找最近点。
if (depth[x] < depth[y]) {
swap(x, y);
}
while (depth[x] > depth[y]) {
x = f[x][lg[depth[x] - depth[y]]];
}
if (x == y)
return x;
for (int k = lg[depth[x]]; k >= 0; --k) {
if (f[x][k] != f[y][k]) {
x = f[x][k];
y = f[y][k];
}
}
return f[x][0];
}
int main() {
int n;
cin >> n;
for (int j = 1; j <= n; ++j) {
cin >> lev[j];
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
for (int i = 0; i <= n; ++i) {
lg[i] = i == 0 ? -1 : lg[i >> 1] + 1;
}
dfs(1, 0);
int mi = 1e9;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j < i; ++j) {
if (lev[i] == lev[j])
mi = min(mi, depth[i] + depth[j] - 2 * depth[LCA(i, j)]);
}
}
cout << (mi>=1e9?-1:mi) << endl;
} #include<iostream>
#include<bits/stdc++.h>
using namespace std;
//其实就是求深度(因为节点的长度是1)图论中的深度
vector<int> g[5005];
int rank_d[5005];
int ans = INT_MAX;
int root = 0;
void dfs(int cur,int before,int deep){
if(cur!=root&&rank_d[cur]==rank_d[root]){
ans = min(ans,deep);
}
//遍历节点下的所有元素
for(int i=0;i<g[cur].size();i++){
if(g[cur][i]!=before){
//这要保证 不与上一个节点重复即可
dfs(g[cur][i],cur,deep+1);
}
}
}
int main(){
int n;
cin>>n;
//等级
for(int i=1;i<=n;i++){
int temp = 0;
cin>>temp;
rank_d[i] = temp;
}
//处理图
for(int j=1;j<n;j++){
int x,y;
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
}
//深度遍历图的每一个节点
for(int i=1;i<=n;i++){
root = i;
dfs(i,0,0);
}
cout<<(ans==INT_MAX?-1:ans);
return 0;
} import java.util.*; public class Niuke7 { public static void main(String[] args) { Scanner s = new Scanner(System.in); int n = s.nextInt(); //节点 int[] nodes = new int[n + 1]; for (int i = 1; i <= n; i++) { nodes[i] = s.nextInt(); } //边,用邻接表表示图 List<List<Integer>> edgs = new ArrayList<>(); for (int i = 0; i <= n; i++) { edgs.add(new ArrayList<>()); } for(int i = 0; i < n - 1; i++){ int u = s.nextInt(); int v = s.nextInt(); //构造无向图,应该双向表示 edgs.get(u).add(v); edgs.get(v).add(u); } int min = Integer.MAX_VALUE; for (int i = 1; i <= n; i++) { min = bfs(edgs, nodes, i) == -1? min : Math.min(min, bfs(edgs, nodes, i)); } System.out.println(min == Integer.MAX_VALUE ? -1 : min); } public static int bfs(List<List<Integer>> edgs, int[] nodes, int i){ boolean[] used = new boolean[nodes.length]; Queue<int[]> que = new LinkedList<>(); que.add(new int[]{i, 0}); used[i] = true; while(!que.isEmpty()){ int[] node = que.poll(); int index = node[0]; int path = node[1]; for(int v : edgs.get(index)){ if(used[v]){ continue; } if(nodes[v] == nodes[i]){ return path + 1; } used[v] = true; que.add(new int[]{v, path + 1}); } } return -1; } }
package codeTest.alibaba; import java.util.ArrayList; import java.util.List; /* 在一个地区有 n 个城市以及 n - 1 条无向边,每条边的时间边权都是 1,并且这些城市是联通的, 即这个地区形成了一个树状结构。每个城市有一个等级。 现在小强想从一个城市走到另一个不同的城市,并且每条边经过至多一次,同时他还有一个要求, 起点和终点城市可以任意选择,但是等级必须是相同的。 但是小强不喜欢走特别远的道路,所以他想知道时间花费最小是多少。 进阶:时间复杂度 O (n^2 logn) ,空间复杂度 O (n) 输入描述: 第一行一个正整数 n,含义如题面所述。 第二行 n 个正整数 Ai,代表每个城市的等级。 接下来 n - 1 行每行两个正整数 u, v,代表一条无向边。 保证给出的图是一棵树。 1 ≤ n ≤ 5000。 1 ≤ u, v ≤ n。 1 ≤ Ai ≤ 10⁹。 输出描述: 仅一行一个整数代表答案,如果无法满足要求,输出 -1。 输入例子: 3 1 2 1 1 2 2 3 输出例子: 2 */ import java.util.Scanner; import java.util.Stack; class TreeNode { int val; int rank; TreeNode parent; List<TreeNode> children; public TreeNode() { } public TreeNode(int val, int rank) { this.val = val; this.rank = rank; this.parent = null; this.children = new ArrayList<>(); } public void printChildren() { StringBuilder sb = new StringBuilder(); sb.append("node: ").append(val) .append(" rank: ").append(rank) .append(" parent: ").append(parent == null ? "null" : parent.val) .append(" children:"); for (TreeNode child : children) { sb.append(" ").append(child.val); } System.out.println(sb.toString()); for (TreeNode child : children) { child.printChildren(); } } } class Tree { TreeNode root; public Tree() { } public Tree(TreeNode root) { this.root = root; } public void printTree() { root.printChildren(); } // 查找多叉树中两个节点的最近公共祖先 public static TreeNode findLCAWithoutParent(TreeNode root, TreeNode node1, TreeNode node2) { if (root == null || root == node1 || root == node2) { return root; } TreeNode found = null; int count = 0; // 遍历所有子节点 for (TreeNode child : root.children) { TreeNode temp = findLCAWithoutParent(child, node1, node2); if (temp != null) { found = temp; count++; } } // 若在子树中找到了两个目标节点,当前节点就是最近公共祖先 if (count == 2) { return root; } return found; } // 计算从指定节点到根节点的距离 public static int distanceToRootWithoutParent(TreeNode root, TreeNode target, int depth) { if (root == null) { return -1; } if (root == target) { return depth; } // 遍历所有子节点 for (TreeNode child : root.children) { int dist = distanceToRootWithoutParent(child, target, depth + 1); if (dist != -1) { return dist; } } return -1; } public static int shortestPathWithoutParent(TreeNode root, TreeNode node1, TreeNode node2) { TreeNode lca = findLCAWithoutParent(root, node1, node2); int dist1 = distanceToRootWithoutParent(lca, node1, 0); int dist2 = distanceToRootWithoutParent(lca, node2, 0); return dist1 + dist2; } public static TreeNode findLCA(TreeNode node1, TreeNode node2) { Stack<TreeNode> path1 = new Stack<>(); Stack<TreeNode> path2 = new Stack<>(); // 记录从 node1 到根节点的路径 TreeNode temp = node1; while (temp != null) { path1.push(temp); temp = temp.parent; } // 记录从 node2 到根节点的路径 temp = node2; while (temp != null) { path2.push(temp); temp = temp.parent; } TreeNode lca = null; while (!path1.isEmpty() && !path2.isEmpty() && path1.peek() == path2.peek()) { lca = path1.pop(); path2.pop(); } return lca; } public static int distanceToRoot(TreeNode node, TreeNode root) { int dist = 0; TreeNode temp = node; while (temp != root) { dist++; temp = temp.parent; } return dist; } public static int shortestPath(TreeNode node1, TreeNode node2) { TreeNode lca = findLCA(node1, node2); int dist1 = distanceToRoot(node1, lca); int dist2 = distanceToRoot(node2, lca); return dist1 + dist2; } } public class ShortestChainInTree { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); TreeNode[] nodes = new TreeNode[n + 1]; for (int i = 1; i <= n; i++) { int rank = sc.nextInt(); nodes[i] = new TreeNode(i, rank); } for (int i = 1; i <= n - 1; i++) { int u = sc.nextInt(); int v = sc.nextInt(); nodes[u].children.add(nodes[v]); nodes[v].parent = nodes[u]; } int minDist = Integer.MAX_VALUE; for (int i = 1; i <= n; i++) { for (int j = i + 1; j <= n; j++) { if (nodes[i].rank == nodes[j].rank) { int dist = Tree.shortestPath(nodes[i], nodes[j]); // int dist = Tree.shortestPathWithoutParent(nodes[1], nodes[i], nodes[j]); minDist = Math.min(minDist, dist); } } } if (minDist == Integer.MAX_VALUE) { System.out.println(-1); } else { System.out.println(minDist); } } }
我这个方法为啥过不了所有的样例?
#include <bits/stdc++.h>
#include <climits>
#include <vector>
using namespace std;
vector<int> A;
vector<vector<int>> G;
int mintime=INT_MAX;
//从u结点出发找值value的点并返回最小步数,找不到就是INT_MAX
int dfs(int u,int par,int root){
if(u!=root&&A[u]==A[root]) return 0;
int steps=INT_MAX;
for(int v:G[u]){
if(v==par) continue;
steps=min(steps,dfs(v,u,root));
}
if(steps==INT_MAX) return steps;
else return steps+1;
}
int main() {
int n;cin>>n;
A.resize(n+1,0);
G.resize(n+1,vector<int> (0));
for(int i=1;i<=n;i++){
cin>>A[i];
}
for(int i=1;i<=n-1;i++){
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
for(int i=1;i<=n;i++){
mintime=min(mintime, dfs(i,0,i));
}
int res= mintime==INT_MAX?-1:mintime;
cout<<res<<endl;
}
// 64 位输出请用 printf("%lld") #include <vector>
#include <limits.h>
#include <iostream>
#include <queue>
std::vector<int> T[5000];
std::vector<int> V;
int n;
int main()
{
std::cin >> n;
for (int i = 0; i < n; i++)
{
int x;
std::cin >> x;
V.push_back(x);
}
for (int i = 0; i < n-1; i++)
{
int x,y;
std::cin >> x >> y;
T[x - 1].push_back(y - 1);
T[y - 1].push_back(x - 1);
}
int min = INT_MAX;
for (int i = 0; i < n-1; i++)
{
int target_v = V[i];
std::vector<int> vis(n);
vis[i] = true;
std::fill(vis.begin(), vis.end(), false);
std::queue<std::pair<int, int> > Q; // <node, step>
Q.push(std::pair<int, int>(i, 0));
while (!Q.empty())
{
std::pair<int, int> node = Q.front();
Q.pop();
if (V[node.first] == target_v && node.first != i)
{
min = std::min(min, node.second);
break;
}
for (size_t j = 0; j < T[node.first].size(); j++)
{
if (!vis[T[node.first][j]])
{
vis[T[node.first][j]] = true;
Q.push(std::pair<int, int>(T[node.first][j], node.second + 1));
}
}
}
}
if (min == INT_MAX)
std::cout << -1;
else
std::cout << min;
return 0;
} def bfs(city:int, djs:list, bian:dict, n: int, dj:int):
# 从城市city出发,广度优先搜索寻找dj的城市
visited = [False] * n
length = 0
q = [(city, 0)] # (城市,长度)
visited[city] = True
while len(q) > 0:
length += 1 #
city, s = q.pop(0)
for i in bian[city]:
if visited[i]: # 已经访问过了
continue
if djs[i] == dj:
return s+1
visited[i] = True
q.append((i, s+1))
return n
def func(dengji:dict, djs:list, bian:dict, n:int):
# dengji是[dengji, [城市编号,]] bian是(城市编号:[城市编号])
length = n
for dj, cities in dengji.items():
if len(cities) <= 1:
# 该等级只有一个城市
continue
for city in cities: # 该等级有多个城市,则从城市city出发宽度优先搜索
length = min(bfs(city, djs, bian, n, dj), length)
return length
def main():
# 城市编号居然是从1开始
n = int(input()) # n个城市
dengji = list(map(int, input().split())) # n个城市的等级
bian = {i: [] for i in range(n)} # key是城市编号,value是一个list,表示相邻城市
for _ in range(n-1):
u, v = list(map(int, input().split()))
bian[u-1].append(v-1)
bian[v-1].append(u-1)
dengji_dict = {} # key是等级,item是城市编号列表
for i, dj in enumerate(dengji):
if dj in dengji_dict:
dengji_dict[dj].append(i)
else:
dengji_dict[dj] = [i]
if len(dengji_dict) == n: # n个城市有n个等级,
return -1
return func(dengji_dict, dengji, bian, n)
print(main()) 结合以上大佬们给出的题解,遍历每一个节点,使用bfs寻找任意两节点的最短路径。题目中增加限制条件:节点要在同一级,所以使用领接表存储每个节点相连的节点。
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int[] grade = new int[n + 1];
for (int i = 1; i <= n; i++) {
grade[i] = in.nextInt();
}
List> edges = new ArrayList();
for (int i = 0; i <= n; i++) {
edges.add(new ArrayList());
}
for (int i = 0; i < n - 1; i++) {
int u = in.nextInt();
int v = in.nextInt();
// 双向边
edges.get(u).add(v);
edges.get(v).add(u);
}
int min = Integer.MAX_VALUE;
for (int i = 1; i <= n; i++) {
min = Math.min(min, bfs(edges, grade, i));
}
System.out.println(min == Integer.MAX_VALUE ? -1 : min);
}
private static int bfs(List> edges, int[] grade, int st) {
Queue queue = new LinkedList();
int n = grade.length;
boolean[] marked = new boolean[n];
queue.add(st);
marked[st] = true;
int cur = 0;
while (!queue.isEmpty()) {
int size = queue.size();
cur++;
while (size-- > 0) {
int t = queue.poll();
for (int x : edges.get(t)) {
if (marked[x]) continue;
if (grade[x] == grade[st])
return cur;
queue.add(x);
marked[x] = true;
}
}
}
return Integer.MAX_VALUE;
}
}
O(n), 共DFSn次 #include <bits/stdc++.h>
using namespace std;
const int N = 5010;
int A[N];
vector<int> g[N];
int ans = INT_MAX;
void dfs(int u, int fa, int root, int dep){
if (A[u] == A[root] and u != root)
ans = min(ans, dep);
for (auto& v : g[u]) {
if (v == fa)
continue;
dfs(v, u, root, dep + 1);
}
}
void solve(){
int n;
cin >> n;
for (int i = 1; i <= n; i ++ )
cin >> A[i];
for (int i = 1; i < n; i ++ ) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for (int i = 1; i <= n; i ++ )
dfs(i, -1, i, 0);
if (ans == INT_MAX)
ans = -1;
cout << ans << endl;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int t;
// cin >> t;
t = 1;
while (t -- )
solve();
return 0;
}
基础DFS,JS版本
let readline = require('readline');
let rl = readline.createInterface({
input: process.stdin,
output: process.stdout
});
let n, level, grid;
let row = 1;
rl.on('line', (line) => {
if (row === 1) {
n = +line;
grid = new Array(n + 1).fill('').map(() => { return new Array() });
} else if (row === 2) {
level = line.split(' ').map((item, index) => { return +item });
} else if (row > 2 && row <= 2 + n - 1) {
let [x, y] = line.split(' ').map((item, index) => { return +item });
grid[x].push(y);
grid[y].push(x);
}
if (row === 2 + n - 1) {
let res = leastCost(n);
console.log(res);
rl.close();
}
row++;
function leastCost(n) {
let root, ans = Infinity;
// console.log(grid);
for (let i = 1; i <= n; i++) {
root = i;
dfs(i, 0, 0);
}
if(ans===Infinity){
return -1;
}
return ans;
/**
*
* @param {*} r 当前dfs的出发点
* @param {*} f 上一个访问过的点
* @param {*} deep 此时的遍历深度
*/
function dfs(r, f, deep) {
if (r !== root && level[r - 1] === level[root - 1]) {
ans = Math.min(ans, deep);
}
for (let i = 0; i < grid[r].length; i++) {
if (grid[r][i] !== f) {
dfs(grid[r][i], r, deep + 1);
}
}
}
}
});
import java.util.*;
public class Main {
private static int N, root, res;
private static int[] A;
private static Map<Integer, List<Integer>> g;
private static boolean[] vis;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
// cite
N = sc.nextInt();
// level
A = new int[N + 1];
for (int i = 1; i <= N; i++) {
A[i] = sc.nextInt();
}
// graph
g = new HashMap<>();
for (int i = 1; i < N; i++) {
int a = sc.nextInt(), b = sc.nextInt();
g.computeIfAbsent(a, k -> new ArrayList<>()).add(b);
g.computeIfAbsent(b, k -> new ArrayList<>()).add(a);
}
// calculate
res = 0x3f3f3f3f;
for (int i = 1; i <= N; i++) {
root = i;
vis = new boolean[N + 1];
dfs(i, 0);
}
System.out.println(res == 0x3f3f3f3f ? -1 : res);
}
public static void dfs(int cur, int level) {
// pruning
if (vis[cur]) return;
// exit
if (cur != root && A[cur] == A[root]) {
res = Math.min(res, level);
return;
}
vis[cur] = true;
// dfs
for (int next : g.get(cur)) {
dfs(next, level + 1);
}
}
} import sys
sys.setrecursionlimit(100000)
n = int(input())
ar = [int(i) for i in input().split(" ")]
global ans
def dfs(root, fa):
global ans
tmp = {ar[root-1]:0}
for u in tree[root]:
if u == fa: continue
son_data = dfs(u, root)
for key in son_data:
if key in tmp:
ans = min(ans, son_data[key] + tmp[key])
tmp[key] = min(tmp[key], son_data[key])
else: tmp[key] = son_data[key]
for key in tmp: tmp[key]+=1
return tmp
tree = {}
for _ in range(n-1):
[a, b] = [int(i) for i in input().split(" ")]
if not (a in tree): tree[a] = []
tree[a].append(b)
if not (b in tree): tree[b] = []
tree[b].append(a)
ans = 50001
dfs(1, -1)
if ans == 50001: print(-1)
else: print(ans)