题解 | #柠檬树#
柠檬树
https://ac.nowcoder.com/acm/problem/212478
基本思路:
按题解的,分两部来做。
第一步:按顺序对所有点到根的路径染上对应的颜色,这样区间查询就变成了查询 颜色号 >= 左端点号 的点的数量了,这里可以用 LCT (这个是静态问题,也可以用 轻重链剖分 来写)。
第二步:减去区间所有点的 LCA 到根节点的距离(可以用欧拉序加倍增来求,也可以用倍增加线段树思想来求)。
代码:(欧拉序+倍增+LCT) (代码若显示无缩进无高亮请去 这里)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define lowbit(x) ((x) & -(x))
int n, m;
struct LinkEdg {
int head[400009], nex[400009], to[400009], tot;
void add(int x, int y) {
nex[++ tot] = head[x], head[x] = tot, to[tot] = y;
}
} tr;
struct GetLca {
int oula[200009 * 2], tot; /// 欧拉序数组
int stMin[200009][18], stMax[200009][18]; /// 欧拉序数组区间中最大和最小值
int st[200009 * 2][19], dis[200009]; /// 欧拉序数组中区间最小深度位置和每个节点的深度
int lg[200009 * 2]; /// 预处理每个数的 log 值
GetLca() {
for(int i = 1, cnt = -1; i <= 400009; i ++) {
if(i == lowbit(i)) cnt ++;
lg[i] = cnt;
}
}
void build() {
for(int i = 1; i <= tot; i ++) {
st[i][0] = dis[oula[i]];
if(!stMin[oula[i]][0]) stMin[oula[i]][0] = stMax[oula[i]][0] = i;
else stMax[oula[i]][0] = i;
}
for(int i = 1; i <= lg[tot]; i ++)
for(int j = 1; j <= tot - (1 << i) + 1; j ++)
st[j][i] = min(st[j][i - 1], st[j + (1 << i - 1)][i - 1]);
for(int i = 1; i <= lg[n]; i ++)
for(int j = 1; j <= n - (1 << i) + 1; j ++) {
stMin[j][i] = min(stMin[j][i - 1], stMin[j + (1 << i - 1)][i - 1]);
stMax[j][i] = max(stMax[j][i - 1], stMax[j + (1 << i - 1)][i - 1]);
}
}
int ask(int l, int r) {
int len = lg[r - l + 1];
int lx = min(stMin[l][len], stMin[r - (1 << len) + 1][len]),
rx = max(stMax[l][len], stMax[r - (1 << len) + 1][len]);
len = lg[rx - lx + 1];
return min(st[lx][len], st[rx - (1 << len) + 1][len]) - 1;
}
} st;
struct Lct {
int son[200009][2], fa[200009], col[200009], sum[200009]; /// splay中的数据
int tr[200009]; /// 树状数组
#define isRoot(x) (son[fa[x]][0] != x && son[fa[x]][1] != x)
inline void pushUp(int x) {
sum[x] = sum[son[x][0]] + sum[son[x][1]] + 1;
}
void rotata(int x) {
int y = fa[x], z = fa[y], c = son[y][0] == x;
if(!isRoot(y)) son[z][son[z][1] == y] = x; fa[x] = z;
son[y][!c] = son[x][c], fa[son[x][c]] = y;
son[x][c] = y, fa[y] = x;
pushUp(y), pushUp(x);
}
inline void pushDown(int x) {
col[son[x][0]] = col[son[x][1]] = col[x];
}
void upData(int x) {
if(!isRoot(x)) upData(fa[x]);
pushDown(x);
}
void splay(int x) {
upData(x);
while(!isRoot(x)) {
int y = fa[x], z = fa[y];
if(!isRoot(y)) son[z][0] == y ^ son[y][0] == x ? rotata(x) : rotata(y);
rotata(x);
}
}
inline void add(int x, int w) {
if(x == 0) return ;
for(; x <= n; x += lowbit(x))
this->tr[x] += w;
}
void access(int x) {
int p = 0, pre = 0, y = x;
for(; x; p = x, x = fa[x]) {
splay(x);
son[x][1] = p;
pushUp(x);
add(col[x], -sum[x] + pre);
pre = sum[x];
}
add(y, sum[p]);
col[p] = y;
}
int ask(int l) {
int ans = 0;
for(int i = n; i | l; i -= lowbit(i), l -= lowbit(l))
ans += this->tr[i] - this->tr[l];
return ans - 1;
}
} lct;
void dfs(int x, int fa = 0) { /// 创建欧拉序和深度数组 初始化lct
lct.fa[x] = fa, lct.sum[x] = 1;
st.oula[++ st.tot] = x; st.dis[x] = st.dis[fa] + 1;
for(int i = tr.head[x], p; i; i = tr.nex[i]) {
p = tr.to[i]; if(p == fa) continue;
dfs(p, x);
st.oula[++ st.tot] = x;
}
}
vector<pair<int, int> > ve[200009];
int ans[200009];
int main() {
cin >> n >> m;
int x, y;
for(int i = 1; i < n; i ++) {
scanf("%d%d", &x, &y);
tr.add(x, y), tr.add(y, x);
}
dfs(1);
st.build();
for(int i = 1; i <= m; i ++) {
scanf("%d%d", &x, &y);
ve[y].push_back(make_pair(x, i));
}
for(int i = 1; i <= n; i ++) {
lct.access(i);
for(pair<int, int> &x : ve[i])
ans[x.second] = lct.ask(x.first - 1) - st.ask(x.first, i);
}
for(int i = 1; i <= m; i ++) printf("%d\n", ans[i]);
return 0;
}