牛客练习赛71 E 神奇的迷宫
神奇的迷宫
https://ac.nowcoder.com/acm/contest/7745/E
题目可以简化为ans[i]表示距离为i的点对个数的概率和,求出这个概率和即可。
考虑使用点分支分解整棵树,然后在子树中选取深度小的进行启发式合并,这里合并用ntt进行加速。
复杂度O(nlognlogn)。
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 262144 + 100;
const int MOD = 998244353;
namespace NTT {
#define pw(n) (1<<n)
const int N = 262144, P = 998244353, g = 3;//或P=1004535809
int n, m, bit, bitnum = 0, a[N + 5], b[N + 5], rev[N + 5];
void getrev(int l) {
for (int i = 0; i < pw(l); i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
}
int fastpow(int a, int b) {
int ans = 1;
for (; b; b >>= 1, a = 1LL * a*a%P) {
if (b & 1)ans = 1LL * ans*a%P;
}
return ans;
}
void NTT(int *s, int op) {
for (int i = 0; i < bit; i++)if (i < rev[i])swap(s[i], s[rev[i]]);
for (int i = 1; i < bit; i <<= 1) {
int w = fastpow(g, (P - 1) / (i << 1));
for (int p = i << 1, j = 0; j < bit; j += p) {
int wk = 1;
for (int k = j; k < i + j; k++, wk = 1LL * wk*w%P) {
int x = s[k], y = 1LL * s[k + i] * wk%P;
s[k] = (x + y) % P;
s[k + i] = (x - y + P) % P;
}
}
}
if (op == -1) {
reverse(s + 1, s + bit);
int inv = fastpow(bit, P - 2);
for (int i = 0; i < bit; i++)s[i] = 1LL * s[i] * inv%P;
}
}
int solve(int *aa, int nn, int *bb, int mm, int *c) {
n = nn; m = mm;
bit = bitnum = 0;
for (int i = 0; i <= n; i++) a[i] = aa[i];
for (int i = 0; i <= m; i++) b[i] = bb[i];
m += n;
for (bit = 1; bit <= m; bit <<= 1)bitnum++;
getrev(bitnum);
NTT(a, 1);
NTT(b, 1);
for (int i = 0; i < bit; i++) a[i] = 1LL * a[i] * b[i] % P;
NTT(a, -1);
for (int i = 0; i < bit; i++) c[i] = a[i];
for (int i = 0; i < bit; i++) a[i] = b[i] = 0;
return bit;
}
}
ll qpow(ll x, ll n) {
ll res = 1;
while (n > 0) {
if (n & 1) res = res * x % MOD;
n /= 2;
x = x * x % MOD;
}
return res;
}
int n, MX, R;
int sa[N], ww[N], siz[N], ms[N];
bool vis[N];
vector<int> V[N];
void getroot(int u, int fa) {
siz[u] = 1; ms[u] = 0;
for (int v : V[u]) {
if (vis[v] || v == fa) continue;
getroot(v, u);
siz[u] += siz[v];
ms[u] = max(ms[u], siz[v]);
}
ms[u] = max(ms[u], MX - siz[u]);
if (ms[u] < ms[R]) R = u;
}
int dep[N], res[N], now[N], ss[N], md[N], ans[N];
void upd(int &a, int b) {
a += b;
if (a >= MOD) a -= MOD;
}
void dfs(int u, int fa) {
md[u] = dep[u];
siz[u] = 1;
for (int v : V[u]) {
if (vis[v] || v == fa) continue;
dep[v] = dep[u] + 1;
dfs(v, u);
siz[u] += siz[v];
md[u] = max(md[u], md[v]);
}
}
void dfs1(int u, int fa) {
upd(res[dep[u]], sa[u]);
for (int v : V[u]) {
if (vis[v] || v == fa) continue;
dfs1(v, u);
}
}
int id[N], tp;
bool cmp(int a, int b) {
return md[a] < md[b];
}
void divide(int u) {
vis[u] = true;
tp = 0; int mm = 0;
for (int v : V[u]) {
if (vis[v]) continue;
dep[v] = 1;
dfs(v, u);
id[++tp] = v;
}
sort(id + 1, id + tp + 1, cmp);
now[0] = sa[u];
for (int i = 1; i <= tp; i++) {
int v = id[i];
dfs1(v, u);
int tt = NTT::solve(now, mm, res, md[v], ss);
for (int i = 1; i <= tt; i++) upd(ans[i], ss[i]);
for (int i = 0; i <= md[v]; i++) upd(now[i], res[i]);
for (int i = 0; i <= md[v]; i++) res[i] = 0;
for (int i = 0; i <= tt; i++) ss[i] = 0;
mm = max(mm, md[v]);
}
for (int i = 0; i <= mm; i++) now[i] = 0;
for (int v : V[u]) {
if (vis[v]) continue;
R = 0; MX = siz[v];
getroot(v, u);
divide(R);
}
}
int main() {
//freopen("0.txt", "r", stdin);
int a, b;
scanf("%d", &n);
ll sum = 0;
for (int i = 1; i <= n; i++) {
scanf("%d", sa + i);
sum += sa[i];
if (sum >= MOD) sum -= MOD;
}
ll RR = qpow(sum, MOD - 2);
for (int i = 1; i <= n; i++) {
sa[i] = RR * sa[i] % MOD;
ans[0] = (ans[0] + 1LL * sa[i] * sa[i]) % MOD;
}
for (int i = 0; i < n; i++) scanf("%d", ww + i);
for (int i = 1; i < n; i++) {
scanf("%d%d", &a, &b);
V[a].push_back(b);
V[b].push_back(a);
}
ms[0] = 1e9;
MX = n;
getroot(1, 0);
divide(R);
ll r = 1LL * ans[0] * ww[0] % MOD;
for (int i = 1; i < n; i++) r = (r + 1LL * ans[i] * ww[i] * 2) % MOD;
printf("%lld\n", r);
return 0;
}
查看9道真题和解析