线段树总结
前言
写在前面,线段树是一种用于区间处理的数据结构,本篇博客用来记录我学习线段树的刷题过程。
1.区间求和以及单点修改
这里由于只是涉及单点修改操作,所以就不用lazy标记了
hdu1166
#include<bits/stdc++.h> using namespace std; typedef long long ll; const ll INF = -1e9; const ll N =5e4+7; const ll mod=1e9+7; inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; } inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); } inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; } ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans *= a;b >>= 1;a *= a; } return ans; } ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; } inline int lowbit(int x) { return x & (-x); } ll getinv(ll x){return qpow(x,mod-2,mod);} ll sum[N<<2],a[N]; void build(int l,int r,int rt){//l,r是区间的左右端点,rt是编号,满二叉树建树 if(l==r){//叶子结点,赋值 sum[rt]=a[l]; return; } int mid=(l+r)>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } void Add(int i,int j,int l,int r,int rt){//单点更新 sum[rt]+=j; if(l==r) return; int mid=(l+r)>>1; if(i<=mid) Add(i,j,l,mid,rt<<1);//左搜 if(i>mid) Add(i,j,mid+1,r,rt<<1|1); } ll query(int a,int b,int l,int r,int rt){//区间求和 if(a<=l&&b>=r) return sum[rt]; int mid=(l+r)>>1; ll ans=0; if(a<=mid) ans+=query(a,b,l,mid,rt<<1); if(b>mid) ans+=query(a,b,mid+1,r,rt<<1|1); return ans; } int main(){ int t,n,j=0; scanf("%d",&t); while(t--){ cout<<"Case "<<++j<<":\n"; scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%ld",&a[i]); build(1,n,1); string q;int i,j; while(1){ cin>>q; if(q=="End") break; scanf("%d%d",&i,&j); if(q=="Add"){ Add(i,j,1,n,1); } else if(q=="Query"){ printf("%lld\n",query(i,j,1,n,1)); } else if(q=="Sub"){ Add(i,-j,1,n,1); } } } return 0; }
2.区间修改
区间加法
1.将某区间每一个数加上k。
2.求出某区间每一个数的和。
记得推加法标记
#include<bits/stdc++.h> using namespace std; typedef long long ll; const ll INF = -1e9; const ll N =1e5+7; const ll mod=1e9+7; inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; } inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); } inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; } ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans *= a;b >>= 1;a *= a; } return ans; } ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; } inline int lowbit(int x) { return x & (-x); } ll getinv(ll x){return qpow(x,mod-2,mod);} ll a[N],sum[N<<2],add[N<<2]; void build(int l,int r,int rt){ if(l==r){ sum[rt]=a[l]; return; } int mid=l+r>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } void pushdown(int rt,int m){ if(add[rt]){ add[rt<<1]+=add[rt]; add[rt<<1|1]+=add[rt]; sum[rt<<1]+=(m-(m>>1))*add[rt]; sum[rt<<1|1]+=(m>>1)*add[rt]; add[rt]=0;//传递下去后就要取消本层标记 } } void update(int l,int r,int rt,int x,int y,int val){ if(x<=l&&r<=y){ sum[rt]+=(r-l+1)*val; add[rt]+=val; return; } int mid=(l+r)>>1; pushdown(rt,r-l+1); if(x<=mid) update(l,mid,rt<<1,x,y,val); if(y>mid) update(mid+1,r,rt<<1|1,x,y,val); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } ll qsum(int rt,int l,int r,int x,int y){ if(x<=l&&r<=y){ return sum[rt]; } pushdown(rt,r-l+1); ll ans=0; int mid=l+r>>1; if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y); if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y); return ans; } int main(){ int n,m,op,x,y,k; cin>>n>>m; for(int i=1;i<=n;i++) a[i]=read(); build(1,n,1); for(int i=1;i<=m;i++){ cin>>op; if(op==2){ cin>>x>>y; cout<<qsum(1,1,n,x,y)<<endl; continue; } if(op==1){ cin>>x>>y>>k; update(1,n,1,x,y,k); } } } /* 5 1 2 3 4 5 1 3 */
区间乘法
P3373 【模板】线段树 2
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
#include<bits/stdc++.h> using namespace std; typedef long long ll; const ll INF = -1e9; const ll N =1e5+7; inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; } inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); } ll a[N],sum[N<<2],add[N<<2],mul[N<<2]; ll mod; void pushdown(int rt,int l,int r){//这里要先乘后加 if(mul[rt]!=1){ mul[rt << 1]=(mul[rt << 1] * mul[rt]) % mod; mul[rt << 1| 1]=(mul[rt << 1| 1] * mul[rt]) % mod; add[rt<<1]=(add[rt<<1]*mul[rt])%mod; add[rt<<1|1]=(add[rt<<1|1]*mul[rt])%mod; sum[rt<<1]=(sum[rt<<1]*mul[rt])%mod; sum[rt<<1|1]=(sum[rt<<1|1]*mul[rt])%mod; mul[rt]=1; } int mid=l+r>>1; if(add[rt]){ sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*add[rt])%mod; sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*add[rt])%mod; add[rt<<1]=(add[rt<<1]+add[rt])%mod; add[rt<<1|1]=(add[rt<<1|1]+add[rt])%mod; add[rt]=0;//传递下去后就要取消本层标记 } } void build(int l,int r,int rt){ mul[rt]=1; if(l==r){ sum[rt]=a[l]; return; } int mid=l+r>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod; } void update1(int rt,int l,int r,int x,int y,ll z){//乘法,乘z if(x<=l&&r<=y){ sum[rt]=sum[rt] * z % mod; add[rt]=add[rt] * z % mod; mul[rt]=mul[rt] * z % mod; return; } int mid=(l+r)>>1; pushdown(rt,l,r); if(x<=mid) update1(rt<<1,l,mid,x,y,z); if(y>mid) update1(rt<<1|1,mid+1,r,x,y,z); sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod; } void update2(int rt,int l,int r,int x,int y,int val){//加法 if(x<=l&&r<=y){ sum[rt]=(sum[rt]+(r-l+1)*val)%mod; add[rt]=(add[rt]+val)%mod; return; } int mid=(l+r)>>1; pushdown(rt,l,r); if(x<=mid) update2(rt<<1,l,mid,x,y,val); if(y>mid) update2(rt<<1|1,mid+1,r,x,y,val); sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod; } ll qsum(int rt,int l,int r,int x,int y){ if(x<=l&&r<=y) return sum[rt]; pushdown(rt,l,r); ll ans=0; int mid=l+r>>1; if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y); ans%=mod; if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y); return ans%mod; } int main(){ ll n,m,op,x,y,k; cin>>n>>m>>mod; for(int i=1;i<=n;i++) a[i]=read(); build(1,n,1); for(int i=1;i<=m;i++){ op=read(); if(op==1){ cin>>x>>y>>k; update1(1,1,n,x,y,k); } else if(op==2){ cin>>x>>y>>k; update2(1,1,n,x,y,k); } else{ cin>>x>>y; cout<<qsum(1,1,n,x,y)<<endl; } } }
区间修改:赋值
#include<bits/stdc++.h> using namespace std; typedef long long ll; const ll INF = -1e9; const ll N =1e5+7; const ll mod=1e9+7; inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; } inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); } inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; } ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans *= a;b >>= 1;a *= a; } return ans; } ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; } inline int lowbit(int x) { return x & (-x); } ll getinv(ll x){return qpow(x,mod-2,mod);} ll a[N],sum[N<<2],change[N<<2]; void build(int l,int r,int rt){ if(l==r){ sum[rt]=a[l]; return; } int mid=l+r>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } void pushdown(int rt,int m){ if(change[rt]){ change[rt<<1]=change[rt]; change[rt<<1|1]=change[rt]; sum[rt<<1]=(m-(m>>1))*change[rt]; sum[rt<<1|1]=(m>>1)*change[rt]; change[rt]=0;//传递下去后就要取消本层标记 } } void update(int l,int r,int rt,int x,int y,int val){ if(x<=l&&r<=y){ sum[rt]=(r-l+1)*val; change[rt]=val; return; } int mid=(l+r)>>1; pushdown(rt,r-l+1); if(x<=mid) update(l,mid,rt<<1,x,y,val); if(y>mid) update(mid+1,r,rt<<1|1,x,y,val); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } ll qsum(int l,int r,int rt,int x,int y){ if(x<=l&&r<=y) return sum[rt]; pushdown(rt,r-l+1); ll ans=0; int mid=l+r>>1; if(x<=mid) ans+=qsum(l,mid,rt<<1,x,y); if(y>mid) ans+=qsum(mid+1,r,rt<<1|1,x,y); return ans; } int main(){ int n,m,op,x,y,k; cin>>n; for(int i=1;i<=n;i++) a[i]=read(); build(1,n,1); cin>>m; for(int i=1;i<=m;i++){ cin>>op; if(op==0){ cin>>x>>y; cout<<qsum(1,n,1,x,y)<<endl; } if(op==1){ cin>>x>>y>>k; update(1,n,1,x,y,k); } } }
区间修改2:求平方和
数据结构
1 l r 询问区间[l,r]内的元素和
2 l r 询问区间[l,r]内的元素的平方 和
3 l r x 将区间[l,r]内的每一个元素都乘上x
4 l r x 将区间[l,r]内的每一个元素都加上x
#include<bits/stdc++.h> using namespace std; typedef long long ll; const ll INF = -1e9; const ll N =1e5+7; inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; } inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); } ll a[N],sum[N<<2],add[N<<2],mul[N<<2],sum1[N<<2]; void pushdown(int rt,int l,int r){//这里要先乘后加 if(mul[rt]!=1){ mul[rt << 1]=(mul[rt << 1] * mul[rt]); mul[rt << 1| 1]=(mul[rt << 1| 1] * mul[rt]); sum[rt<<1]=(sum[rt<<1]*mul[rt]); sum[rt<<1|1]=(sum[rt<<1|1]*mul[rt]); sum1[rt<<1]*=mul[rt]*mul[rt]; sum1[rt<<1|1]*=mul[rt]*mul[rt]; mul[rt]=1; } int mid=l+r>>1; if(add[rt]){ add[rt<<1]+=add[rt]; add[rt<<1|1]+=add[rt]; ll x=sum[rt<<1]; ll y=sum[rt<<1|1]; sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*add[rt]); sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*add[rt]); sum1[rt<<1]+=2*x*add[rt]+(mid-l+1)*add[rt]*add[rt]; sum1[rt<<1|1]+=2*y*add[rt]+(r-mid)*add[rt]*add[rt]; add[rt]=0;//传递下去后就要取消本层标记 } } void build(int l,int r,int rt){ add[rt]=0; mul[rt]=1; if(l==r){ sum[rt]=a[l]; sum1[rt]=a[l]*a[l]; return; } int mid=l+r>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); sum[rt]=(sum[rt<<1]+sum[rt<<1|1]); sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1]; } void update1(int rt,int l,int r,int x,int y,ll z){//乘法,乘z if(x<=l&&r<=y){ sum[rt]=sum[rt] * z; sum1[rt]=sum1[rt] * z * z; mul[rt]=mul[rt] * z; if(add[rt]) add[rt]=add[rt] * z; return; } int mid=(l+r)>>1; pushdown(rt,l,r); if(x<=mid) update1(rt<<1,l,mid,x,y,z); if(y>mid) update1(rt<<1|1,mid+1,r,x,y,z); sum[rt]=(sum[rt<<1]+sum[rt<<1|1]); sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1]; } void update2(int rt,int l,int r,int x,int y,ll val){//加法 if(x<=l&&r<=y){ ll cnt=sum[rt]; sum[rt]=sum[rt]+(r-l+1)*val; sum1[rt]+=2*val*cnt+(r-l+1)*val*val; add[rt]=(add[rt]+val); return; } int mid=(l+r)>>1; pushdown(rt,l,r); if(x<=mid) update2(rt<<1,l,mid,x,y,val); if(y>mid) update2(rt<<1|1,mid+1,r,x,y,val); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1]; } ll qsum(int rt,int l,int r,int x,int y,int c){ if(x<=l&&r<=y){ if(c==1) return sum[rt]; else return sum1[rt]; } pushdown(rt,l,r); ll ans=0; int mid=l+r>>1; if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y,c); if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y,c); return ans; } int main(){ ll n,m,op,x,y,k; cin>>n>>m; for(int i=1;i<=n;i++) a[i]=read(); build(1,n,1); for(int i=1;i<=m;i++){ op=read(); if(op==3){ scanf("%lld%lld%lld",&x,&y,&k); update1(1,1,n,x,y,k); } else if(op==4){ scanf("%lld%lld%lld",&x,&y,&k); update2(1,1,n,x,y,k); } else if(op==1||op==2){ cin>>x>>y; printf("%lld\n",qsum(1,1,n,x,y,op)); } } }
代码2
#include<bits/stdc++.h> using namespace std; typedef long long ll; const ll INF = -1e9; const ll N =1e5+7; inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; } inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); } ll a[N],sum[N<<2],add[N<<2],mul[N<<2],sum1[N<<2]; void pushdown(int rt,int l,int r){//这里要先乘后加 if(mul[rt]!=1){ mul[rt << 1]=(mul[rt << 1] * mul[rt]) ; mul[rt << 1| 1]=(mul[rt << 1| 1] * mul[rt]) ; add[rt<<1]=add[rt<<1]*mul[rt]; add[rt<<1|1]=add[rt<<1|1]*mul[rt]; sum[rt<<1]=(sum[rt<<1]*mul[rt]); sum[rt<<1|1]=(sum[rt<<1|1]*mul[rt]); sum1[rt<<1]*=mul[rt]*mul[rt]; sum1[rt<<1|1]*=mul[rt]*mul[rt]; mul[rt]=1; } int mid=l+r>>1; if(add[rt]){ add[rt<<1]=(add[rt<<1]+add[rt]); add[rt<<1|1]=(add[rt<<1|1]+add[rt]); ll x=sum[rt<<1]; ll y=sum[rt<<1|1]; sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*add[rt]); sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*add[rt]); sum1[rt<<1]+=2*x*add[rt]+(mid-l+1)*add[rt]*add[rt]; sum1[rt<<1|1]+=2*y*add[rt]+(r-mid)*add[rt]*add[rt]; add[rt]=0;//传递下去后就要取消本层标记 } } void build(int l,int r,int rt){ add[rt]=0; mul[rt]=1; if(l==r){ sum[rt]=a[l]; sum1[rt]=a[l]*a[l]; return; } int mid=l+r>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); sum[rt]=(sum[rt<<1]+sum[rt<<1|1]); sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1]; } void update1(int rt,int l,int r,int x,int y,ll z){//乘法,乘z if(x<=l&&r<=y){ sum[rt]=sum[rt] * z; sum1[rt]=sum1[rt] * z * z; add[rt]=add[rt] * z; mul[rt]=mul[rt] * z; return; } int mid=(l+r)>>1; pushdown(rt,l,r); if(x<=mid) update1(rt<<1,l,mid,x,y,z); if(y>mid) update1(rt<<1|1,mid+1,r,x,y,z); sum[rt]=(sum[rt<<1]+sum[rt<<1|1]); sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1]; } void update2(int rt,int l,int r,int x,int y,ll val){//加法 if(x<=l&&r<=y){ ll cnt=sum[rt]; sum[rt]=(sum[rt]+(r-l+1)*val); sum1[rt]+=2*val*cnt+(r-l+1)*val*val; add[rt]=(add[rt]+val); return; } int mid=(l+r)>>1; pushdown(rt,l,r); if(x<=mid) update2(rt<<1,l,mid,x,y,val); if(y>mid) update2(rt<<1|1,mid+1,r,x,y,val); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1]; } ll qsum(int rt,int l,int r,int x,int y,int c){ if(x<=l&&r<=y){ if(c==1) return sum[rt]; else return sum1[rt]; } pushdown(rt,l,r); ll ans=0; int mid=l+r>>1; if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y,c); if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y,c); return ans; } int main(){ ll n,m,op,x,y,k; cin>>n>>m; for(int i=1;i<=n;i++) a[i]=read(); build(1,n,1); for(int i=1;i<=m;i++){ op=read(); if(op==3){ scanf("%lld%lld%lld",&x,&y,&k); update1(1,1,n,x,y,k); } else if(op==4){ scanf("%lld%lld%lld",&x,&y,&k); update2(1,1,n,x,y,k); } else if(op==1||op==2){ cin>>x>>y; printf("%lld\n",qsum(1,1,n,x,y,op)); } } }