题解 | #中位数之和#

中位数之和

https://www.nowcoder.com/practice/4bc8c3535b8e488eb608c73f8946d9cb

1. 问题分析

问题的核心在于计算所有长度为 的子序列的中位数之和。数组 是二进制数组(仅包含 0 和 1),这是一个极其关键的性质。

  • 中位数的定义:对于长度为奇数 的数组,排序后的第 个元素即为中位数。
  • 二进制特性:在一个仅包含 0 和 1 的排序数组中,形式必然为
    • 如果数组中 1 的数量足够多,覆盖了中位数的位置,则中位数为 1。
    • 否则,中位数为 0。

由此可得出一个确定性的阈值判定条件: 令 。在一个长度为 的二进制序列中,中位数为 1 当且仅当该序列中 1 的个数至少为 。否则,中位数为 0。

2. 转化问题

因为中位数只能是 0 或 1,所有子序列的中位数之和等价于中位数为 1 的子序列的数量。 原本需要遍历所有子序列(数量级 )的问题,现在转化为了一个组合计数问题

  • 统计原数组 中 1 的总数(记为 )和 0 的总数(记为 )。
  • 我们需要构建长度为 的子序列,且其中包含 个 1,满足
  • 对于每一个合法的 ,我们需要计算有多少种方式从原数组中选出 个 1 和 个 0。

3. 算法选择

数学推导 + 组合计数。 由于 的规模可达 ,且我们需要快速回答多个查询,传统的动态规划或暴力枚举不可行。利用组合数学公式可以直接在 级别计算出特定条件下的方案数,是解决此类计数问题的最优范式。

数学推导

设:

  • 为数组总长度。
  • 为数组中 1 的总个数。
  • 为数组中 0 的总个数。
  • 为使中位数为 1 所需的最少 1 的个数。

我们枚举子序列中 1 的个数 ,其中 的取值范围是 。 对于每一个固定的

  1. 选 1 的方案数:从 个 1 中取出 个,方案数为
  2. 选 0 的方案数:剩余的位置 必须由 0 填充。从 个 0 中取出 个,方案数为
  3. 乘法原理:包含恰好 个 1 的长度为 的子序列总数为

最终答案为所有满足条件的 的方案数之和:

注意隐含约束:我们在计算组合数 时,若 ,则结果为 0。

代码实现

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

static constexpr int N = 2e5 + 5;
static constexpr ll MOD = 1000000007;

array<ll, N> fact{};
array<ll, N> invFact{};

constexpr ll power(ll base, ll exp) {
    ll res = 1;
    base %= MOD;

    while (exp > 0) {
        if (exp % 2 == 1) {
            res = res * base % MOD;
        }
        base = base * base % MOD;
        exp >>= 1;
    }

    return res;
}

void precompute() {
    fact[0] = 1;
    for (int i = 1; i < N; i++) {
        fact[i] = fact[i - 1] * i % MOD;
    }

    invFact[N - 1] = power(fact[N - 1], MOD - 2);

    for (int i = N - 1; i > 0; i--) {
        invFact[i - 1] = invFact[i] * i % MOD;
    }
}

ll Comb(ll n, ll m) {
    if (m < 0 || m > n)
        return 0;
    ll res = fact[n] % MOD;
    res = res * invFact[m] % MOD;
    res = res * invFact[n - m] % MOD;
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    precompute();

    int t;
    cin >> t;

    while (t--) {
        int n, k;
        cin >> n >> k;

        int a;
        int cntOne = 0;
        for (int i = 0; i < n; i++) {
            cin >> a;
            if (a == 1)
                cntOne++;
        }
        int cntZero = n - cntOne;
        int limit = (k + 1) / 2;

        ll ans = 0;
        for (int i = limit; i <= k; i++) {
            if (i > cntOne || k - i > cntZero)
                continue;
            ll term = Comb(cntOne, i);
            term = term * Comb(cntZero, k - i) % MOD;
            ans = (ans + term) % MOD;
        }

        cout << ans << "\n";
    }
}
#牛友的春节生活#
每日一题@牛客网 文章被收录于专栏

牛客网每日一题

全部评论

相关推荐

KKorz:是这样的,还会定期默写抽查
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务