树剖re求助
#include<map> #include<cstring> #include<vector> #include<cstdio> using namespace std; int n,m,r,p,a[800101],vis[800101],father[800101],w[800101],d[800101],top[800101],c[800101],cnt1[800101],cnt2,ctop,tree[800001],lazy[800001],flag,x,y,z; vector<int>g[800101]; map<int,map<int,int> >vis2; bool check(int x,int y){ for(int i=0;i<g[x].size();i++){ if(!vis2[x][i]&&!vis[g[x][i]]&&w[g[x][i]]>w[g[x][y]])return false; } return true; } void lazytag(int o,int l,int r,int k){lazy[o]+=k;tree[o]=((long long)tree[o]+(long long)(r-l+1)*k)%p;} void pushup(int o){tree[o]=((long long)tree[o<<1]+tree[o<<1|1])%p;} void pushdown(int o,int l,int r){ if(!lazy[o])return; lazy[o<<1]=((long long)lazy[o<<1]+lazy[o])%p; lazy[o<<1|1]=((long long)lazy[o<<1|1]+lazy[o])%p; int mid=(l+r)>>1; tree[o<<1]=((long long)tree[o<<1]+(long long)lazy[o]*(mid-l+1))%p; tree[o<<1|1]=((long long)tree[o<<1|1]+(long long)lazy[o]*(r-mid))%p; lazy[o]=0; } void build(int o,int l,int r){ if(l==r){ tree[o]=a[c[l]];return; } int mid=(l+r)>>1; build(o<<1,l,mid); build(o<<1|1,mid+1,r); pushup(o); } void add(int o,int l,int r,int ql,int qr,int k){ if(ql<=l&&r<=qr){lazytag(o,l,r,k);return;} int mid=(l+r)>>1; pushdown(o,l,r); if(mid>=ql)add(o<<1,l,mid,ql,qr,k); if(mid<qr)add(o<<1|1,mid+1,r,ql,qr,k); pushup(o); } int sum(int o,int l,int r,int ql,int qr){ if(ql<=l&r<=qr)return tree[o]; int ans=0;int mid=(l+r)>>1; pushdown(o,l,r); if(mid>=ql)ans=((long long)ans+sum(o<<1,l,mid,ql,qr))%p; if(mid<qr)ans=((long long)ans+sum(o<<1|1,mid+1,r,ql,qr))%p; return ans; } void dfs1(int x){ w[x]=1; for(int i=0;i<g[x].size();i++){ if(!vis[g[x][i]]&&father[x]!=g[x][i]){ vis[g[x][i]]=1;father[g[x][i]]=x;d[g[x][i]]=d[x]+1; dfs1(g[x][i]);w[x]+=w[g[x][i]]; } } } void dfs2(int x,int v){ vis[x]=1;cnt1[x]=++cnt2;c[++ctop]=x; for(int i=0;i<g[x].size();i++){ if(father[x]!=g[x][i]&&!vis[g[x][i]]&&!vis2[x][i]&&check(x,i)){ if(!v)top[g[x][i]]=top[x]; vis2[x][i]=1;dfs2(g[x][i],v);v=1; } } } void shupou(int x,int y,int z){ while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]])swap(x,y); add(1,1,n,cnt1[top[x]],cnt1[x],z); x=father[top[x]]; } if(d[x]<d[y])swap(x,y); add(1,1,n,cnt1[y],cnt1[x],z); } int shupou1(int x,int y){ int ans=0; while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]])swap(x,y); ans=((long long)ans+sum(1,1,n,cnt1[top[x]],cnt1[x]))%p; x=father[top[x]]; } if(d[x]<d[y])swap(x,y); ans=((long long)ans+sum(1,1,n,cnt1[y],cnt1[x]))%p; return ans; } int main(){ scanf("%lld%lld%lld%lld",&n,&m,&r,&p); for(int i=1;i<=n;i++)scanf("%lld",&a[i]),father[i]=top[i]=i; for(int i=1;i<=n-1;i++){ scanf("%lld%lld",&x,&y); g[x].push_back(y); g[y].push_back(x); } d[r]=1; dfs1(r);memset(vis,0,sizeof(vis));dfs2(r,0); build(1,1,ctop); while(m--){ scanf("%lld",&flag); if(flag==1)scanf("%lld%lld%lld",&x,&y,&z),shupou(x,y,z); if(flag==2)scanf("%lld%lld",&x,&y),printf("%d\n",shupou1(x,y)); if(flag==3)scanf("%lld%lld",&x,&z),add(1,1,n,cnt1[x],cnt1[x]+w[x]-1,z); if(flag==4)scanf("%lld",&x),printf("%d\n",sum(1,1,n,cnt1[x],cnt1[x]+w[x]-1)); } return 0; }
#C++工程师#