CSP-S 2020 T3 函数调用

一不小心先开了T3,憋了好几天。题解一开始也看不懂,对我来说太难了。

题目链接:
https://www.luogu.com.cn/problem/P7077

题目大意:
给n个数a[i],给m个函数g[i],其中函数有3种类型:
type1:p v 给a[p]加上v
type2: v 对n个数同时乘上v
type3: 调用若干个函数(在1-m范围内),保证不会出现递归

给出一个序列Q是函数的调用序列,问调用完以后的a[i]数组是多少?
n,m<=10^5,调用的函数不会超过10^6

思路:
这题如果直接模拟的话,还是可以得到一个比较暴力的思路的。

先说说我的暴力做法:

假设只有type1,可以直接模拟,假设只有type2,可以计算一共乘了多少,然后对a[i]数组乘上去就好了。
如果同时存在type1和type2,直接模拟的话,我们需要考虑后面的乘法对前面的加法的影响。
例如:((a+v1) * v2+v3) * v4,显然我们发现多出来的有v1 * v2 * v3,v3 * v4。那么我们可以用一个变量k计算出乘的结果,然后对a * k,加法的部分我们可以在计算k的过程中用一个数组b[i]记录下当前已经加了多少了然后进行乘法。
这样做,每次遇到乘法的时候我们都要对数组a[i]进行一次遍历。显然会超时。
为了骗取更多的分数为了加速这个遍历,我考虑使用线段树维护a[i]数组,当遇到乘法的时候,我们记录k为已经乘了多少。当遇到加法的时候,我们在线段树中查找这个a[i],先把前一段乘法乘上去,然后再加。最后再打个标记表示已经乘了多少了。
具体的,假设当前全局乘k,当前对a[x]进行加y操作。先把前面一段的乘法乘上去,我们需要记录一个laz标记,表示前面已经乘了laz了,那么显然当前这一段要再乘inv[laz]*k,然后再加y,接着更新laz=k。这里要用到逆元。
这里的做法有一个坑点,就是数据中会有乘0的存在,此时逆元就变0了。因此遇到这种情况,我们就把数组a[i]全部变0,然后重新建树。
上述做法能骗70分。

来看正解:
题解看了好久
首先其实我们可以倒序处理刚刚的加法和乘法贡献问题,这样我们只要维护当前的乘积k,然后遇到加法就直接对加的数乘k就好了。
根据这个性质,我们可以对输入的Q序列倒序处理。
我们先开一个数组mul[i]表示调用函数i会对答案乘上多少。若:
type[i]=1 ,mul[i]=1
type[i]=2,mul[i]=v[i]
type[i]=3,mul[i]=调用子函数的mul乘积
对于type[i]=3的情况,我们可以跑一遍dfs记忆化搜索。

ll dfs(int x){
    if(mul[x]!=-1)return mul[x];   //记忆化,mul[x]有可能为0,因此设为-1
    ll ans=1;
    for(int i=head[x];~i;i=e[i].next){
        int y=e[i].to;
        ans=ans*dfs(y)%mod;
    }

    return mul[x]=ans%mod;
}

接下来我们对Q序列倒序处理,这里有一个重要的性质:乘k等价于让前面的加法再执行k次
我们开一个数组f[i]表示函数i被调用的次数。
我们先倒序Q序列,遇到type3不进行向下调用,只记录出现的次数。由于后出现的函数会影响前面的函数,那么我们记录一个now表示后面对前面的影响。

ll now=1;
    for(int i=Q;i>=1;i--){    //倒序处理序列
        int x=q[i];
        f[x]=(f[x]+now)%mod;    //加上之前的后面函数影响
        now=now*mul[x]%mod;     //更新后面函数的影响
    }

最后我们考虑type3函数的向下值的更新,显然对于函数i调用的所有函数j,我们需要从后往前进行处理。由于题目保证不出现递归,那么我们知道得到的图一定是一个DAG。这里我们可以先拓扑排序保证顺序,接着我们按照刚刚的方法进行同样的计算就好了。
假设从函数i调用到函数j,我们有:
f[j]=f[j]+f[i]
f[i]=f[i]*mul[j] //再往前看左边的子树的时候考虑右边已经产生的影响。

void getdown(){
    for(int i=1;i<=m;i++){
        int x=nb[i];
        ll now=1;
        for(int j=head[x];~j;j=e[j].next){
            int y=e[j].to;
            f[y]=(f[y]+f[x]%mod)%mod;
            f[x]=f[x]*mul[y]%mod;
        }

    }
}

最后我们看一下加法的函数计算一下最终结果就好了。
代码:

#include<bits/stdc++.h>
#define mod 998244353
#define ll long long
using namespace std;

int n,m;
ll a[100040],b[100040],k=1;
int tot=0,head[100040],q[100040],in[100040];
struct edge{
    int to,next;
}e[1000030];
struct node{
    int t,p,v;
}g[100400];
ll mul[100040],f[100040],nb[100040];
void init(){
    tot=0;
    memset(head,-1,sizeof(head));
    memset(in,0,sizeof(in));
    memset(mul,-1,sizeof(mul));
}
void add(int a,int b){
    tot++;
    e[tot].to=b;e[tot].next=head[a];head[a]=tot;
}
ll dfs(int x){
    if(mul[x]!=-1)return mul[x];
    ll ans=1;
    for(int i=head[x];~i;i=e[i].next){
        int y=e[i].to;
        ans=ans*dfs(y)%mod;
    }

    return mul[x]=ans%mod;
}
void toposort(){
    queue<int>qu;
    for(int i=1;i<=m;i++){
        if(in[i]==0)qu.push(i);
    }
    int t=0;
    while(!qu.empty()){
        int now=qu.front();
        nb[++t]=now;
        qu.pop();
        for(int i=head[now];~i;i=e[i].next){
            int y=e[i].to;
            in[y]--;
            if(in[y]==0)qu.push(y);
        }
    }
}
void getdown(){
    for(int i=1;i<=m;i++){
        int x=nb[i];
        ll now=1;
        for(int j=head[x];~j;j=e[j].next){
            int y=e[j].to;
            f[y]=(f[y]+f[x]%mod)%mod;
            f[x]=f[x]*mul[y]%mod;
        }

    }
}
int main()
{
    cin>>n;
    for(int i=1;i<=n;i++)cin>>a[i];
    cin>>m;
    init();
    for(int i=1;i<=m;i++){
        scanf("%d",&g[i].t);
        if(g[i].t==1){
            scanf("%d%d",&g[i].p,&g[i].v);
            mul[i]=1;
        }
        else if(g[i].t==2){
            scanf("%d",&g[i].v);
            mul[i]=g[i].v;
        }
        else {
            int x;
            scanf("%d",&x);
            for(int j=1;j<=x;j++){
                int y;
                scanf("%d",&y);
                in[y]++;
                add(i,y);
            }
        }
    }
    for(int i=1;i<=m;i++)dfs(i);
    int Q;
    cin>>Q;
    for(int i=1;i<=Q;i++){
        scanf("%d",&q[i]);
    }

    ll now=1;
    for(int i=Q;i>=1;i--){   //倒序处理序列统计函数出现次数
        int x=q[i];
        f[x]=(f[x]+now)%mod;
        now=now*mul[x]%mod;
    }

    toposort();    //拓扑排序
    for(int i=1;i<=n;i++)a[i]=a[i]*now%mod;   //全局乘法先计算上
    getdown();   //向下更新
    for(int i=1;i<=m;i++){
        if(g[i].t==1){
            int x=g[i].p;
            a[x]=a[x]+f[i]*g[i].v%mod;
            a[x]%=mod;
        }
    }
    for(int i=1;i<=n;i++)cout<<a[i]<<' ';
    cout<<endl;
    return 0;
}

总结:
倒序处理是个好方法。
中间的统计比较抽象,还是需要好好体会这种计算方法的使用。

全部评论

相关推荐

点赞 收藏 评论
分享
牛客网
牛客企业服务