两棵树的问题题解
这是一道比较困难的数据结构题。
【暴力法】
暴力的方法,是直接枚举点对,求两次lca,但这复杂度是O(n^2logn)的,必然超时。
【正解】
将暴力的思路反过来,我们枚举树1中的lca,设为z,问题变成了求所有在树1中lca为z的点对在树2中的lca深度的最大值。
首先,满足树1中lca为z的点对,两点在树1中必然分别在z的不同儿子的子树中。这是显然的。
对子树分开处理,假设z的孩子们分别为p[1],p[2],p[3]...假设我们处理到p[i],暴力地求解,则是枚举子树p[i]的所有点,另外枚举p[1]~p[i-1]的子树中的点,再在树2中求点对的lca,这样看下来复杂度变成了O(n^3logn),不减反增。
但上述方法显然是可以优化的。
假设我们现在要求树1中点x和若干其他点在树2中lca深度的最大值,有一个非常简单的结论:
将树2的dfs序处理出来,初始时每个位置为白点。这些其他点在树2中的对应点已经在dfs序上被染黑。
这个最大值必然是a[x]与a[x]在dfs序上往前第一个黑点或往后第一个黑点(下面简称前驱和后继)的lca的深度。
根据dfs序的特点可以知道,这个结论是正确的。
根据这个结论,枚举p[i]的所有点之后,用一个数据结构来维护树2的dfs序上点的黑或白,要求支持查询前驱和后继,同时在枚举p[i]的所有点时用该数据结构进行查询得到答案(当然p[1]是不用查询的)。用set或者线段树就可以轻松搞定。复杂度为O( ),极限是O(n^2logn),依旧不够优秀。
当然还可以继续优化。
当然在我们要解决的问题里,子树p[i]是依次枚举的,而p[1]由于是第一个插入的子树,不必对p[1]的所有点进行查询操作,不妨将子树大小最大的儿子(下称重儿子)作为p[1]。
每当到一个点z的时候,优先处理重儿子,处理完毕之后,可以发现这个重儿子对应子树的点在树2的dfs序已经全部被染黑,这样我们就可以不用枚举这个重儿子进行染色查询等操作了,复杂度产生了一点点变化,设s[x]为点x的重儿子,复杂度为O( ),极限是O(nlog^2n),证明的话,考虑树链剖分或者启发式合并的复杂度,都是可以证明出来的,网上有许多优秀的资料,我讲的可能没有它们好...
【小结】
这题虽然比较难,但我想还是可以给大家提供一些很好的思路,如将问题反转看待、快速求某点与部分特殊点的lca最大深度、枚举子树问题的时间优化......这些都是非常实用的方法或结论,而且可能稍微有一点点冷门,是一些人(至少我)的盲区。希望能对大家有所帮助。
参考代码(我是使用线段树的做法(因为我不是很会用stl的说)):
如果对代码的某些部分存在疑问,可以私信我,我会及时回复。
class Solution { public: /** * * @param n int整型 * @param a int整型一维数组 * @param aLen int a数组长度 * @param b int整型一维数组 * @param bLen int b数组长度 * @param c int整型一维数组 * @param cLen int c数组长度 * @return int整型 */ int cnt=0,p[100010],ans,nn; int tot1=0,he1[100010],ne1[100010],t1[100010]; void link1(int x,int y) { tot1++; ne1[tot1]=he1[x]; he1[x]=tot1; t1[tot1]=y; } int tot2=0,he2[100010],ne2[100010],t2[100010]; void link2(int x,int y) { tot2++; ne2[tot2]=he2[x]; he2[x]=tot2; t2[tot2]=y; } int dfn1[100010],fr1[100010],sz[100010],dep1[100010],g[100010]; void dfs1(int x) { cnt++; dfn1[x]=cnt; fr1[cnt]=x; sz[x]=1; for (int i=he1[x];i;i=ne1[i]) { dep1[t1[i]]=dep1[x]+1; dfs1(t1[i]); sz[x]=sz[x]+sz[t1[i]]; if (sz[t1[i]]>sz[g[x]]) g[x]=t1[i]; } } int dfn2[100010],fr2[100010],dep2[100010],f[100010][18]; void dfs2(int x) { cnt++; dfn2[x]=cnt; fr2[cnt]=x; for (int i=he2[x];i;i=ne2[i]) { dep2[t2[i]]=dep2[x]+1; f[t2[i]][0]=x; for (int j=1;j<=17;j++) f[t2[i]][j]=f[f[t2[i]][j-1]][j-1]; dfs2(t2[i]); } } int lca(int x,int y) { if (dep2[x]<=dep2[y]) swap(x,y); for (int i=17;i>=0;i--) if (dep2[f[x][i]]>=dep2[y]) x=f[x][i]; if (x==y) return x; for (int i=17;i>=0;i--) if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; return f[x][0]; } int mx[100010*4],mn[100010*4]; void modify(int rt,int l,int r,int x,int y) { if (l==r) { if (y==-1) mx[rt]=0,mn[rt]=nn+1; else mx[rt]=mn[rt]=y; return; } int mid=(l+r)>>1; if (x<=mid) modify(rt<<1,l,mid,x,y);else modify(rt<<1|1,mid+1,r,x,y); mx[rt]=max(mx[rt<<1],mx[rt<<1|1]); mn[rt]=min(mn[rt<<1],mn[rt<<1|1]); } int get_max(int rt,int l,int r,int x,int y) { if (l>=x&&r<=y) return mx[rt]; int mid=(l+r)>>1; int ret=0; if (x<=mid) ret=max(ret,get_max(rt<<1,l,mid,x,y)); if (y>mid) ret=max(ret,get_max(rt<<1|1,mid+1,r,x,y)); return ret; } int get_min(int rt,int l,int r,int x,int y) { if (l>=x&&r<=y) return mn[rt]; int mid=(l+r)>>1; int ret=nn+1; if (x<=mid) ret=min(ret,get_min(rt<<1,l,mid,x,y)); if (y>mid) ret=min(ret,get_min(rt<<1|1,mid+1,r,x,y)); return ret; } int calc(int x,int y) { if (!x) return -nn; return dep2[lca(x,y)]; } void clean(int x) { for (int i=dfn1[x];i<=dfn1[x]+sz[x]-1;i++) modify(1,1,nn,dfn2[p[fr1[i]]],-1); } void ins(int x) { for (int i=dfn1[x];i<=dfn1[x]+sz[x]-1;i++) { int u=get_max(1,1,nn,1,dfn2[p[fr1[i]]]),v=get_min(1,1,nn,dfn2[p[fr1[i]]],nn); modify(1,1,nn,dfn2[p[fr1[i]]],dfn2[p[fr1[i]]]); ans=max(ans,dep1[x]-1+calc(fr2[u],p[fr1[i]])); ans=max(ans,dep1[x]-1+calc(fr2[v],p[fr1[i]])); } } void solve(int x) { for (int i=he1[x];i;i=ne1[i]) if (t1[i]!=g[x]) { solve(t1[i]); clean(t1[i]); } if (g[x]) solve(g[x]); int u=get_max(1,1,nn,1,dfn2[p[x]]),v=get_min(1,1,nn,dfn2[p[x]],nn); modify(1,1,nn,dfn2[p[x]],dfn2[p[x]]); ans=max(ans,dep1[x]+calc(fr2[u],p[x])); ans=max(ans,dep1[x]+calc(fr2[v],p[x])); for (int i=he1[x];i;i=ne1[i]) if (t1[i]!=g[x]) ins(t1[i]); } int wwork(int n, int* a, int aLen, int* b, int bLen, int* c, int cLen) { // write code here nn=n; for (int i=1;i<=n;i++) p[i]=a[i]; for (int i=2;i<=n;i++) { link1(b[i],i); link2(c[i],i); } dfs1(1); cnt=0; dfs2(1); ans=0; for (int i=1;i<=n;i++) modify(1,1,n,i,-1); solve(1); return ans; } };