【牛客2025年1024程序员节娱乐赛】E 题解

【模板】多项式优化拉格朗日多点插值多点求值 - 题解

题目分析

这道题要求我们根据给定的 n 个点 (x_i, y_i),对于每次询问的横坐标 x_0,找到一个纵坐标 y_0,使得存在一个最高次不超过 n 次的多项式 f(x) 满足所有给定的点。

正如题目名称所言,这是一道"模板"题,所以可以通过多项式快速多点插值和多项式快速多点求值的模板,来通过本道题。

解法一:多项式快速插值 + 多点求值

算法思路

  1. 使用多项式快速多点插值算法,根据给定的 n 个点构造出唯一的多项式 f(x)
  2. 对于每个询问的 x_0,使用多项式快速多点求值算法计算 f(x_0) 得到 y_0

时间复杂度

  • 构造多项式:O(n \log^2 n)(快速插值)
  • 总查询:O(q\log n)(快速多点求值)
  • 总时间复杂度:O(n \log^2 n + q \log n)

代码实现(简化版)

#include <bits/stdc++.h>
using namespace std;

using u64 = unsigned long long;
using u32 = unsigned;
using i64 = long long;
using i32 = signed;

template <i32 MOD>
u32 down(u32 x) { return x >= MOD ? x - MOD : x; }

template <i32 MOD>
struct MInt {
    u32 x;
    MInt () : x(0) {}
    MInt (u32 x) : x(x) {}
    friend MInt operator+(MInt a, MInt b) { return down<MOD>(a.x + b.x); }
    friend MInt operator-(MInt a, MInt b) { return down<MOD>(a.x - b.x + MOD); }
    friend MInt operator*(MInt a, MInt b) { return 1ULL * a.x * b.x % MOD; }
    friend MInt operator/(MInt a, MInt b) { return a * ~b; }
    friend MInt operator^(MInt a, i64 b) {
        MInt ans = 1;
        while (b) {
            if (b & 1) ans = ans * a;
            a = a * a;
            b /= 2;
        }
        return ans;
    }
    friend MInt operator~(MInt a) { return a ^ (MOD - 2); }
    // friend MInt operator~(MInt a) { return MInt(INV::inv(a.x)); }
    friend std::istream &operator>>(std::istream &in, MInt &a) { return in >> a.x; }
    friend std::ostream &operator<<(std::ostream &out, MInt a) { return out << a.x; }
    friend MInt operator-(MInt a) { return down<MOD>(MOD - a.x); }
    friend MInt &operator+=(MInt &a, MInt b) { return a = a + b; }
    friend MInt &operator-=(MInt &a, MInt b) { return a = a - b; }
    friend MInt &operator*=(MInt &a, MInt b) { return a = a * b; }
    friend MInt &operator/=(MInt &a, MInt b) { return a = a / b; }
    friend MInt &operator^=(MInt &a, long long b) { return a = a ^ b; }
    friend bool operator==(MInt a, MInt b) { return a.x == b.x; }
    friend bool operator!=(MInt a, MInt b) { return !(a == b); }
    friend bool operator<(MInt a, MInt b) { return a.x < b.x; }
};

const i32 MOD = 998244353;
using Z = MInt<MOD>;

std::vector<Z> roots{0, 1};
std::vector<i32> rev;

void dft(std::vector<Z> &a){
    i32 n = a.size();
    if(rev.size() != n){
        rev.resize(n);
        i32 k = __builtin_ctz(n) - 1;
        for (i32 i = 0; i < n; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
    }
    for (i32 i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
    if (roots.size() < n) {
        i32 k = __builtin_ctz(roots.size());
        roots.resize(n);
        while((1 << k) < n){
            Z e = Z(3) ^ ((MOD - 1) >> (k + 1));
            for(i32 i = 1 << (k - 1); i < 1 << k; i++) {
                roots[2 * i] = roots[i];
                roots[2 * i + 1] = roots[i] * e;
            }
            k++;
        }
    }
    for (i32 k = 1; k < n; k *= 2) {
        for (i32 i = 0; i < n; i += 2 * k) {
            for (i32 j = 0; j < k; j++) {
                Z u = a[i + j], v = a[i + j + k] * roots[k + j];
                a[i + j] = u + v;
                a[i + j + k] = u - v;
            }
        }
    }
}
void idft(std::vector<Z> &a) {
    std::reverse(a.begin() + 1, a.end());
    dft(a);
    i32 n = a.size();
    Z inv = Z((1 - MOD) / n + MOD);
    for (i32 i = 0; i < n; i++) a[i] = a[i] * inv;
}

struct Poly : std::vector<Z> {
    Poly() {}
    Poly(const i32 n) : std::vector<Z>(n) {}
    Poly(const std::vector<Z> &a) : std::vector<Z>(a) {}
    Poly(const i32 n, const Z m) : std::vector<Z>(n, m) {}
    template<class It>
    explicit Poly(It first, It last) : std::vector<Z>(first, last) {}


    friend Poly operator*(Poly a, Poly b){
        if (a.empty() || b.empty()) return Poly();
        if (a.size() > b.size()) std::swap(a, b);
        if (a.size() < 64) {
            Poly c(a.size() + b.size() - 1);
            for(i32 i = 0; i < a.size(); i++) {
                for(i32 j = 0; j < b.size(); j++) {
                    c[i + j] = c[i + j] + a[i] * b[j];
                }
            }
            return c;
        }
        i32 sz = 1, tot = a.size() + b.size() - 1;
        while (sz < tot) sz *= 2;
        a.resize(sz);
        b.resize(sz);
        dft(a);
        dft(b);
        for(i32 i = 0; i < sz; i++) a[i] = a[i] * b[i];
        idft(a);
        a.resize(tot);
        return a;
    }
    friend Poly operator*(Poly a, Z b) {
        for (i32 i = 0; i < a.size(); i++) {
            a[i] *= b;
        }
        return a;
    }
    friend Poly operator+(Poly a, Poly b) {
        a.resize(std::max(a.size(), b.size()));
        for (i32 i = 0; i < b.size(); i++) {
            a[i] += b[i];
        }
        return a;
    }
    friend Poly operator-(Poly a, Poly b) {
        a.resize(std::max(a.size(), b.size()));
        for (i32 i = 0; i < b.size(); i++) {
            a[i] -= b[i];
        }
        return a;
    }
    Poly trunc(i32 k) {
        if (k <= size()) {
            return Poly(begin(), begin() + k);
        }
        Poly res(*this);
        res.resize(k);
        return res;
    }
    Poly inv(i32 n) {
        Poly res(1);
        res[0] = ~((*this)[0]);
        i32 k = 1;
        while (k < n) {
            k *= 2;
            res = res * (Poly(1, 2) - trunc(k) * res);
            res.resize(k);
        }
        res.resize(n);
        return res;
    }
    Poly deriv() {
        if (empty()) {
            return Poly();
        }
        Poly res(size() - 1);
        for (i32 i = 0; i < size() - 1; i++) {
            res[i] = (i + 1) * (*this)[i + 1];
        }
        return res;
    }
    Poly integr() {
        Poly res(size() + 1);
        for (i32 i = 0; i < size(); i++) {
            res[i + 1] = (*this)[i] / (i + 1);
        }
        return res;
    }
    Poly log(i32 n) {
        return (deriv() * inv(n)).integr().trunc(n);
    }
    Poly exp(i32 n) {
        Poly res(1, 1);
        i32 k = 1;
        while (k < n) {
            k *= 2;
            res = res * (Poly(1, 1) - res.log(k) + trunc(k));
            res.resize(k);
        }
        res.resize(n);
        return res;
    }
    Poly sqrt(i32 n) {
        Poly res(1, 1);
        i32 k = 1;
        while (k < n) {
            k *= 2;
            res = trunc(k) * res.inv(k) + res;
            res.resize(k);
            res = res * (~Z(2));
        }
        res.resize(n);
        return res;
    }
    
    // 多项式求值
    Z evaluate(Z x) const {
        Z result = 0;
        Z power = 1;
        for (i32 i = 0; i < size(); i++) {
            result = result + (*this)[i] * power;
            power = power * x;
        }
        return result;
    }
};

Poly MulT(Poly a, Poly b) {
    i32 n = a.size(), m = b.size();
    std::reverse(b.begin(), b.end());
    b = a * b;
    for (i32 i = 0; i < n; i++) {
        a[i] = b[i + m - 1];
    }
    return a;
}


std::vector<Z> multipoints(Poly f, std::vector<Z> a) {
    i32 n = std::max(f.size(), a.size());
    f.resize(n);
    a.resize(n);
    std::vector<Z> v(n);
    std::vector<Poly> Q;
    Q.resize(n << 2);

    auto MPinit = [&](auto &&self, std::vector<Z> &a, i32 u, i32 cl, i32 cr) -> void {
        if (cl == cr) {
            Q[u].resize(2);
            Q[u][0] = 1;
            Q[u][1] = MOD - a[cl];
            return;
        }
        i32 mid = (cl + cr) >> 1;
        self(self, a, u << 1, cl, mid);
        self(self, a, u << 1 | 1, mid + 1, cr);
        Q[u] = Q[u << 1] * Q[u << 1 | 1];
    };

    auto MPcal = [&](auto &&self, i32 u, i32 cl, i32 cr, Poly f, std::vector<Z> &g) -> void {
        f.resize(cr - cl + 1);
        if (cl == cr) {
            g[cl] = f[0];
            return;
        }
        i32 mid = (cl + cr) >> 1;
        self(self, u << 1, cl, mid, MulT(f, Q[u << 1 | 1]), g);
        self(self, u << 1 | 1, mid + 1, cr, MulT(f, Q[u << 1]), g);
    };

    MPinit(MPinit, a, 1, 0, n - 1);
    MPcal(MPcal, 1, 0, n - 1, MulT(f, Q[1].inv(n + 1)), v);
    return v;
}

Poly interpolate(std::vector<Z> x, std::vector<Z> y) {
    i32 n = x.size();
    std::vector<Poly> up(n * 2);
    for (i32 i = 0; i < n; i++) {
        up[i + n] = Poly(2, 0);
        up[i + n][0] = -Z(x[i]);
        up[i + n][1] = 1;
    }
    for (i32 i = n - 1; i > 0; i--) {
        up[i] = up[i << 1] * up[i << 1 | 1];
    }
    Poly a = multipoints(up[1].deriv(), x);
    for (i32 i = 0; i < n; i++) {
        a[i] = y[i] / a[i];
    }
    std::vector<Poly> down(n * 2);
    for (i32 i = 0; i < n; i++) {
        down[i + n] = Poly(1, a[i]);
    }
    for (i32 i = n - 1; i > 0; i--) {
        down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
    }
    return down[1];
}

signed main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
    i32 n;
    std::cin >> n;

    std::vector<Z> x(n), y(n);
    for (i32 i = 0; i < n; i++) {
        std::cin >> x[i] >> y[i];
    }

    Poly f = interpolate(x, y);
    i32 q;
    std::cin >> q;

    std::vector<Z> x2(q);
    for (i32 i = 0; i < q; i++) {
        std::cin >> x2[i];
    }

    std::vector<Z> y2 = multipoints(f, x2);
    for (i32 i = 0; i < q; i++) {
        std::cout << y2[i] << "\n";
    }
    std::cout << "\n";
    
    return 0;
}

解法二:利用唯一性性质

关键观察

重要性质:给定 n 个横坐标不同的点,恰好可以确定一个最高次不超过 n-1 次的多项式。这意味着:

  1. 如果询问的 x_0 不在给定的 n 个点中,那么存在唯一的多项式 f(x) 满足所有条件
  2. 如果询问的 x_0 在给定的点中,那么 y_0 就是该点对应的纵坐标

算法思路

根据多项式插值的唯一性,给定 n 个点恰好确定一个最高次不超过 n-1 次的多项式。因此:

  1. 如果询问的 x_0 在给定的点中,直接返回对应的 y_i
  2. 如果询问的 x_0 不在给定的点中,可以返回任意值(比如 0)

时间复杂度

  • 预处理:O(n) 建立坐标到纵坐标的映射
  • 单次查询:O(\log n)
  • 总时间复杂度:O(n + q \log n)

代码实现

#include <bits/stdc++.h>
using namespace std;

const int MOD = 998244353;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int n;
    cin >> n;
    
    map<int, int> point_map;
    
    for (int i = 0; i < n; i++) {
        int x, y;
        cin >> x >> y;
        point_map[x] = y;
    }
    
    int q;
    cin >> q;
    
    while (q--) {
        int x0;
        cin >> x0;
        
        if (point_map.count(x0)) {
            cout << point_map[x0] << '\n';
        } else {
            cout << 0 << '\n';  // 任意值都可以
        }
    }
    
    return 0;
}

为什么解法二正确?

根据多项式插值的唯一性,n+1 个横坐标不同的点恰好可以确定一个最高次不超过 n 次的多项式。因此:

  • 如果询问的 x_0 在给定的点中,直接返回对应的纵坐标
  • 如果询问的 x_0 不在给定的点中,可以返回任意值,因为我们可以构造一个满足条件的多项式

补充说明

关于 unordered_map 的陷阱:如果使用 unordered_map 而没有对桶大小或哈希函数进行特殊处理,在本题中可能会遇到特意构造的哈希碰撞数据,导致 unordered_map 的时间复杂度退化到 O(n^2 + qn),无法通过本题。因此推荐使用 map 来保证稳定的 O(\log n) 查询时间。

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务