树链剖分是个很强的 OIer

树链剖分是用来解决一系列树上问题的利器。

定义

  • 链:树上不拐弯的一条路径
  • 重儿子:子树大小最大的儿子
  • 轻儿子:其他儿子
  • 重链:由重儿子组成的链

结果和过程

结果

这是原树

剖分之后的树

这里,红色的边构成重链,蓝色的点为重儿子,绿色的点为轻儿子。

注意,这里的 7 号与 8 号节点子树大小相同,因此我们选择编号靠前的儿子为重儿子。

实现

一般来说,树链剖分通过两遍 dfs 来实现。

定义

siz[x] //子树x的大小
top[x] //x所在链的顶端
fat[x] //x的父亲
dep[x] //x的深度
son[x] //x的重儿子

第一遍 dfs ,我们先处理出每个节点的父亲深度子树大小重儿子

void dfs1(int u,int f) {
    siz[u] = 1;
    dep[u] = dep[f] + 1;
    son[u] = 0;
    fat[u] = f;
    for (int i = hd[u];i;i= nxt[i]) {
        int v = to[i];
        if (v != f) {
            dfs1(v,u);
            siz[u] += siz[v];
            if (siz[v] > siz[son[u]]) {
                son[u] = v;
            }
        }
    }
} 

第二遍 dfs,我们处理出每条链链顶的节点

void dfs2(int u,int f) {
    if (son[f] == u) {
        top[u] = top[f];
    }
    else {
        top[u] = u;
    }
    if (son[u]) {
        dfs2(son[u],u);
    }
    for (int i = hd[u];i;i = nxt[i]) {
        int v = to[i];
        if (v != f && v != son[u]) {
            dfs2(v,u);
        }
    }
}

这样树链剖分的基本结构就写完辣!是不是很简单

应用

LCA

显然,我们可以发现,在一条重链上的两个节点的 LCA 显然就是深度更浅的那个节点。

所以我们可以先将两个节点跳到同一条链上,求出深度浅的那个节点即可,于是有了如下代码:

int lca(int a,int b) {
    for (int ta = top[a],tb = top[b];ta != tb;) {
        if (dep[ta] > dep[tb]) {
            ta = top[a = fat[ta]];
        }
        else {
            tb = top[b = fat[tb]];
        }
    }
    return dep[a] < dep[b] ? a : b;
}

时间复杂度O(\log{n})

例题:板子题

随手套个板子写掉。

代码:

#include <cstdio>
#include <cctype>
#include <cstring>

const int N = 500010;
const int M = N << 1;

inline int read() {
    char v = getchar();int x = 0,f = 1;
    while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
    while (isdigit(v)) {x = x * 10 + v - 48;v = getchar();}
    return x * f;
}

int hd[N],edg[M],nxt[M],to[M],n,m,tot,cnt;

inline void add(int u,int v,int w) {
    to[++tot] = v;edg[tot] = w;nxt[tot] = hd[u];hd[u] = tot;
}

inline void addedge(int u,int v,int w) {
    add(u,v,w);add(v,u,w);
}

int dfn[N],top[N],fat[N],siz[N],son[N],end[N],dep[N],s;

void dfs1(int s,int f) {
    son[s] = 0;
    dep[s] = dep[f] + 1;
    siz[s] = 1;
    fat[s] = f;
    for (int i = hd[s];i;i = nxt[i]) {
        if (to[i] != f) {
            dfs1(to[i],s);
            siz[s] += siz[to[i]];
            if (siz[to[i]] > siz[son[s]]) {
                son[s] = to[i];
            }
        }
    }
}

void dfs2(int u,int f) {
    dfn[u] = ++cnt;
    if (u == son[f]) {
        top[u] = top[f];
    }
    else {
        top[u] = u;
    }
    if (son[u]) {
        dfs2(son[u],u);
    }
    for (int i = hd[u];i;i = nxt[i]) {
        int v = to[i];
        if (v != f && v != son[u]) {
            dfs2(v,u);
        }
    }
    end[u] = cnt;
}

int lca(int a,int b) {
    for (int ta = top[a],tb = top[b];ta != tb;) {
        if (dep[ta] > dep[tb]) {
            ta = top[a = fat[ta]];
        }
        else {
            tb = top[b = fat[tb]];
        }
    }
    return dep[a] < dep[b] ? a : b;
}

int main() {
    n = read(),m = read(),s = read();
    for (int i = 1;i < n;++i) {
        int u = read(),v = read(),w = 1;
        addedge(u,v,w);
    }
    dfs1(s,0);
    dfs2(s,0);
    for (int i = 1;i <= m;++i) {
        int u = read(),v = read();
        printf("%d\n",lca(u,v));
    }
    return 0;
}

链上修改,子树修改

这里就必须引出一个新东西了:dfs 序

指的是第几次 dfs 遍历到的这个节点

定义数组dfn[x]为节点 x 的 dfs 序。

我们可以在 dfs2 中顺手维护一下。

void dfs2(int u,int f) {
    dfn[u] = ++cnt;
    if (son[f] == u) {
        top[u] = top[f];
    }
    else {
        top[u] = u;
    }
    if (son)
    for (int i = hd[u];i;i = nxt[i]) {
        int v = to[i];
        if (v != f) {
            dfs2(v,u);
        }
    }
    end[cnt] = u;
}

这里给出每个节点加上 dfn 的图

我们可以发现,在每棵子树,每条上的dfn都是连续的!

这样看可能不太明朗,我们把它转化成区间来看(这里是子树的,链同理)

所以我们就可以利用数据结构维护一下节点了。

对子树修改就是 change(dfn[x],dfn[x]+siz[x]-1)

链上修改就是 change(dfn[x],dfn[y])

这里的核心思想是:将树上问题转化成序列问题来处理

例题:[ZJOI2008]树的统计

用线段树维护最大值最小值。

代码:

6#include <cstdio>
#include <cctype>
#include <iostream>
#include <string>
const int N = 100010;
const int M = N << 1;
const int INF = 0x3f3f3f3f;
using namespace std;
inline int read() {
    char v = getchar();int x = 0,f = 1;
    while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
    while (isdigit(v)) {x = x * 10 + v - 48;v = getchar();}
    return x * f;
}

int to[M],hd[N],nxt[M],tot;
string s;

inline void add(int u,int v) {
    to[++tot] = v;nxt[tot] = hd[u];hd[u] = tot;
}

inline void addedge(int u,int v) {
    add(u,v);add(v,u);
}

int dfn[N],top[N],fat[N],siz[N],dep[N],son[N],rk[N],cnt,n;

void dfs1(int u,int f) {
    fat[u] = f;
    dep[u] = dep[f] + 1;
    son[u] = 0;
    siz[u] = 1;
    for (int i = hd[u];i;i = nxt[i]) {
        int v = to[i];
        if (v != f) {
            dfs1(v,u);
            siz[u] += siz[v];
            if (siz[v] > siz[son[u]]) {
                son[u] = v;
            }
        }
    }
}

void dfs2(int u,int f) {
    dfn[u] = ++cnt;rk[cnt] = u;
    if (son[f] == u) {
        top[u] = top[f];
    }
    else {
        top[u] = u;
    }
    if (son[u]) {
        dfs2(son[u],u);
    }
    for (int i = hd[u];i;i = nxt[i]) {
        int v = to[i];
        if (v != f && v != son[u]) {
            dfs2(v,u);
        }
    }
}

int num[N];

struct node {
    int l,r,big,sum;
}tree[N<<2];

inline void build(int p,int l,int r) {
    tree[p].l = l;tree[p].r = r;
    if (l == r) {
        tree[p].big = tree[p].sum = num[rk[l]];
        return ;
    }
    int mid = (l + r) >> 1;
    build(p << 1,l,mid);
    build(p << 1|1,mid+1,r);
    tree[p].sum = tree[p<<1].sum + tree[p<<1|1].sum;
    tree[p].big = max(tree[p<<1].big,tree[p<<1|1].big);
    return ;
}

inline void modify(int p,int x,int v) {
    if (tree[p].l == tree[p].r) {
        tree[p].sum = tree[p].big = v;
        return ;
    }
    int mid = (tree[p].l + tree[p].r) >> 1;
    if (x <= mid) modify(p<<1,x,v);
    if (x > mid) modify(p<<1|1,x,v);
    tree[p].sum = tree[p<<1].sum + tree[p<<1|1].sum;
    tree[p].big = max(tree[p<<1].big,tree[p<<1|1].big);
    return ;
} 

inline int querym(int p,int x,int y) {
    if (tree[p].l >= x && tree[p].r <= y) {
        return tree[p].big;
    }
    int mid = (tree[p].l + tree[p].r) >> 1,ans = -INF;
    if (x <= mid) {
        ans = max(ans,querym(p<<1,x,y));
    }
    if (y > mid) {
        ans = max(ans,querym(p<<1|1,x,y));
    }
    return ans;
}

inline int querys(int p,int x,int y) {
    if (tree[p].l >= x && tree[p].r <= y) {
        return tree[p].sum;
    }
    int mid = (tree[p].l + tree[p].r) >> 1,ans = 0;
    if (x <= mid) {
        ans += querys(p<<1,x,y);
    }
    if (y > mid) {
        ans += querys(p<<1|1,x,y);
    }
    return ans;
}

inline void change(int u,int t) {
    modify(1,dfn[u],t);
}

inline int qmax(int u,int v) {
    int ans = -INF;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) {
            swap(u,v);
        }
        ans = max(ans,querym(1,dfn[top[u]],dfn[u]));
        u = fat[top[u]];
    }
    if (dep[u] > dep[v]) {
        swap(u,v);
    }
    ans = max(ans,querym(1,dfn[u],dfn[v]));
    return ans;
}

inline int qsum(int u,int v) {
    int ans = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) {
            swap(u,v);
        }
        ans += querys(1,dfn[top[u]],dfn[u]);
        u = fat[top[u]];
    }
    if (dep[u] > dep[v]) {
        swap(u,v);
    }
    ans += querys(1,dfn[u],dfn[v]);
    return ans;
}

signed main() {
    n = read();
    for (int i = 1;i < n;++i) {
        addedge(read(),read());
    }
    for (int i = 1;i <= n;++i) {
        num[i] = read();
    }
    dfs1(1,0);dfs2(1,0);build(1,1,n);
    int m = read();
    for (int i = 1;i <= m;++i) {
        cin >> s;
        int u = read(),v = read();
        if (s == "CHANGE") {
            change(u,v);
        }
        if (s == "QSUM") {
            printf("%d\n",qsum(u,v));
        }
        if (s == "QMAX") {
            printf("%d\n",qmax(u,v));
        }
    }
    return 0;
}

练习

Qtree 把边权变成点权,巧妙的做法

[HAOI2015]树上操作 练手,区间修改,区间查询

[JLOI2014]松鼠的新家 巧妙的做法,也可以用树上差分来写

[NOI2015]程序包管理器 稍微转换一下问题

后记

感谢 \color{black}M\color{red}{inagami} 给我提供了极大的帮助,感谢 \color{black}Y\color{red}{ouSiKi} 的教导