【牛客2025年1024程序员节娱乐赛】E 题解
【模板】多项式优化拉格朗日多点插值多点求值 - 题解
题目分析
这道题要求我们根据给定的 个点
,对于每次询问的横坐标
,找到一个纵坐标
,使得存在一个最高次不超过
次的多项式
满足所有给定的点。
正如题目名称所言,这是一道"模板"题,所以可以通过多项式快速多点插值和多项式快速多点求值的模板,来通过本道题。
解法一:多项式快速插值 + 多点求值
算法思路
- 使用多项式快速多点插值算法,根据给定的
个点构造出唯一的多项式
- 对于每个询问的
,使用多项式快速多点求值算法计算
得到
时间复杂度
- 构造多项式:
(快速插值)
- 总查询:
(快速多点求值)
- 总时间复杂度:
代码实现(简化版)
#include <bits/stdc++.h>
using namespace std;
using u64 = unsigned long long;
using u32 = unsigned;
using i64 = long long;
using i32 = signed;
template <i32 MOD>
u32 down(u32 x) { return x >= MOD ? x - MOD : x; }
template <i32 MOD>
struct MInt {
u32 x;
MInt () : x(0) {}
MInt (u32 x) : x(x) {}
friend MInt operator+(MInt a, MInt b) { return down<MOD>(a.x + b.x); }
friend MInt operator-(MInt a, MInt b) { return down<MOD>(a.x - b.x + MOD); }
friend MInt operator*(MInt a, MInt b) { return 1ULL * a.x * b.x % MOD; }
friend MInt operator/(MInt a, MInt b) { return a * ~b; }
friend MInt operator^(MInt a, i64 b) {
MInt ans = 1;
while (b) {
if (b & 1) ans = ans * a;
a = a * a;
b /= 2;
}
return ans;
}
friend MInt operator~(MInt a) { return a ^ (MOD - 2); }
// friend MInt operator~(MInt a) { return MInt(INV::inv(a.x)); }
friend std::istream &operator>>(std::istream &in, MInt &a) { return in >> a.x; }
friend std::ostream &operator<<(std::ostream &out, MInt a) { return out << a.x; }
friend MInt operator-(MInt a) { return down<MOD>(MOD - a.x); }
friend MInt &operator+=(MInt &a, MInt b) { return a = a + b; }
friend MInt &operator-=(MInt &a, MInt b) { return a = a - b; }
friend MInt &operator*=(MInt &a, MInt b) { return a = a * b; }
friend MInt &operator/=(MInt &a, MInt b) { return a = a / b; }
friend MInt &operator^=(MInt &a, long long b) { return a = a ^ b; }
friend bool operator==(MInt a, MInt b) { return a.x == b.x; }
friend bool operator!=(MInt a, MInt b) { return !(a == b); }
friend bool operator<(MInt a, MInt b) { return a.x < b.x; }
};
const i32 MOD = 998244353;
using Z = MInt<MOD>;
std::vector<Z> roots{0, 1};
std::vector<i32> rev;
void dft(std::vector<Z> &a){
i32 n = a.size();
if(rev.size() != n){
rev.resize(n);
i32 k = __builtin_ctz(n) - 1;
for (i32 i = 0; i < n; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
}
for (i32 i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
if (roots.size() < n) {
i32 k = __builtin_ctz(roots.size());
roots.resize(n);
while((1 << k) < n){
Z e = Z(3) ^ ((MOD - 1) >> (k + 1));
for(i32 i = 1 << (k - 1); i < 1 << k; i++) {
roots[2 * i] = roots[i];
roots[2 * i + 1] = roots[i] * e;
}
k++;
}
}
for (i32 k = 1; k < n; k *= 2) {
for (i32 i = 0; i < n; i += 2 * k) {
for (i32 j = 0; j < k; j++) {
Z u = a[i + j], v = a[i + j + k] * roots[k + j];
a[i + j] = u + v;
a[i + j + k] = u - v;
}
}
}
}
void idft(std::vector<Z> &a) {
std::reverse(a.begin() + 1, a.end());
dft(a);
i32 n = a.size();
Z inv = Z((1 - MOD) / n + MOD);
for (i32 i = 0; i < n; i++) a[i] = a[i] * inv;
}
struct Poly : std::vector<Z> {
Poly() {}
Poly(const i32 n) : std::vector<Z>(n) {}
Poly(const std::vector<Z> &a) : std::vector<Z>(a) {}
Poly(const i32 n, const Z m) : std::vector<Z>(n, m) {}
template<class It>
explicit Poly(It first, It last) : std::vector<Z>(first, last) {}
friend Poly operator*(Poly a, Poly b){
if (a.empty() || b.empty()) return Poly();
if (a.size() > b.size()) std::swap(a, b);
if (a.size() < 64) {
Poly c(a.size() + b.size() - 1);
for(i32 i = 0; i < a.size(); i++) {
for(i32 j = 0; j < b.size(); j++) {
c[i + j] = c[i + j] + a[i] * b[j];
}
}
return c;
}
i32 sz = 1, tot = a.size() + b.size() - 1;
while (sz < tot) sz *= 2;
a.resize(sz);
b.resize(sz);
dft(a);
dft(b);
for(i32 i = 0; i < sz; i++) a[i] = a[i] * b[i];
idft(a);
a.resize(tot);
return a;
}
friend Poly operator*(Poly a, Z b) {
for (i32 i = 0; i < a.size(); i++) {
a[i] *= b;
}
return a;
}
friend Poly operator+(Poly a, Poly b) {
a.resize(std::max(a.size(), b.size()));
for (i32 i = 0; i < b.size(); i++) {
a[i] += b[i];
}
return a;
}
friend Poly operator-(Poly a, Poly b) {
a.resize(std::max(a.size(), b.size()));
for (i32 i = 0; i < b.size(); i++) {
a[i] -= b[i];
}
return a;
}
Poly trunc(i32 k) {
if (k <= size()) {
return Poly(begin(), begin() + k);
}
Poly res(*this);
res.resize(k);
return res;
}
Poly inv(i32 n) {
Poly res(1);
res[0] = ~((*this)[0]);
i32 k = 1;
while (k < n) {
k *= 2;
res = res * (Poly(1, 2) - trunc(k) * res);
res.resize(k);
}
res.resize(n);
return res;
}
Poly deriv() {
if (empty()) {
return Poly();
}
Poly res(size() - 1);
for (i32 i = 0; i < size() - 1; i++) {
res[i] = (i + 1) * (*this)[i + 1];
}
return res;
}
Poly integr() {
Poly res(size() + 1);
for (i32 i = 0; i < size(); i++) {
res[i + 1] = (*this)[i] / (i + 1);
}
return res;
}
Poly log(i32 n) {
return (deriv() * inv(n)).integr().trunc(n);
}
Poly exp(i32 n) {
Poly res(1, 1);
i32 k = 1;
while (k < n) {
k *= 2;
res = res * (Poly(1, 1) - res.log(k) + trunc(k));
res.resize(k);
}
res.resize(n);
return res;
}
Poly sqrt(i32 n) {
Poly res(1, 1);
i32 k = 1;
while (k < n) {
k *= 2;
res = trunc(k) * res.inv(k) + res;
res.resize(k);
res = res * (~Z(2));
}
res.resize(n);
return res;
}
// 多项式求值
Z evaluate(Z x) const {
Z result = 0;
Z power = 1;
for (i32 i = 0; i < size(); i++) {
result = result + (*this)[i] * power;
power = power * x;
}
return result;
}
};
Poly MulT(Poly a, Poly b) {
i32 n = a.size(), m = b.size();
std::reverse(b.begin(), b.end());
b = a * b;
for (i32 i = 0; i < n; i++) {
a[i] = b[i + m - 1];
}
return a;
}
std::vector<Z> multipoints(Poly f, std::vector<Z> a) {
i32 n = std::max(f.size(), a.size());
f.resize(n);
a.resize(n);
std::vector<Z> v(n);
std::vector<Poly> Q;
Q.resize(n << 2);
auto MPinit = [&](auto &&self, std::vector<Z> &a, i32 u, i32 cl, i32 cr) -> void {
if (cl == cr) {
Q[u].resize(2);
Q[u][0] = 1;
Q[u][1] = MOD - a[cl];
return;
}
i32 mid = (cl + cr) >> 1;
self(self, a, u << 1, cl, mid);
self(self, a, u << 1 | 1, mid + 1, cr);
Q[u] = Q[u << 1] * Q[u << 1 | 1];
};
auto MPcal = [&](auto &&self, i32 u, i32 cl, i32 cr, Poly f, std::vector<Z> &g) -> void {
f.resize(cr - cl + 1);
if (cl == cr) {
g[cl] = f[0];
return;
}
i32 mid = (cl + cr) >> 1;
self(self, u << 1, cl, mid, MulT(f, Q[u << 1 | 1]), g);
self(self, u << 1 | 1, mid + 1, cr, MulT(f, Q[u << 1]), g);
};
MPinit(MPinit, a, 1, 0, n - 1);
MPcal(MPcal, 1, 0, n - 1, MulT(f, Q[1].inv(n + 1)), v);
return v;
}
Poly interpolate(std::vector<Z> x, std::vector<Z> y) {
i32 n = x.size();
std::vector<Poly> up(n * 2);
for (i32 i = 0; i < n; i++) {
up[i + n] = Poly(2, 0);
up[i + n][0] = -Z(x[i]);
up[i + n][1] = 1;
}
for (i32 i = n - 1; i > 0; i--) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
Poly a = multipoints(up[1].deriv(), x);
for (i32 i = 0; i < n; i++) {
a[i] = y[i] / a[i];
}
std::vector<Poly> down(n * 2);
for (i32 i = 0; i < n; i++) {
down[i + n] = Poly(1, a[i]);
}
for (i32 i = n - 1; i > 0; i--) {
down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
}
return down[1];
}
signed main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
i32 n;
std::cin >> n;
std::vector<Z> x(n), y(n);
for (i32 i = 0; i < n; i++) {
std::cin >> x[i] >> y[i];
}
Poly f = interpolate(x, y);
i32 q;
std::cin >> q;
std::vector<Z> x2(q);
for (i32 i = 0; i < q; i++) {
std::cin >> x2[i];
}
std::vector<Z> y2 = multipoints(f, x2);
for (i32 i = 0; i < q; i++) {
std::cout << y2[i] << "\n";
}
std::cout << "\n";
return 0;
}
解法二:利用唯一性性质
关键观察
重要性质:给定 个横坐标不同的点,恰好可以确定一个最高次不超过
次的多项式。这意味着:
- 如果询问的
不在给定的
个点中,那么存在唯一的多项式
满足所有条件
- 如果询问的
在给定的点中,那么
就是该点对应的纵坐标
算法思路
根据多项式插值的唯一性,给定 个点恰好确定一个最高次不超过
次的多项式。因此:
- 如果询问的
在给定的点中,直接返回对应的
- 如果询问的
不在给定的点中,可以返回任意值(比如 0)
时间复杂度
- 预处理:
建立坐标到纵坐标的映射
- 单次查询:
- 总时间复杂度:
代码实现
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
map<int, int> point_map;
for (int i = 0; i < n; i++) {
int x, y;
cin >> x >> y;
point_map[x] = y;
}
int q;
cin >> q;
while (q--) {
int x0;
cin >> x0;
if (point_map.count(x0)) {
cout << point_map[x0] << '\n';
} else {
cout << 0 << '\n'; // 任意值都可以
}
}
return 0;
}
为什么解法二正确?
根据多项式插值的唯一性, 个横坐标不同的点恰好可以确定一个最高次不超过
次的多项式。因此:
- 如果询问的
在给定的点中,直接返回对应的纵坐标
- 如果询问的
不在给定的点中,可以返回任意值,因为我们可以构造一个满足条件的多项式
补充说明
关于 unordered_map 的陷阱:如果使用 unordered_map 而没有对桶大小或哈希函数进行特殊处理,在本题中可能会遇到特意构造的哈希碰撞数据,导致 unordered_map 的时间复杂度退化到 ,无法通过本题。因此推荐使用
map 来保证稳定的 查询时间。