bbb

#include <iostream>
#include <climits>
#include <cstring>
#define int long long
using namespace std;
typedef long long ll;
const int maxn=1e5+10;
ll n,m,r,p;
struct Edge{
	ll next,to;
}e[maxn<<1];
struct node{
	ll l,r,lazy,sum;
}tree[maxn<<2];
int head[maxn];
ll cnt,tim;
ll dep[maxn],siz[maxn],son[maxn],par[maxn],rnk[maxn],w[maxn],top[maxn],dfn[maxn];
void add(ll x,ll y){
	e[cnt].to=y;
	e[cnt].next=head[x];
	head[x]=cnt++;
}
void dfs1(ll u,ll fa,ll de){
	par[u]=fa,dep[u]=de;
	siz[u]=1;
	ll mm=-INT_MAX,pp=-1;
	for(int i=head[u];~i;i=e[i].next){
		int v=e[i].to;
		if(v==fa) continue;
		dfs1(v,u,de+1);
		siz[u]+=siz[v];
		if(siz[v]>mm){
			pp=v,mm=siz[v];
		}
	}
	son[u]=pp;
}

void dfs2(ll u,ll t){
	dfn[u]=++tim;
	rnk[tim]=u;
	top[u]=t;
	if(son[u]==-1) return;
	dfs2(son[u],t);
	for(int i=head[u];~i;i=e[i].next){
		int v=e[i].to;
		if(v==par[u]||v==son[u]) continue;
		dfs2(v,v);
	}
}

void pushup(int rt){
	tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum;
}

void pushdown(int rt){
	if(tree[rt].lazy){
		tree[rt<<1].lazy+=tree[rt].lazy;
		tree[rt<<1|1].lazy+=tree[rt].lazy;
		tree[rt<<1].sum+=(tree[rt<<1].r-tree[rt<<1].l+1)*tree[rt].lazy;
		tree[rt<<1|1].sum+=(tree[rt<<1|1].r-tree[rt<<1|1].l+1)*tree[rt].lazy;
		tree[rt].lazy=0;
	}
}

void build(ll l,ll r,ll rt){
	tree[rt].l=l,tree[rt].r=r;
	if(l==r){
		tree[rt].sum=w[rnk[l]];
		tree[rt].lazy=0;
		return;
	}
	ll m=(l+r)>>1;
	build(l,m,rt<<1);
	build(m+1,r,rt<<1|1);
	pushup(rt);
}

void add(ll L,ll R,ll rt,ll val){
	ll l=tree[rt].l,r=tree[rt].r;
	if(L<=l&&r<=R){
		tree[rt].lazy+=val;
		tree[rt].sum+=(tree[rt].r-tree[rt].l+1)*val;
		return;
	}
	pushdown(rt);
	int m=(l+r)>>1;
	if(m>=L) add(L,R,rt<<1,val);
	if(m<R) add(L,R,rt<<1|1,val);
	pushup(rt);
}

ll getsum(ll L,ll R,ll rt){
	ll l=tree[rt].l,r=tree[rt].r;
	if(L<=l&&r<=R){
		return tree[rt].sum;
	}
	pushdown(rt);
	ll ans=0;
	ll m=(l+r)>>1;
	if(m>=L) ans=(ans+getsum(L,R,rt<<1))%p;
	if(m<R) ans=(ans+getsum(L,R,rt<<1|1))%p;;
	return ans;
}

void psubtree(ll x,ll val){
	add(dfn[x],dfn[x]+siz[x]-1,1,val);
}

ll getstree(int x){
	return getsum(dfn[x],dfn[x]+siz[x]-1,1)%p;
}

void pchain(int x,int y,int val){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]){
			swap(x,y);
		}
		add(dfn[top[x]],dfn[x],1,val);
		x=par[top[x]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	add(dfn[y],dfn[x],1,val);
}

ll gschain(int x,int y){
	ll res=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		res=(res+getsum(dfn[top[x]],dfn[x],1))%p;
		x=par[top[x]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	res=(res+getsum(dfn[y],dfn[x],1))%p;
	return res;
}

signed main()
{
	memset(head,-1,sizeof(head));
	//cin>>n>>m>>r>>p; 
	scanf("%lld %lld %lld %lld",&n,&m,&r,&p);
	for(int i=1;i<=n;i++) cin>>w[i],w[i]%=p;
	for(int i=1;i<n;i++){
		int x,y;
		//cin>>x>>y;
		scanf("%lld %lld",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(r,-1,0);
	dfs2(r,r);
	build(1,tim,1);
	while(m--){
		int cmd;
		//cin>>cmd;
		scanf("%lld",&cmd);
		if(cmd==1){
			ll x,y,z;
			//cin>>x>>y>>z;
			scanf("%lld %lld %lld",&x,&y,&z);
			pchain(x,y,z); 
		}
		if(cmd==2){
			ll x,y;
			//cin>>x>>y;
			scanf("%lld %lld",&x,&y);
			printf("%lld\n",gschain(x,y)%p);
		}
		if(cmd==3){
			ll x,z;
			//cin>>x>>z;
			scanf("%lld %lld",&x,&z);
			psubtree(x,z);
		}
		if(cmd==4){
			ll x;
			//cin>>x;
			scanf("%lld",&x);
			printf("%lld\n",getstree(x)%p);
		}
	}
	return 0;
}
全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务