洛谷P5002 专心OI - 找祖先

题目背景
Imakf是一个小蒟蒻,他最近刚学了LCA,他在手机APP里看到一个游戏也叫做LCA就下载了下来。

题目描述
这个游戏会给出你一棵树,这棵树有NN个节点,根结点是RR,系统会选中MM个点P_1,P_2…P_M ,要Imakf回答有多少组点对(u_i,v_i)的最近公共祖先是P_i 。Imakf是个小蒟蒻,他就算学了LCA也做不出,于是只好求助您了。Imakf毕竟学过一点OI,所以他允许您把答案模 (10^9+7)(
9 +7)

输入格式
第一行 N , R , M
此后N-1行 每行两个数a,ba,b 表示a,ba,b之间有一条边
此后1行 M个数 表示P_i
输出格式
M行,每行一个数,第ii行的数表示有多少组点对(u_i,v_i)的最近公共祖先是P_i

题解: 本来想写点求lca的题,看见这个标签有Lca就进来了

然而这题和求lca一点关系没有

简单地分析了一下题目,两个两个求lca再判断方法必然会tle,于是我们考虑一些其他的做法。废话

首先,两个点的lca如果是p的话,我们不难发现,如果有一个点不是p,那么这两个点必然位于p的两个不同子树中,否则就会有一个深度更深的爸爸取代他。

于是我们可以考虑,对于p的每一个子树,必有(size[p]-size[son])*size[son]个合法解,这个式子代表的是p的这棵子树中的每一个点和其他子树中的每一个点构成一对,因为不在一棵子树中,所以p必然是他们两个的lca。把这些加起来之后,p的子树中的任何一个点和p都可以构成一对,而由样例可以知道,p和p自己也是一对,所以再在上面的式子基础上加上size[p]。此题就结束了。 因为(u,v)(v,u)算两组解,所以这样必然是对的,上代码 (m比n大,所以可以先全算出来啦)

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cmath>
#define fo(i) for(int i=1;i<n;i++)
#define Fo(j) for(int j=1;j<=m;j++)
using namespace std;
const int N=500005;
const int M=2000003;
const int modd=1e9+7;
int n,m,s,tot;
int head[N];
int dep[N],ans[N];
struct NODE{
   
	int to,nex;
}bian[M];
struct Node{
   
    int son,dep,fa;
}dian[N];
inline int read(){
   
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
   if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){
   x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline void add(int x,int y){
   
	++tot;
	bian[tot].nex=head[x];
	bian[tot].to=y;
	head[x]=tot;
}
void dfs(int xx,int fa){
   
    dian[xx].son=1;
    for(int i=head[xx];i;i=bian[i].nex){
   
        if(bian[i].to!=fa){
   //这个儿子不是爹
             dian[bian[i].to].fa=xx;
             dian[bian[i].to].dep=dian[xx].dep+1;
             dfs(bian[i].to,xx);
             dian[xx].son+=(dian[bian[i].to].son%modd);
        }
    }
}
inline int solve(int xx){
   
     int ans=dian[xx].son;
     for(int i=head[xx];i;i=bian[i].nex){
   	
          if(dian[bian[i].to].dep>dian[xx].dep){
           	   
               ans=ans+(dian[xx].son-dian[bian[i].to].son)*dian[bian[i].to].son;
          }
     }
     return ans;
     
}
int main(){
   
    n=read();s=read();m=read();
    fo(i){
   
       int a,b;
       a=read();b=read();
       add(a,b); add(b,a);
    }
    dian[0].dep=0;
    dfs(s,0);
    for(int i=1;i<=n;i++) ans[i]=solve(i);
    Fo(j){
   
        int p;
        p=read();
        printf("%d\n",ans[p]);
    }
	return 0;
}

全部评论

相关推荐

点赞 评论 收藏
转发
点赞 收藏 评论
分享
牛客网
牛客企业服务