题解 | #[SCOI2010]序列操作#

[SCOI2010]序列操作

https://ac.nowcoder.com/acm/problem/20279

没有思路直接上代码

#include <iostream>
#include <cstdio>

#define lson u << 1
#define rson u << 1 | 1

using namespace std;

const int N = 1e5 + 10;

int n, m;
int a[N];
struct Node
{
    int l, r;
    int lazy;//0表示全变0 1表示全变1 2表示全变取反 -1表示什么都不干
    int sum[2], ma[2], pre[2], suf[2];
    //sum[i]总共用多少个i   ma[i]最多多少个连续的i 
    //pre[i]前缀中最多多少个连续的i   suf[i]后缀中最多多少个连续的i
}tr[N << 2];

template <class T>
inline void read(T & res)
{
    char ch; bool flag = false;
    while ((ch = getchar()) < '0' || ch > '9')
        if (ch == '-') flag = true;
    res = ch ^ 48;
    while ((ch = getchar()) >= '0' && ch <= '9')
    res = (res << 3) + (res << 1) + (ch ^ 48);
    if (flag) res = ~res + 1;
}

void pushup(Node &u, Node &x, Node &y)
{
    //考虑到下面会使用mid - u.l + 1以此来求区间长度
    //但是这里的mid千万不能写成 int mid = u.l + u.r >> 1;
    //因为这里的u节点里面不一定有东西,query的else里面pushup的是一个空的u
    //所以直接使用x.r - x.l + 1来求区间长度
    for (int i = 0; i < 2; i ++)
    {
        u.sum[i] = x.sum[i] + y.sum[i];
        u.ma[i] = max(max(x.ma[i], y.ma[i]), x.suf[i] + y.pre[i]);
        //这里直接压行
        u.pre[i] = x.pre[i] + y.pre[i] * (x.pre[i] == x.r - x.l + 1);
        u.suf[i] = y.suf[i] + x.suf[i] * (y.suf[i] == y.r - y.l + 1);
        /*不压行的话
        u.pre[i] = x.pre[i];
        if (x.pre[i] == x.r - x.l + 1) u.pre[i] = x.pre[i] + y.pre[i];
        u.suf[i] = y.suf[i];
        if (y.suf[i] == y.r - y.l + 1) u.suf[i] = y.suf[i] + x.suf[i];
        */
    }
}

void pushup(int u) {pushup(tr[u], tr[lson], tr[rson]);}

void build(int u, int l, int r)
{
    tr[u].l = l, tr[u].r = r, tr[u].lazy = -1;//一开始懒标记要标记成-1,表示啥也不干
    if (l == r)
    {        
        tr[u].sum[a[l]] = tr[u].ma[a[l]] = tr[u].pre[a[l]] = tr[u].suf[a[l]] = 1;
        return ;
    }
    int mid = l + r >> 1;
    build(lson, l, mid), build(rson, mid + 1, r);
    pushup(u);
}

//将u节点所代表的区间内所有数全变v
void change1(int u, int v)
{
    tr[u].sum[v] = tr[u].ma[v] = tr[u].pre[v] = tr[u].suf[v] = tr[u].r - tr[u].l + 1;
    tr[u].sum[v ^ 1] = tr[u].ma[v ^ 1] = tr[u].pre[v ^ 1] = tr[u].suf[v ^ 1] = 0;
    tr[u].lazy = v;
}

//将u节点所代表的区间内所有数全部取反
void change2(int u)
{
    swap(tr[u].sum[0], tr[u].sum[1]);
    swap(tr[u].ma[0], tr[u].ma[1]);
    swap(tr[u].pre[0], tr[u].pre[1]);
    swap(tr[u].suf[0], tr[u].suf[1]);

    if (tr[u].lazy == -1) tr[u].lazy = 2;//什么都没标记的话,标记取反
    else if (tr[u].lazy == 2) tr[u].lazy = -1;//标记取反的话,两个取反就抵消了
    else tr[u].lazy ^= 1;//否则0变1,1变0
}

void pushdown(int u)
{    
    if (tr[u].lazy == -1) return ;//啥也不干
    else if (tr[u].lazy == 2)
    {
        change2(lson), change2(rson);
        tr[u].lazy = -1;//清空懒标记
    }
    else//全变1或者全变0
    {
        change1(lson, tr[u].lazy), change1(rson, tr[u].lazy);
        tr[u].lazy = -1;//清空懒标记
    }
}

//将[x,y]区间内所有数变为v
void modify(int u, int x, int y, int v)
{
    if (x <= tr[u].l && tr[u].r <= y)
    {
        change1(u, v);
        return ;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (x <= mid) modify(lson, x, y, v);
    if (y > mid) modify(rson, x, y, v);
    pushup(u);
}

//将[x,y]区间内所有数全部取反
void rev(int u, int x, int y)
{
    if (x <= tr[u].l && tr[u].r <= y)
    {
        change2(u);
        return ;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (x <= mid) rev(lson, x, y);
    if (y > mid) rev(rson, x, y);
    pushup(u);
}

Node query(int u, int x, int y)
{    
    if (x <= tr[u].l && tr[u].r <= y) return tr[u];
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (y <= mid) return query(lson, x, y);
    else if (x > mid) return query(rson, x, y);
    else
    {
        Node l, r, res;
        l = query(lson, x, y), r = query(rson, x, y);        
        pushup(res, l, r);//这里的res是空的,所以pushup里面mid不能再像往常一样定义
        return res;
    }
}

int main()
{
    read(n), read(m);
    for (int i = 1; i <= n; i ++) read(a[i]);
    build(1, 1, n);
    
    for (int i = 1; i <= m; i ++)
    {
        int op, a, b;
        read(op), read(a), read(b);
        a ++, b ++;//题目下标从0开始,所以全部+1
        if (op == 0) modify(1, a, b, 0);
        else if (op == 1) modify(1, a, b, 1);
        else if (op == 2) rev(1, a, b);
        else if (op == 3) printf("%d\n", query(1, a, b).sum[1]);
        else printf("%d\n", query(1, a, b).ma[1]);
    }

    return 0;
}
全部评论

相关推荐

1 收藏 评论
分享
牛客网
牛客企业服务