2024美团春季笔试第二场 T4 区间众数求和 线段树做法

(以下讲解的下标均从1开始)

题目大意

给出一个数组 a[1~n],a[i] = 1或2, 求每个区间的众数之和。

思路

写公式

我们假设 two[i] 为 a[1~i] 的“2”的数量,one[i] 为 a[1~i] 的“1”的数量,假设m为众数为2的区间数量,那么有下面的公式:

其中[...]表示如果括号内的条件满足,则=1,否则=0.

这个式子可以转换为:

设一个数组 diff[i] = two[i] - one[i]。那么式子再写成:

可以看到这个式子非常的熟悉,可能有人会想到一维偏序什么的。

首先,-n <= diff[i] <= n。我们不喜欢负数的存在,由于比较大小时两边都加同一个数没啥毛病,考虑每个diff[i]都 +(n+1),这样就都是正数了。

回到正题,如果暴力求解这个式子的话,显然要用O(n^2)的查询。我们希望遍历r时,可以马上获取数组中在r左边,且 < diff[r]的数的数量,这里,我当时就用的线段树。

请出线段树

(不知道线段树的朋友可以先去了解一下)

建立一个线段树tree,tree[i] 表示树的第i号节点,我们先假设其表示的区间为[left,right]。则tree[i]的值就是diff[1~r-1]的所有值处于[left,right]的元素数量。比如diff[] = {3,5,4,7,1, ... },[left,right] = [1,4],假设r此时为5.则tree[i]就是3,5,4,7(下标<r)中位于[1,4]的数:3,4,它们数量的和:2。则tree[i] = 2.

假设我们处理了diff[1~r-1],把它们写到了线段树里,现在轮到diff[r]了。

我们要判断diff[1~r-1]中有多少<diff[r]的,怎么办?我们就去线段树中去找在区间[left,right] = [0,diff[r]-1]内的tree值。这个[0,diff[r]-1]是可以用线段树的性质拼接出来的,对每个组成其一部分的子区间的tree值求和即可得到答案,对应的增加众数为2的区间数量的贡献。

查找完之后我们还要将diff[r]的值加到线段树里,这样在继续找diff[r+1,n]的答案时可以把diff[r]的贡献算进去。

最后,由于总区间个数为sum = n*(n+1)/2,众数为2的区间个数为m,则众数为1的区间个数为p = sum - m,答案为m*2 + p。

复杂度

时间复杂度O(nlogn),其中线段树查询和单点修改为O(logn),做n次

空间复杂度O(n)

#include <iostream>
#include <set>
#include <algorithm>
#include <vector>
#include <cstring>
using namespace std;
using ll = long long;
const int N = (int)2e5+5;
ll n;
ll a[200005];
ll tree[N << 3]; // 开大点没关系
ll diff[200005];
// 将数值x放入线段树,此时所在的区间为[l,r],线段树下标为id
void push(ll x,int l,int r,int id){
    if(l >= r){
        tree[id]++;
        return;
    }
    int mid = (l+r)>>1;
    if(x <= mid){
        push(x,l,mid,id<<1);
    }
    else{
        push(x,mid+1,r,(id<<1)+1);
    }
    tree[id] = tree[id<<1] + tree[(id<<1)+1];
}
// 求<x的diff数量
ll getsum(ll x,int l,int r,int id){
    if(l >= r){
        if(x <= l) return 0;
        return tree[id];
    }
    ll res = 0;
    int mid = (l+r)>>1;
    if(x > mid){
        res += tree[id<<1];
        res += getsum(x,mid+1,r,(id<<1)+1);
    }
    else{
        res += getsum(x,l,mid,(id<<1));
    }
    return res;
}
void solve(){
    ll m = 0;
    for(int i = 1;i <= n;++i){
        cin >> a[i]; // 1 <= a[i] <= 2
        diff[i] = diff[i-1] + (a[i] == 2 ? 1 : -1);
    }
    for(int i = 0;i <= n;++i){
        diff[i] += n+1; // 保证所有数为正数
    }
    push(diff[0],0,2*n+1,1); // 由于l-1可以为0,因此先把diff[0]放进去
    for(int i = 1;i <= n;++i){
        m += getsum(diff[i],0,2*n+1,1); // 从最大的区间[0,2n+1]开始递归查询
        push(diff[i],0,2*n+1,1); // 单点修改,其所在的区间的值都要+1
    }
    ll p = n*(n+1)/2 - m;
    ll ans = m*2 + p;
    cout << ans << "\n";
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    while(cin >> n){
        solve();
    }
    return 0;
}

全部评论
确实没看懂
1 回复
分享
发布于 03-19 09:16 浙江

相关推荐

2 5 评论
分享
牛客网
牛客企业服务