两棵树的问题题解

这是一道比较困难的数据结构题。

【暴力法】
暴力的方法,是直接枚举点对,求两次lca,但这复杂度是O(n^2logn)的,必然超时。

【正解】
将暴力的思路反过来,我们枚举树1中的lca,设为z,问题变成了求所有在树1中lca为z的点对在树2中的lca深度的最大值。

首先,满足树1中lca为z的点对,两点在树1中必然分别在z的不同儿子的子树中。这是显然的。

对子树分开处理,假设z的孩子们分别为p[1],p[2],p[3]...假设我们处理到p[i],暴力地求解,则是枚举子树p[i]的所有点,另外枚举p[1]~p[i-1]的子树中的点,再在树2中求点对的lca,这样看下来复杂度变成了O(n^3logn),不减反增。

但上述方法显然是可以优化的。
假设我们现在要求树1中点x和若干其他点在树2中lca深度的最大值,有一个非常简单的结论:
将树2的dfs序处理出来,初始时每个位置为白点。这些其他点在树2中的对应点已经在dfs序上被染黑。
这个最大值必然是a[x]与a[x]在dfs序上往前第一个黑点或往后第一个黑点(下面简称前驱和后继)的lca的深度。

根据dfs序的特点可以知道,这个结论是正确的。

根据这个结论,枚举p[i]的所有点之后,用一个数据结构来维护树2的dfs序上点的黑或白,要求支持查询前驱和后继,同时在枚举p[i]的所有点时用该数据结构进行查询得到答案(当然p[1]是不用查询的)。用set或者线段树就可以轻松搞定。复杂度为O(图片说明 ),极限是O(n^2logn),依旧不够优秀。

当然还可以继续优化。
当然在我们要解决的问题里,子树p[i]是依次枚举的,而p[1]由于是第一个插入的子树,不必对p[1]的所有点进行查询操作,不妨将子树大小最大的儿子(下称重儿子)作为p[1]。
每当到一个点z的时候,优先处理重儿子,处理完毕之后,可以发现这个重儿子对应子树的点在树2的dfs序已经全部被染黑,这样我们就可以不用枚举这个重儿子进行染色查询等操作了,复杂度产生了一点点变化,设s[x]为点x的重儿子,复杂度为O(图片说明 ),极限是O(nlog^2n),证明的话,考虑树链剖分或者启发式合并的复杂度,都是可以证明出来的,网上有许多优秀的资料,我讲的可能没有它们好...

【小结】
这题虽然比较难,但我想还是可以给大家提供一些很好的思路,如将问题反转看待、快速求某点与部分特殊点的lca最大深度、枚举子树问题的时间优化......这些都是非常实用的方法或结论,而且可能稍微有一点点冷门,是一些人(至少我)的盲区。希望能对大家有所帮助。

参考代码(我是使用线段树的做法(因为我不是很会用stl的说)):
如果对代码的某些部分存在疑问,可以私信我,我会及时回复。

class Solution {
public:
    /**
     * 
     * @param n int整型 
     * @param a int整型一维数组 
     * @param aLen int a数组长度
     * @param b int整型一维数组 
     * @param bLen int b数组长度
     * @param c int整型一维数组 
     * @param cLen int c数组长度
     * @return int整型
     */

    int cnt=0,p[100010],ans,nn;

    int tot1=0,he1[100010],ne1[100010],t1[100010];
    void link1(int x,int y)
    {
        tot1++;
        ne1[tot1]=he1[x];
        he1[x]=tot1;
        t1[tot1]=y;
    }

    int tot2=0,he2[100010],ne2[100010],t2[100010];
    void link2(int x,int y)
    {
        tot2++;
        ne2[tot2]=he2[x];
        he2[x]=tot2;
        t2[tot2]=y;
    }

    int dfn1[100010],fr1[100010],sz[100010],dep1[100010],g[100010];
    void dfs1(int x)
    {
        cnt++;
        dfn1[x]=cnt;
        fr1[cnt]=x;
        sz[x]=1;
        for (int i=he1[x];i;i=ne1[i]) 
        {
            dep1[t1[i]]=dep1[x]+1;
            dfs1(t1[i]);
            sz[x]=sz[x]+sz[t1[i]];
            if (sz[t1[i]]>sz[g[x]]) g[x]=t1[i];
        }
    }

    int dfn2[100010],fr2[100010],dep2[100010],f[100010][18];
    void dfs2(int x)
    {
        cnt++;
        dfn2[x]=cnt;
        fr2[cnt]=x;
        for (int i=he2[x];i;i=ne2[i])
        {
            dep2[t2[i]]=dep2[x]+1;
            f[t2[i]][0]=x;
            for (int j=1;j<=17;j++) f[t2[i]][j]=f[f[t2[i]][j-1]][j-1];
            dfs2(t2[i]);
        }
    }

    int lca(int x,int y)
    {
        if (dep2[x]<=dep2[y]) swap(x,y);
        for (int i=17;i>=0;i--) if (dep2[f[x][i]]>=dep2[y]) x=f[x][i];
        if (x==y) return x;
        for (int i=17;i>=0;i--) if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
        return f[x][0]; 
    }

    int mx[100010*4],mn[100010*4];
    void modify(int rt,int l,int r,int x,int y)
    {
        if (l==r)
        {
            if (y==-1) mx[rt]=0,mn[rt]=nn+1;
            else mx[rt]=mn[rt]=y;
            return;
        }
        int mid=(l+r)>>1;
        if (x<=mid) modify(rt<<1,l,mid,x,y);else modify(rt<<1|1,mid+1,r,x,y);
        mx[rt]=max(mx[rt<<1],mx[rt<<1|1]);
        mn[rt]=min(mn[rt<<1],mn[rt<<1|1]); 
    }

    int get_max(int rt,int l,int r,int x,int y)
    {
        if (l>=x&&r<=y) return mx[rt];
        int mid=(l+r)>>1;
        int ret=0;
        if (x<=mid) ret=max(ret,get_max(rt<<1,l,mid,x,y));
        if (y>mid) ret=max(ret,get_max(rt<<1|1,mid+1,r,x,y));
        return ret;
    }

    int get_min(int rt,int l,int r,int x,int y)
    {
        if (l>=x&&r<=y) return mn[rt];
        int mid=(l+r)>>1;
        int ret=nn+1;
        if (x<=mid) ret=min(ret,get_min(rt<<1,l,mid,x,y));
        if (y>mid) ret=min(ret,get_min(rt<<1|1,mid+1,r,x,y));
        return ret;
    }

    int calc(int x,int y)
    {
        if (!x) return -nn;
        return dep2[lca(x,y)];
    }

    void clean(int x)
    {
        for (int i=dfn1[x];i<=dfn1[x]+sz[x]-1;i++) modify(1,1,nn,dfn2[p[fr1[i]]],-1);
    }

    void ins(int x)
    {
        for (int i=dfn1[x];i<=dfn1[x]+sz[x]-1;i++) 
        {
            int u=get_max(1,1,nn,1,dfn2[p[fr1[i]]]),v=get_min(1,1,nn,dfn2[p[fr1[i]]],nn);
            modify(1,1,nn,dfn2[p[fr1[i]]],dfn2[p[fr1[i]]]);
            ans=max(ans,dep1[x]-1+calc(fr2[u],p[fr1[i]]));
            ans=max(ans,dep1[x]-1+calc(fr2[v],p[fr1[i]]));
        }
    }

    void solve(int x)
    {
        for (int i=he1[x];i;i=ne1[i]) if (t1[i]!=g[x]) 
        {
            solve(t1[i]);
            clean(t1[i]);
        }
        if (g[x]) solve(g[x]);
        int u=get_max(1,1,nn,1,dfn2[p[x]]),v=get_min(1,1,nn,dfn2[p[x]],nn);
        modify(1,1,nn,dfn2[p[x]],dfn2[p[x]]);
        ans=max(ans,dep1[x]+calc(fr2[u],p[x]));
        ans=max(ans,dep1[x]+calc(fr2[v],p[x]));
        for (int i=he1[x];i;i=ne1[i]) if (t1[i]!=g[x]) ins(t1[i]);
    }


    int wwork(int n, int* a, int aLen, int* b, int bLen, int* c, int cLen) {
        // write code here
        nn=n;
        for (int i=1;i<=n;i++) p[i]=a[i];
        for (int i=2;i<=n;i++) 
        {
            link1(b[i],i);
            link2(c[i],i);
        }
        dfs1(1);
        cnt=0;
        dfs2(1);
        ans=0;
        for (int i=1;i<=n;i++) modify(1,1,n,i,-1);
        solve(1);
        return ans;
    }
};
全部评论

相关推荐

点赞 评论 收藏
转发
点赞 收藏 评论
分享
牛客网
牛客企业服务