树上点分治
点分治
树上点分治 其实就是把序列的分治方法移到了树上操作。序列上每个点的后继只有一个,树上可以有很多,我们找一个分支最多的出去,在不考虑常数的的情况下,这种分治方法是非常划算的。
静态分治
关于点分治的思想我不再赘述,很多博客已经讲得很清楚了。
我们实现上面代码,主要为下面几个函数。
寻找重心
在树上找重心的操作其实相当于在序列上找中点。
int mid = (l + r) >> 1; |
而我们在树上是这样操作的。
void getroot(int u, int f) { size[u] = 1; mxs[u] = 0; for (int it = head[u]; it != -1; it = edge[it].nxt) { int v = edge[it].v; if (vis[v] || v == f) continue; getroot(v, u); size[u] += size[v]; mxs[u] = max(mxs[u], size[v]); } mxs[u] = max(mxs[u], tsize - size[u]); if (mxs[u] < mxs[root]) root = u; } |
(其实就是树上跑个dp)
分治
得到重心之后,我们分治重心的每个分支,最后相当于pushup上来。
void solve(int u, int f) { vis[u] = 1; ans += get(u, 0); //治 for (int it = head[u]; it != -1; it = edge[it].nxt) //分 { int v = edge[it].v; if (v == f || vis[v]) continue; ans -= get(v, edge[it].d); //治 root = 0; tsize = size[v]; getroot(v, 0); solve(root, u); } }
这一部分就是我们随意发挥的地方了,我们可以在这里进行一些复杂度约为O(n)计算,达到和序列分治类似甚至更优的效果。
POJ1741
以此题为例,我们来搞一波树上点分治。
在大部分博客中,分治部分代码对小子树大小的计算方法我觉得有问题,于是后来自己在原来我的AC代码上做了一番修改,貌似还是有几十毫秒的改进的。
#include<cstdio> #include<algorithm> #include<iostream> using namespace std; typedef long long ll; const int N = 1e4 + 5; struct node { int v, nxt, d; }edge[2 * N]; int head[N], tot; int size[N], root, mxs[N], vis[N], tsize; int cnt[N], k; void init(int n) { for (int i = 1; i <= n; i++) head[i] = -1, vis[i] = 0; root = 0; tot = 0; tsize = n; mxs[0] = 0x3f3f3f3f; } void addedge(int u, int v, int d) { edge[++tot].v = v; edge[tot].nxt = head[u]; edge[tot].d = d; head[u] = tot; } void getroot(int u, int f) { size[u] = 1; mxs[u] = 0; for (int it = head[u]; it != -1; it = edge[it].nxt) { int v = edge[it].v; if (vis[v] || v == f) continue; getroot(v, u); size[u] += size[v]; mxs[u] = max(mxs[u], size[v]); } mxs[u] = max(mxs[u], tsize - size[u]); if (mxs[u] < mxs[root]) root = u; } int cur = 0; ll ans = 0; void calc(int u, int f, int d) { cnt[++cur] = d; size[u] = 1; for (int it = head[u]; it != -1; it = edge[it].nxt) { int v = edge[it].v; if (v == f || vis[v]) continue; calc(v, u, d + edge[it].d); size[u] += size[v]; } } int get(int u, int d) { cur = 0; calc(u, 0, d); sort(cnt + 1, cnt + cur + 1); int head = 1, tail = cur; int ret = 0; while (head < tail) { if (cnt[head] + cnt[tail] <= k) { ret += tail - head; head++; } else tail--; } return ret; } void solve(int u, int f) { // dbg(u); vis[u] = 1; ans += get(u, 0); for (int it = head[u]; it != -1; it = edge[it].nxt) { int v = edge[it].v; if (v == f || vis[v]) continue; ans -= get(v, edge[it].d); root = 0; tsize = size[v]; // puts("getroot"); // dbg(v, tsize); getroot(v, 0); solve(root, u); } } template<class T> void read(T& ret) { ret = 0; char c; while ((c = getchar()) > '9' || c < '0'); while (c >= '0' && c <= '9') { ret = ret * 10 + c - '0'; c = getchar(); } } int main() { int n; while (true) { read(n); read(k); if (n == 0 && k == 0) break; init(n); for (int i = 1; i < n; i++) { int u, v, d; read(u); read(v); read(d); addedge(u, v, d); addedge(v, u, d); } ans = 0; root = 0; getroot(1, 0); solve(root, 0); printf("%lld\n", ans); } return 0; } |
动态点分治
有了前面的基础,我们来了解一下动态点分治。
动态在哪里
我们之所以将这种点分治和之前加以区分,是因为这一类题目会要求可以做一些修改,而显而易见我们不能每次暴力修改然后查询的时候用O(nlog(n))时间去查询,这是十分爆炸的。
解决方案
还是考虑分治的时候,如果我们要求修改某一个点,那么我们可以像线段树那样每次取一半,看看要修改的点在哪个区间,然后向上层push。
我们这样只去修改会受到影响的log(n)级别个区间,查询类似。
为了方便快捷地知道我们后面该往哪里走,先预处理出一棵点分树,在这棵树上是我们寻找重心的顺序,后面将会用到这棵树。
点分树
这棵树是由我们原本的树建出来的,但是树形又和原来不同。
点分树的父子关系,是由原树的分治顺序决定的,也就是说,我们寻找重心的时候的顺序,其实应该是点分树上的遍历顺序。既然这样,那我们的数其实也是比较好建立了。
inline void get_tree(int u, int f) { //dbg(u); vis[u] = 1; fa[u] = f; for (int it = head[u]; it != -1; it = edge[it].nxt) { int v = edge[it].v; if (vis[v]) continue; root = 0; tsz = size[v]; get_root(v, u, 1); get_tree(root, u); } } |
使用点分树
我们要这样一棵树有什么用呢?其实理解树上点分治的话,就会觉得这棵树其实是很有必要的。我们要维护这棵树上的一些父子关系。
分治的过程,我们要知道当前这一部分要往哪个点合并,知道方向我们才好处理。为了不用每次都去get_root,我们预先把所有点父子关系存下来。
就像我们在归并的时候向上返回,总要做些处理,如果当前区间已经覆盖了要修改或查询的点,直接返回,否则分半去修改或查询。
使用方法需要随机应变的。
BZOJ3730 震波
#include<bits/stdc++.h> using namespace std; #define ll long long #ifndef ONLINE_JUDGE #define dbg(x...) do{cout << "\033[33;1m" << #x << "->" ; err(x);} while (0) void err(){cout << "\033[39;0m" << endl;} template<template<typename...> class T, typename t, typename... A> void err(T<t> a, A... x){for (auto v: a) cout << v << ' '; err(x...);} template<typename T, typename... A> void err(T a, A... x){cout << a << ' '; err(x...);} #else #define dbg(...) #endif #define inf 1ll << 50 #define lowbit(x) ((x)&-(x)) const int N = 1e5 + 5; struct node
{ int v, nxt;
}edge[2 * N]; int tot, head[N], val[N], vis[N]; int dep[N]; inline void add_edge(int u, int v) {
edge[++tot].v = v;
edge[tot].nxt = head[u];
head[u] = tot;
} int root, tsz, mxs[N], size[N], fa[N], dis[N]; vector<ll> sum[2][N]; inline void init(int n) { for (int i = 1; i <= n; i++)
{
head[i] = -1;
vis[i] = 0;
sum[0][i].clear();
sum[1][i].clear();
}
root = 0;
mxs[0] = 0x3f3f3f3f;
tsz = n;
tot = 0;
} /*
int Dep[N], son[N], Top[N], pa[N];
inline void dfs1(int u, int f, int d)
{
Dep[u] = d;
son[u] = -1;
size[u] = 1;
pa[u] = f;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (v == f)
continue;
dfs1(v, u, d + 1);
size[u] += size[v];
if (son[u] == -1 || size[v] > size[u])
son[u] = v;
}
}
inline void dfs2(int u, int f, int t)
{
Top[u] = t;
if (son[u] != -1)
dfs2(son[u], u, t);
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (v == f || v == son[u])
continue;
dfs2(v, u, v);
}
}
*/ int st[2 * N], cur, Dep[N]; int first[N]; inline void dfs(int u, int f) {
st[++cur] = u;
first[u] = cur;
Dep[u] = Dep[f] + 1; for (int it = head[u]; it != -1; it = edge[it].nxt)
{ int v = edge[it].v; if (v == f) continue;
dfs(v, u);
st[++cur] = u;
}
} int dp[N * 2][20]; void rmq_pre() { for (int i = 1; i <= cur; i++)
dp[i][0] = st[i]; for (int j = 1; (1 << j) <= cur; j++) for (int i = 1; i + (1 << j) <= cur + 1; i++)
{ int d1 = Dep[dp[i][j - 1]], d2 = Dep[dp[i + (1 << (j - 1))][j - 1]]; if (d1 < d2)
dp[i][j] = dp[i][j - 1]; else dp[i][j] = dp[i + (1 << (j - 1))][j - 1];
}
} inline int rmq(int l, int r) { int k = 31 - __builtin_clz(r - l + 1); int d1 = dp[l][k], d2 = dp[r - (1 << k) + 1][k]; if (d1 < d2) return dp[l][k]; return dp[r - (1 << k) + 1][k];
} inline int lca(int u, int v) { /*
while (Top[u] != Top[v])
{
if (Dep[Top[u]] < Dep[Top[v]])
swap(u, v);
u = pa[Top[u]];
}
if (Dep[u] < Dep[v])
return u;
else
return v;
*/ if (first[u] > first[v])
swap(u, v); return rmq(first[u], first[v]);
} inline int Dis(int u, int v) { int ca = lca(u, v); return Dep[u] + Dep[v] - 2 * Dep[ca];
} inline void get_root(int u, int f, int d) { // dbg(u, d); size[u] = 1;
mxs[u] = 0;
dep[u] = d; for (int it = head[u]; it != -1; it = edge[it].nxt)
{ int v = edge[it].v; if (vis[v] || v == f) continue;
get_root(v, u, d + 1);
size[u] += size[v];
mxs[u] = max(mxs[u], size[v]);
}
mxs[u] = max(mxs[u], tsz - size[u]); if (mxs[u] < mxs[root])
root = u;
} /*
inline int lowbit(int x)
{
return x & (-x);
}
*/ inline void update(int bla, int id, int x, int val) { while (x < sum[bla][id].size())
{
sum[bla][id][x] += val;
x += lowbit(x);
}
} inline ll query(int bla, int id, int x) {
ll ans = 0; if (x >= sum[bla][id].size())
x = sum[bla][id].size() - 1; // dbg(id, x); while (x > 0)
{
ans += sum[bla][id][x];
x -= lowbit(x);
} return ans;
} inline pair<int, int> get_dep(int u, int f, int d)
{
pair<int, int> ans = make_pair(d, dep[u]); for (int it = head[u]; it != -1; it = edge[it].nxt)
{ int v = edge[it].v; if (vis[v] || v == f) continue;
pair<int, int> tmp = get_dep(v, u, d + 1);
ans.first = max(ans.first, tmp.first);
ans.second = max(ans.second, tmp.second);
} return ans;
} inline void calc(int u, int f, int top, int d) {
size[u] = 1;
update(0, top, d, val[u]); //dbg(top, d, val[u], u); for (int it = head[u]; it != -1; it = edge[it].nxt)
{ int v = edge[it].v; if (vis[v] || v == f) continue;
calc(v, u, top, d + 1);
size[u] += size[v];
}
} inline void calc2(int u, int f, int top) {
update(1, top, dep[u], val[u]); for (int it = head[u]; it != -1; it = edge[it].nxt)
{ int v = edge[it].v; if (vis[v] || v == f) continue;
calc2(v, u, top);
}
} inline void get_tree(int u, int f) { //dbg(u); vis[u] = 1;
fa[u] = f;
pair<int, int> mxd = get_dep(u, 0, 0);
sum[0][u].resize(mxd.first + 1);
sum[1][u].resize(mxd.second + 1);
calc2(u, f, u); // dbg(u, ans); for (int it = head[u]; it != -1; it = edge[it].nxt)
{ int v = edge[it].v; if (vis[v]) continue;
calc(v, u, u, 1); //dbg(u, v, ans); root = 0;
tsz = size[v];
get_root(v, u, 1);
get_tree(root, u);
} // update(1, u, dep[u], val[u]); } inline ll s_que(int u, int k) {
ll ans = val[u] + query(0, u, k); //dbg(u, k, query(0, u, k), ans); int tmp = u; while (fa[tmp])
{ int d1 = Dis(fa[tmp], u); if (k - d1 >= 0)
ans += query(0, fa[tmp], k - d1) - query(1, tmp, k - d1) + (k - d1 >= 0) * val[fa[tmp]]; //dbg(tmp, fa[tmp], d1, k - d1, k, query(0, fa[tmp], k - d1), query(1, tmp, k - d1), ans); tmp = fa[tmp];
} return ans;
} inline void s_update(int u, int v) {
ll del = v - val[u]; int tmp = u; while (fa[tmp])
{ int dis = Dis(u, fa[tmp]);
update(0, fa[tmp], dis, del);
update(1, tmp, dis, del);
tmp = fa[tmp];
}
val[u] = v;
} template<class T> void read(T& ret) {
ret = 0; char c; while ((c = getchar()) > '9' | c < '0'); while (c >= '0' && c <= '9')
{
ret = ret * 10 + c - '0';
c = getchar();
}
} int main() { int n, m;
read(n), read(m);
init(n); //puts("init over"); for (int i = 1; i <= n; i++)
read(val[i]); for (int i = 1; i < n; i++)
{ int u, v;
read(u);
read(v);
add_edge(u, v);
add_edge(v, u);
} /*
dfs1(1, 0, 1);
dfs2(1, 0, 1);
*/ Dep[0] = 0;
dfs(1, 0);
rmq_pre();
get_root(1, 0, 1);
get_tree(root, 0);
ll ans = 0; while (m--)
{ int type;
read(type); if (type == 0)
{ int u, k;
read(u);
read(k);
u ^= ans;
k ^= ans; printf("%lld\n", ans = s_que(u, k));
} else { int u, v;
read(u);
read(v);
u ^= ans;
v ^= ans;
s_update(u, v);
}
} return 0;
} |