牛客练习赛84 E 牛客推荐系统开发之标签重复度
牛客推荐系统开发之标签重复度
https://ac.nowcoder.com/acm/contest/11174/E
考虑使用点分树维护,每次维护经过点分中心的权值。我们考虑计算每个点到点分中心的最小值和最大值,然后合并两个到点分中心的路径。将这样的最大最小值对按最小值排序后,计算贡献。
考虑到后计算到的点最小值一定更大,则只有两种情况。
1.最大值比之前的大,则贡献是之前的最小值乘现在的最大值。
2.最大值比之前的小,则贡献是之前的最小值乘之前的最大值。
因此我们只要按照最大值为下标插入权值即可。
因为维护的是过点分中心的路径,要删除掉同一子树内点的贡献。
复杂度。
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
#define lson rt * 2
#define rson rt * 2 + 1
#define MP make_pair
typedef long long ll;
void read(int &x) {
x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
}
const int N = 1e5 + 100;
const int MOD = 998244353;
int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
int mul(int a, int b) { return 1LL * a * b % MOD; }
void upd(int &a, int b) {
a += b;
if (a >= MOD) a -= MOD;
}
int tp;
struct node {
int tree[N];
void insert(int id, int x) {
for (; id <= tp; id += id & -id) upd(tree[id], x);
}
int query(int id) {
int sum = 0;
for (; id > 0; id -= id & -id) upd(sum, tree[id]);
return sum;
}
int query(int l, int r) {
return add(query(r), MOD - query(l - 1));
}
}t1, t2;
int R, Mn, SZ, ans;
int siz[N], mn[N];
bool vis[N];
vector<int> V[N];
void getroot(int u, int fa) {
siz[u] = 1; mn[u] = 0;
for (int v : V[u]) {
if (vis[v] || v == fa) continue;
getroot(v, u);
siz[u] += siz[v];
mn[u] = max(mn[u], siz[v]);
}
mn[u] = max(mn[u], SZ - siz[u]);
if (mn[u] < Mn) R = u, Mn = mn[u];
}
int n, tot;
int sa[N], has[N];
pair<int, int> res[N];
void dfs(int u, int fa, int mi, int ma) {
mi = min(mi, sa[u]); ma = max(ma, sa[u]);
res[++tot] = MP(mi, ma);
siz[u] = 1;
for (int v : V[u]) {
if (vis[v] || v == fa) continue;
dfs(v, u, mi, ma);
siz[u] += siz[v];
}
}
void cal(int op) {
int sum = 0;
sort(res + 1, res + tot + 1);
for (int i = 1; i <= tot; i++) {
int mi = res[i].first, ma = res[i].second;
if (ma > 1) upd(sum, mul(has[ma], t1.query(1, ma - 1)));
upd(sum, t2.query(ma, n));
t1.insert(ma, has[mi]);
t2.insert(ma, mul(has[mi], has[ma]));
}
for (int i = 1; i <= tot; i++) {
int mi = res[i].first, ma = res[i].second;
t1.insert(ma, MOD - has[mi]);
t2.insert(ma, MOD - mul(has[mi], has[ma]));
}
if (op > 0) upd(ans, sum);
else upd(ans, MOD - sum);
}
void solve(int u) {
tot = 0;
upd(ans, mul(has[sa[u]], has[sa[u]]));
res[++tot] = MP(sa[u], sa[u]);
for (int v : V[u]) if (!vis[v]) dfs(v, u, sa[u], sa[u]);
cal(1);
for (int v : V[u]) {
if (vis[v]) continue;
tot = 0;
dfs(v, u, sa[u], sa[u]);
cal(-1);
}
vis[u] = true;
for (int v : V[u]) {
if (vis[v]) continue;
SZ = siz[v]; Mn = 1e9; getroot(v, u);
solve(R);
}
}
int main() {
//freopen("0.txt", "r", stdin);
read(n);
for (int i = 1; i <= n; i++) read(sa[i]), has[i] = sa[i];
for (int i = 1, a, b; i < n; i++) {
read(a); read(b);
V[a].push_back(b);
V[b].push_back(a);
}
sort(has + 1, has + n + 1);
tp = unique(has + 1, has + n + 1) - has - 1;
for (int i = 1; i <= n; i++) sa[i] = lower_bound(has + 1, has + tp + 1, sa[i]) - has;
SZ = n; Mn = 1e9; getroot(1, 0);
solve(R);
printf("%d\n", ans);
return 0;
}</int,>
