首页 > 试题广场 >

树上异或

[编程题]树上异或
  • 热度指数:61 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解

小团有一棵树,这棵树有n个节点,编号为1-n。每个节点上有一个值a_i。1号节点为整棵树的根。

现在,小团给小美一个难题:小美每次可以操作一个节点x,将a_x变为,保持x所有的儿子不变,将x所有儿子的儿子a_y变为,保持所有的儿子的儿子的儿子不变,以此类推。

代表位运算异或。

小团希望小美用尽可能少的次数,将所有的a_i变为b_i,请帮助小美计算这个最少的次数。

数据保证在有限步数内,能够将所有的a_i变为b_i


输入描述:

输入第一行包含一个整数n,代表节点数。

接下来n-1行,每行两个整数u_i, v_i,代表树上的一条边

接下来一行,一共n个数,第i个数代表a_i

接下来一行,一共n个数,第i个数代表b_i



输出描述:

输出包含一行一个数,即小美的最少操作次数。

示例1

输入

3
1 2
2 3
4 5 1
5 5 1

输出

2

说明

小美需要操作两次,第一次操作1号节点,三个节点的权值变为5 5 0,第二次操作3号节点,三个节点的权值变为5 5 1


备注:

对于40%的数据,

对于100%的数据,

深度优先搜索,如果对应的节点值不相等,进行一次操作,并将操作记录下来,反映到后续所有对应节点。
#include <iostream>
#include <vector>
#include <bits/stdc++.h>
using namespace std;
map<long long,vector<long long>> tree;
map<long long,long long> Visit;
long long ans=0;
void dfs(long long node,vector<long long> &a,vector<long long> &b,int flag1,int  flag2)
{
    if(flag2==1)
    {
        
        if(a[node]==b[node])
        {
            ans++;
            int flag3=flag2;
            flag2=flag1;
            flag1=(flag3+1)%2;
        }
        else 
            swap(flag1,flag2);
    }
    else 
    {
        if(a[node]!=b[node])
        {
            ans++;
            int flag3=flag2;
            flag2=flag1;
            flag1=(flag3+1)%2;
        }
        else 
            swap(flag1,flag2);
    }
    for(int i=0;i<tree[node].size();i++)
    {
        if(Visit.count(tree[node][i])==0)
        {
            Visit[tree[node][i]]=1;
            dfs(tree[node][i],a,b,flag1,flag2);
        }
    }
}
int main() {
    int n;
    cin>>n;
    vector<long long > a(n+1,0);
    vector<long long>  b(n+1,0);
    for(int i=0;i<n-1;i++)//首先输入n-1个整数,树上的一条边
    {
        long long u,v;
        cin>>u>>v;
        tree[u].push_back(v);
        tree[v].push_back(u);
    }
    for(int i=1;i<=n;i++)
        cin>>a[i];
    for(int i=1;i<=n;i++)
        cin>>b[i];
    Visit[1]=1;
    dfs(1,a,b,0,0);
    cout<<ans;
}


编辑于 2023-04-25 13:31:16 回复(0)
利用字典储存二代子节点遍历,然后计算操作次数。只过了一个,没返回错误例子暂时也改不下去,希望有大佬能指出
from collections import defaultdict,deque
n=int(input())
dic=defaultdict(lambda:[])
for _ in range(n-1):
    a,b=map(int, input().strip().split())
    dic[a].append(b)
an=[0]+list(map(int, input().strip().split()))
bn=[0]+list(map(int, input().strip().split()))
dic1=defaultdict(lambda:[])
que=deque([1])
while que:
    node=que.popleft()
    for nxt in dic[node]:
        que.append(nxt)
        for j in dic[nxt]:
            dic1[node].append(j)
res=0

for i in range(1, n+1):
    if an[i]==bn[i]: continue
    while an[i]!=bn[i]:
        an[i]^=1
        res+=1
        que1=deque([i])
        while que1:
            node=que1.popleft()
            for nxt in dic1[node]:
                que1.append(nxt)
                an[nxt]^=1
print(res)

发表于 2022-03-29 18:35:02 回复(0)