树链剖分 学习笔记
树链剖分是个很强的 OIer
树链剖分是用来解决一系列树上问题的利器。
定义
- 链:树上不拐弯的一条路径
- 重儿子:子树大小最大的儿子
- 轻儿子:其他儿子
- 重链:由重儿子组成的链
结果和过程
结果
这是原树
剖分之后的树
这里,红色的边构成重链,蓝色的点为重儿子,绿色的点为轻儿子。
注意,这里的 7 号与 8 号节点子树大小相同,因此我们选择编号靠前的儿子为重儿子。
实现
一般来说,树链剖分通过两遍 dfs 来实现。
定义
siz[x] //子树x的大小
top[x] //x所在链的顶端
fat[x] //x的父亲
dep[x] //x的深度
son[x] //x的重儿子
第一遍 dfs ,我们先处理出每个节点的父亲,深度,子树大小,重儿子
void dfs1(int u,int f) {
siz[u] = 1;
dep[u] = dep[f] + 1;
son[u] = 0;
fat[u] = f;
for (int i = hd[u];i;i= nxt[i]) {
int v = to[i];
if (v != f) {
dfs1(v,u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) {
son[u] = v;
}
}
}
}
第二遍 dfs,我们处理出每条链链顶的节点,
void dfs2(int u,int f) {
if (son[f] == u) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son[u]) {
dfs2(son[u],u);
}
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f && v != son[u]) {
dfs2(v,u);
}
}
}
这样树链剖分的基本结构就写完辣!是不是很简单
应用
LCA
显然,我们可以发现,在一条重链上的两个节点的 LCA 显然就是深度更浅的那个节点。
所以我们可以先将两个节点跳到同一条链上,求出深度浅的那个节点即可,于是有了如下代码:
int lca(int a,int b) {
for (int ta = top[a],tb = top[b];ta != tb;) {
if (dep[ta] > dep[tb]) {
ta = top[a = fat[ta]];
}
else {
tb = top[b = fat[tb]];
}
}
return dep[a] < dep[b] ? a : b;
}
时间复杂度
例题:板子题
随手套个板子写掉。
代码:
#include <cstdio>
#include <cctype>
#include <cstring>
const int N = 500010;
const int M = N << 1;
inline int read() {
char v = getchar();int x = 0,f = 1;
while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
while (isdigit(v)) {x = x * 10 + v - 48;v = getchar();}
return x * f;
}
int hd[N],edg[M],nxt[M],to[M],n,m,tot,cnt;
inline void add(int u,int v,int w) {
to[++tot] = v;edg[tot] = w;nxt[tot] = hd[u];hd[u] = tot;
}
inline void addedge(int u,int v,int w) {
add(u,v,w);add(v,u,w);
}
int dfn[N],top[N],fat[N],siz[N],son[N],end[N],dep[N],s;
void dfs1(int s,int f) {
son[s] = 0;
dep[s] = dep[f] + 1;
siz[s] = 1;
fat[s] = f;
for (int i = hd[s];i;i = nxt[i]) {
if (to[i] != f) {
dfs1(to[i],s);
siz[s] += siz[to[i]];
if (siz[to[i]] > siz[son[s]]) {
son[s] = to[i];
}
}
}
}
void dfs2(int u,int f) {
dfn[u] = ++cnt;
if (u == son[f]) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son[u]) {
dfs2(son[u],u);
}
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f && v != son[u]) {
dfs2(v,u);
}
}
end[u] = cnt;
}
int lca(int a,int b) {
for (int ta = top[a],tb = top[b];ta != tb;) {
if (dep[ta] > dep[tb]) {
ta = top[a = fat[ta]];
}
else {
tb = top[b = fat[tb]];
}
}
return dep[a] < dep[b] ? a : b;
}
int main() {
n = read(),m = read(),s = read();
for (int i = 1;i < n;++i) {
int u = read(),v = read(),w = 1;
addedge(u,v,w);
}
dfs1(s,0);
dfs2(s,0);
for (int i = 1;i <= m;++i) {
int u = read(),v = read();
printf("%d\n",lca(u,v));
}
return 0;
}
链上修改,子树修改
这里就必须引出一个新东西了:dfs 序
指的是第几次 dfs 遍历到的这个节点
定义数组dfn[x]
为节点 x 的 dfs 序。
我们可以在 dfs2
中顺手维护一下。
void dfs2(int u,int f) {
dfn[u] = ++cnt;
if (son[f] == u) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son)
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f) {
dfs2(v,u);
}
}
end[cnt] = u;
}
这里给出每个节点加上 dfn
的图
我们可以发现,在每棵子树,每条链上的dfn都是连续的!
这样看可能不太明朗,我们把它转化成区间来看(这里是子树的,链同理)
所以我们就可以利用数据结构维护一下节点了。
对子树修改就是 change(dfn[x],dfn[x]+siz[x]-1)
链上修改就是 change(dfn[x],dfn[y])
这里的核心思想是:将树上问题转化成序列问题来处理
用线段树维护最大值最小值。
代码:
6#include <cstdio>
#include <cctype>
#include <iostream>
#include <string>
const int N = 100010;
const int M = N << 1;
const int INF = 0x3f3f3f3f;
using namespace std;
inline int read() {
char v = getchar();int x = 0,f = 1;
while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
while (isdigit(v)) {x = x * 10 + v - 48;v = getchar();}
return x * f;
}
int to[M],hd[N],nxt[M],tot;
string s;
inline void add(int u,int v) {
to[++tot] = v;nxt[tot] = hd[u];hd[u] = tot;
}
inline void addedge(int u,int v) {
add(u,v);add(v,u);
}
int dfn[N],top[N],fat[N],siz[N],dep[N],son[N],rk[N],cnt,n;
void dfs1(int u,int f) {
fat[u] = f;
dep[u] = dep[f] + 1;
son[u] = 0;
siz[u] = 1;
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f) {
dfs1(v,u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) {
son[u] = v;
}
}
}
}
void dfs2(int u,int f) {
dfn[u] = ++cnt;rk[cnt] = u;
if (son[f] == u) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son[u]) {
dfs2(son[u],u);
}
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f && v != son[u]) {
dfs2(v,u);
}
}
}
int num[N];
struct node {
int l,r,big,sum;
}tree[N<<2];
inline void build(int p,int l,int r) {
tree[p].l = l;tree[p].r = r;
if (l == r) {
tree[p].big = tree[p].sum = num[rk[l]];
return ;
}
int mid = (l + r) >> 1;
build(p << 1,l,mid);
build(p << 1|1,mid+1,r);
tree[p].sum = tree[p<<1].sum + tree[p<<1|1].sum;
tree[p].big = max(tree[p<<1].big,tree[p<<1|1].big);
return ;
}
inline void modify(int p,int x,int v) {
if (tree[p].l == tree[p].r) {
tree[p].sum = tree[p].big = v;
return ;
}
int mid = (tree[p].l + tree[p].r) >> 1;
if (x <= mid) modify(p<<1,x,v);
if (x > mid) modify(p<<1|1,x,v);
tree[p].sum = tree[p<<1].sum + tree[p<<1|1].sum;
tree[p].big = max(tree[p<<1].big,tree[p<<1|1].big);
return ;
}
inline int querym(int p,int x,int y) {
if (tree[p].l >= x && tree[p].r <= y) {
return tree[p].big;
}
int mid = (tree[p].l + tree[p].r) >> 1,ans = -INF;
if (x <= mid) {
ans = max(ans,querym(p<<1,x,y));
}
if (y > mid) {
ans = max(ans,querym(p<<1|1,x,y));
}
return ans;
}
inline int querys(int p,int x,int y) {
if (tree[p].l >= x && tree[p].r <= y) {
return tree[p].sum;
}
int mid = (tree[p].l + tree[p].r) >> 1,ans = 0;
if (x <= mid) {
ans += querys(p<<1,x,y);
}
if (y > mid) {
ans += querys(p<<1|1,x,y);
}
return ans;
}
inline void change(int u,int t) {
modify(1,dfn[u],t);
}
inline int qmax(int u,int v) {
int ans = -INF;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u,v);
}
ans = max(ans,querym(1,dfn[top[u]],dfn[u]));
u = fat[top[u]];
}
if (dep[u] > dep[v]) {
swap(u,v);
}
ans = max(ans,querym(1,dfn[u],dfn[v]));
return ans;
}
inline int qsum(int u,int v) {
int ans = 0;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u,v);
}
ans += querys(1,dfn[top[u]],dfn[u]);
u = fat[top[u]];
}
if (dep[u] > dep[v]) {
swap(u,v);
}
ans += querys(1,dfn[u],dfn[v]);
return ans;
}
signed main() {
n = read();
for (int i = 1;i < n;++i) {
addedge(read(),read());
}
for (int i = 1;i <= n;++i) {
num[i] = read();
}
dfs1(1,0);dfs2(1,0);build(1,1,n);
int m = read();
for (int i = 1;i <= m;++i) {
cin >> s;
int u = read(),v = read();
if (s == "CHANGE") {
change(u,v);
}
if (s == "QSUM") {
printf("%d\n",qsum(u,v));
}
if (s == "QMAX") {
printf("%d\n",qmax(u,v));
}
}
return 0;
}
练习
Qtree 把边权变成点权,巧妙的做法
[HAOI2015]树上操作 练手,区间修改,区间查询
[JLOI2014]松鼠的新家 巧妙的做法,也可以用树上差分来写
[NOI2015]程序包管理器 稍微转换一下问题
后记
感谢
本作品采用 知识共享署名-相同方式共享 4.0 国际许可协议 进行许可。