洛谷P3384 【模板】轻重链剖分/树链剖分+线段树

题目链接
题目大意:
已知一棵包含 N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
1、 x y z,表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z。
2、 x y,表示求树从 x 到 y 结点最短路径上所有节点的值之和。
3、 x z,表示将以 x 为根节点的子树内所有节点值都加上 z。
4、 x 表示求以 x 为根节点的子树内所有节点值之和。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,r,p;
int ch,x,y,z;
int a[100005];
int h[100005],e[100005*2],ne[100005*2];
int inx,idx;
int d[100005],size[100005],top[100005],f[100005],son[100005];
//d数组表示一个点的深度,size[i]表示以i为根节点的子树的点的数量
//top[i]求i所在重链的第一个节点,f表示一个节点的父节点
//son记录以子节点为根的所有子树中节点数量最多的子节点
int s[100005],dfsxu[100005],ni[100005];//DFS序列的三个相关数组
struct ty{
    int l,r;
    ll lazy;
    ll sum;
    ty():l(0),r(0),lazy(0),sum(0){ }
};
ty tr[100005*4];
void add(int x,int y){//链式前向星
    ne[inx]=h[x];e[inx]=y;h[x]=inx++;
    return ;
}
void dfs1(int u,int fa){
    d[u]=d[fa]+1; size[u]=1;
    f[u]=fa; son[u]=0;
    for(int i=h[u];i!=-1;i=ne[i]){
        int to=e[i];
        if(to==fa) continue;
        dfs1(to,u);
        size[u]+=size[to];
        if(size[son[u]]<size[to]) son[u]=to;
    }
    return ;
}
void dfs2(int u,int tops){//求top数组
    top[u]=tops;
    s[u]=++idx;
    dfsxu[idx]=u;
    if(son[u]!=0) dfs2(son[u],tops);
    for(int i=h[u];i!=-1;i=ne[i]){
        int to=e[i];
        if(to!=f[u]&&to!=son[u]){
            dfs2(to,to);
        }
    }
    ni[u]=idx;
    return ;
}
//线段树基本操作
void pushup(int u){
    tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
    return ;
}
void build(int u,int l,int r){
    tr[u].l=l; tr[u].r=r;
    tr[u].lazy=0;
    if(l==r) {
        tr[u].sum=a[dfsxu[l]];
        return ;
    }
    int mind=(l+r)>>1;
    build(u<<1,l,mind);
    build(u<<1|1,mind+1,r);
    pushup(u);
    return ;
}
void pushdown(int u){
    if(tr[u].lazy){
        tr[u<<1].sum+=(tr[u<<1].r-tr[u<<1].l+1)*tr[u].lazy;
        tr[u<<1|1].sum+=(tr[u<<1|1].r-tr[u<<1|1].l+1)*tr[u].lazy;
        tr[u<<1].lazy+=tr[u].lazy;
        tr[u<<1|1].lazy+=tr[u].lazy;
        tr[u].lazy=0;
    }
    return ;
}
void modify(int u,int l,int r,int da){
    if(tr[u].l>=l&&tr[u].r<=r){
        tr[u].sum+=(tr[u].r-tr[u].l+1)*da;
        tr[u].lazy+=da;
        return ;
    }
    pushdown(u);
    int mind=(tr[u].l+tr[u].r)>>1;
    if(l<=mind) modify(u<<1,l,r,da);
    if(r>mind) modify(u<<1|1,l,r,da);
    pushup(u);
    return ;
}
ll query(int u,int l,int r){
    if(tr[u].l>=l&&tr[u].r<=r){
        return tr[u].sum;
    }
    pushdown(u);
    int mind=(tr[u].l+tr[u].r)>>1;
    ll ans=0;
    if(l<=mind) ans+=query(u<<1,l,r);
    if(r>mind) ans+=query(u<<1|1,l,r);
    return ans;
}
void update1(int x,int y,int z){
    int l,r;
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]]) swap(x,y);
        l=s[top[x]]; r=s[x];
        modify(1,l,r,z);
        x=f[top[x]];
        //l=s[f[top[x]]]; r=s[top[x]];
        //modify(1,l,r,z);
    }
    if(d[x]>d[y]) swap(x,y);
    l=s[x]; r=s[y];
    modify(1,l,r,z);
    return ;
}
ll merge1(int x,int y){
    int l,r;
    ll ans=0;
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]]) swap(x,y);
        l=s[top[x]]; r=s[x];
        ans+=query(1,l,r);
        //l=s[f[top[x]]]; r=s[top[x]];
        //ans+=query(1,l,r);注意注释掉的这一部分轻边是不需要算的,不然会重复计算一些点
        x=f[top[x]];
    }
    if(d[x]>d[y]) swap(x,y);
    l=s[x]; r=s[y];
    //cout<<"l="<<l<<"r="<<r<<"\n";
    ans+=query(1,l,r);
    //cout<<"ans="<<query(1,l,r)<<"\n";
    return ans;
}
void update2(int x,int z){
    int l,r;
    l=s[x]; r=ni[x];
    modify(1,l,r,z);
    return ;
}
ll merge2(int x){
    int l,r;
    l=s[x]; r=ni[x];
    return query(1,l,r);
}
int main(){
    memset(h,-1,sizeof(h));
    scanf(" %d %d %d %d",&n,&m,&r,&p);
    for(int i=1;i<=n;i++) scanf(" %d",&a[i]);
    for(int i=1;i<n;i++){
        scanf(" %d %d",&x,&y);
        add(x,y); add(y,x);
    }
    dfs1(r,0);
    dfs2(r,r);
    build(1,1,n);
    for(int i=1;i<=m;i++){
        scanf(" %d",&ch);
        if(ch==1){
            scanf(" %d %d %d",&x,&y,&z);
            update1(x,y,z);
        }
        else if(ch==2){
            scanf(" %d %d",&x,&y);
            ll ans=merge1(x,y);
            printf("%d\n",ans%p);
        }
        else if(ch==3){
            scanf(" %d %d",&x,&z);
            update2(x,z);
        }
        else{
            scanf(" %d",&x);
            ll ans=merge2(x);
            printf("%d\n",ans%p);
        }
    }
    return 0;
}
全部评论

相关推荐

1 收藏 评论
分享
牛客网
牛客企业服务