title: Cotree 2019CCPC-江西省赛-A题
题目大意:
给你两棵树,在这两棵树上分别找一个点,将其连接,使得\sum_{i=1}^{n-1}\sum_{j=i+1}^{n}dis(i,j) 最小,其中dis(i,j)表示从节点 i 到节点 j 的边数。
Part1
首先我们需要判断我们找的这两个点应该是哪个点,对于两棵树,他们的 dis 和是固定的,因此我们需要讨论将两个点连接起来所增加的花费。
假设需要连接的两棵树A,B,两棵树上进行连接的点为 u,v ,
点 u,v 到其所在子树其他点的距离之和为Dis_u,Dis_v ,A,B 上点的个数为 P_A,P_B,
那么将其连接后增加的 dis 值为:
Dis_u*P_B+Dis_v*P_A+P_A*P_B
很容易理解:
对于树 A 上的任意一个点 w ,我们需要将其和 B 上的所有点进行一次连接,等同于需要将dis(w,u) 重复计算 P_B 次,其他点同理,因此 A 树上增加的 dis 值为 Dis_u*P_B ,B树同理。
而对于刚建立的通道 dis(u,v)=1 被使用了 P_A*P_B 因此总的增加量即为上式。
P_A,P_B为定值,所以我们只需要最小化 Dis_u,Dis_v 即可。
Part2
现在的问题已经简化成了如何求一棵树上的 Dis 的最小值。
首先我们需要一遍dfs将树的根节点的 Dis 值找出来,找出来之后,我们就使用换根dp计算出树上所有节点的Dis值。
然后找出 Dis 最小的点,进行连接,再次进行上述操作即可。
最后将所有节点的 Dis 全部求和,由于这个值是求的双向的,因此需要除以2。
代码如下(代码有些冗长):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,int>
const ll INF=0x3f3f3f3f3f3f3f3f;
const int maxn=1000000+10;
const ll mod=1e9+7;
vector<ll>vec[maxn];
ll n;
ll dp[maxn];
ll num[maxn];
ll vis[maxn];
ll vis2[maxn];
ll value[maxn];
ll flag,point;
void init(){//清空数组
memset(dp,0,sizeof(dp));
memset(num,0,sizeof(num));
memset(vis,0,sizeof(vis));
memset(vis2,0,sizeof(vis2));
memset(value,0,sizeof(value));
}
ll dfs(ll x){//dfs找所有的值,num表示以这个点为根节点下面有几个点(包括这个点,dp用),value是这个点到所有子树的距离之和
vis[x]=flag;
ll sum=0;
for(ll i=0;i<vec[x].size();i++){
ll y=vec[x][i];
if(!vis[y]){
int q=dfs(y);
sum+=q;
value[x]+=value[y]+q;
}
}
num[x]=sum+1;
return sum+1;
}
void Dp(ll x){//dp数组表示当前这个点到其他点的距离之和
vis2[x]=flag;
for(ll i=0;i<vec[x].size();i++){
ll y=vec[x][i];
if(!vis2[y]){
dp[y]=dp[x]-num[y]+point-num[y];//换根dp方程
Dp(y);
}
}
}
void solve(){
for(int i=1;i<=n;i++) vec[i].clear();
init();
ll u,v;
for(ll i=0;i<n-2;i++){
scanf("%lld %lld",&u,&v);
vec[u].push_back(v);
vec[v].push_back(u);
}
//使用flag的值对第一颗子树和第二棵子树进行区分,下同
flag=1; dfs(1); dp[1]=value[1]; flag++;
for(ll i=1;i<=n;i++)
if(!vis[i]){
dfs(i);
dp[i]=value[i];
break;
}
//point是当前两个子树之一的点的个数
point=0;
for(int i=1;i<=n;i++)
if(vis[i]==1)
point++;
flag=1;
Dp(1);
flag++;
point=n-point;
for(ll i=1;i<=n;i++)
if(!vis2[i]){
Dp(i);
break;
}
//找到两个树上到其他点距离最小的点
ll v1=INF,v2=INF,p1,p2;
for(ll i=1;i<=n;i++){
if(vis2[i]==1 && dp[i]<=v1){
p1=i;
v1=dp[i];
}
if(vis2[i]==2 && dp[i]<=v2){
p2=i;
v2=dp[i];
}
}
vec[p1].push_back(p2);
vec[p2].push_back(p1);
//重新进行dfs 和 dp 计算出这棵大树的所有dp值
init(); flag=1; dfs(1);
dp[1]=value[1]; point=n; Dp(1);
//a->b b->a计算两次,所以 /2
ll res=0;
for(ll i=1;i<=n;i++) res+=dp[i];
printf("%lld\n",res/2ll);
}
int main()
{
while(~scanf("%lld",&n))
solve();
return 0;
}