树链剖分(重链剖分)+线段树---解树一大法宝
题目链接 其实主要是想存一下我写的代码,以后当板子用了,毕竟敲一次线段树真心挺难的。
注意一点,我在写的时候检查了好几遍,线段树和树链剖分及一系列修改查询操作都对,为什么样例过不去,原来是我在建线段树的时候,是按照原来的输入顺序建立线段树,但原来输入的顺序在树上不一定连续,所以肯定会出错。 需要按照dfn的顺序去建树,将dfn与原数组进行映射,去建立线段树,因为dfn的顺序在树上是连续的可以用线段树进行区间管理,建立一个当前dfs序与对应节点编号的映射。
#include <bits/stdc++.h>
#define int long long
#define endl '\n'
#define LCHILD(x) ((x << 1) | 1)
#define RCHILD(x) ((x + 1) << 1)
using namespace std;
void solve()
{
int n, q, root, mod;
cin >> n >> q >> root >> mod;
vector<int> row_data(n + 1), data(n + 1);
for (int i = 1; i <= n; i++)
{
cin >> row_data[i];
}
vector<vector<int>> adj(n + 1);
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
struct node
{
int val, ctag;
};
vector<node> tree(4 * n + 1);
auto Pushup = [&](int p) -> void
{
tree[p].val = (tree[LCHILD(p)].val + tree[RCHILD(p)].val) % mod;
};
auto build = [&](auto &&build, int p, int left, int right) -> void
{
if (left == right)
{
tree[p].val = data[left];
tree[p].ctag = 0;
}
else
{
int mid = (left + right) >> 1;
build(build, LCHILD(p), left, mid);
build(build, RCHILD(p), mid + 1, right);
Pushup(p);
}
};
auto mark = [&](int p, int lazy_tag, int left, int right) -> void
{
tree[p].ctag += lazy_tag;
tree[p].val = (tree[p].val + lazy_tag * (right - left + 1)) % mod;
};
auto PushDown = [&](int p, int left, int right)
{
if (tree[p].ctag)
{
int mid = (left + right) >> 1;
mark(LCHILD(p), tree[p].ctag, left, mid);
mark(RCHILD(p), tree[p].ctag, mid + 1, right);
tree[p].ctag = 0;
}
};
auto query = [&](auto &&query, int p, int left, int right, int qleft, int qright) -> int
{
if (qright < left || qleft > right)
return 0;
if (qleft <= left && qright >= right)
return tree[p].val;
PushDown(p, left, right);
int mid = (left + right) >> 1;
int re1 = query(query, LCHILD(p), left, mid, qleft, qright);
int re2 = query(query, RCHILD(p), mid + 1, right, qleft, qright);
return (re1 + re2) % mod;
};
auto update = [&](auto &&update, int p, int left, int right, int uleft, int uright, int k) -> void
{
if (uleft > right || uright < left)
return;
if (uleft <= left && uright >= right)
{
mark(p, k, left, right);
}
else
{
PushDown(p, left, right);
int mid = (left + right) >> 1;
update(update, LCHILD(p), left, mid, uleft, uright, k);
update(update, RCHILD(p), mid + 1, right, uleft, uright, k);
Pushup(p);
}
};
vector<int> sz(n + 1), fa(n + 1), dep(n + 1), top(n + 1), son(n + 1), id(n + 1);
auto dfs = [&](auto &&dfs, int u) -> void
{
sz[u] = 1;
dep[u] = dep[fa[u]] + 1;
for (auto v : adj[u])
{
if (v == fa[u])
continue;
fa[v] = u;
dfs(dfs, v);
sz[u] += sz[v];
if (sz[v] > sz[son[u]])
{
son[u] = v;
}
}
};
int cnt = 0;
auto dfs1 = [&](auto &&dfs1, int u, int t) -> void
{
top[u] = t;
id[u] = ++cnt;
if (son[u])
{
dfs1(dfs1, son[u], t);
}
for (int i : adj[u])
{
if (i == fa[u] || i == son[u])
continue;
dfs1(dfs1, i, i);
}
};
auto sum_x_y = [&](int x, int y) -> int
{
int res = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
{
swap(x, y);
}
res = (res + query(query, 1, 1, n, id[top[x]], id[x])) % mod;
x = fa[top[x]];
}
if (dep[x] < dep[y])
{
swap(x, y);
}
res = (res + query(query, 1, 1, n, id[y], id[x])) % mod;
return res;
};
auto update_x_y = [&](int x, int y, int z) -> void
{
// int res=0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
{
swap(x, y);
}
update(update, 1, 1, n, id[top[x]], id[x], z);
x = fa[top[x]];
}
if (dep[x] < dep[y])
{
swap(x, y);
}
update(update, 1, 1, n, id[y], id[x], z);
};
auto subtree_sum = [&](int u) -> int
{
return query(query, 1, 1, n, id[u], id[u] + sz[u] - 1);
};
auto upsubtree = [&](int u, int k) -> void
{
update(update, 1, 1, n, id[u], id[u] + sz[u] - 1, k);
};
dfs(dfs, root);
dfs1(dfs1, root, root);
//要对dfn去建树
for (int u = 1; u <= n; ++u)
data[id[u]] = row_data[u];
build(build, 1, 1, n);
for (int i = 0; i < q; i++)
{
int op;
cin >> op;
if (op == 1)
{
int x, y, z;
cin >> x >> y >> z;
update_x_y(x, y, z);
}
else if (op == 2)
{
int x, y;
cin >> x >> y;
cout << sum_x_y(x, y) % mod << endl;
}
else if (op == 3)
{
int x, z;
cin >> x >> z;
upsubtree(x, z);
}
else if (op == 4)
{
int x;
cin >> x;
cout << subtree_sum(x) % mod << endl;
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int t = 1;
// cin>>t
while (t--)
{
solve();
}
return 0;
}