0%

树链剖分

前置知识

必备:存图方式(邻接表,邻接矩阵),dfs序。 维护:线段树、树状数组、BST。

引入

树链剖分,简单来说就是把树分割成链,然后维护每一条链。一般的维护算法有线段树,树状数组和BST。复杂度为\(O(\log n)\)

剖分方式

对于每一个节点,它的子节点中子树的节点数最大的为重儿子,连接到重儿子的边称为重边。例如:

加粗节点为重儿子。除重儿子和重边外的节点和边均为轻儿子或轻边。根不是重儿子也不是轻儿子,他根本就不是儿子!

以轻儿子或者根为起点的,由重边连接的一条连续的链称为重链。特别地,若一个叶子结点是轻儿子,那么便有一条以该叶子结点为起点的长度为1的重链。上图中,1-3-7-9是一条重链,6一个节点也是一条重链。

预处理

树链剖分的预处理本质上就是2个dfs。

dfs 1

一共完成四项任务:

  • 标记节点深度:dep[]
  • 标记节点的父节点:fa[]
  • 标记节点的子树大小:siz[]
  • 标记节点的重儿子编号:son[]

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
void dfs_1(int u,int f){
siz[u]=1;
int maxson=-1;
for(int i=head[u];i!=-1;i=edge[i].nxt){
int v=edge[i].v;
if(v==f) continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs_1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}

dfs 2

以先重后轻为优先级的dfs。完成三项任务:

  • 标记dfs序(先重后轻):idx[]
  • 标记dfs序对应的节点:rnk[]
  • 标记节点所在重链的顶端:top[]

代码:

1
2
3
4
5
6
7
8
9
10
11
12
void dfs_2(int u,int tp){
idx[u]=++dfn;
rnk[dfn]=a[u];
top[u]=tp;
if(!son[u])return;
dfs_2(son[u],tp);
for(int i=head[u];i!=-1;i=edge[i].nxt){
if(!idx[edge[i].v]){
dfs_2(edge[i].v,edge[i].v);
}
}
}

维护

这里以洛谷的P3384为例,要求维护四种操作:

  • 操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
  • 操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
  • 操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
  • 操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
  1. 要处理两点间的路径时:

    首先找到两个节点中较深的那个点,然后将对该点所在重链进行处理,并将该点移动至所在链顶端节点的父节点(相当于一段一段地处理)。因为是按照轻重边为优先级做的dfs,所以一条链上的编号一定是连续的。

    循环执行直到两个点在一条链上,这时再处理两个点之间的区间即可。

  2. 要处理一个点的子树时:

    子树的dfs序一定是连续的区间,直接处理该区间即可。

这样就把树上问题转化成了区间问题。

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include "bits/stdc++.h"
using namespace std;
const int N=2000100;
#define lson id<<1
#define rson id<<1|1
inline int read(){
int c=1,q=0;char ch=' ';
while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
if(ch=='-')c=-1,ch=getchar();
while(ch>='0'&&ch<='9')q=q*10+ch-'0',ch=getchar();
return c*q;
}
struct node{
int u,v,nxt;
}e[N];
int head[N];
int tot=1;
struct Tree{
int l,r,c,siz,f;
}t[N];
int n,m,root,p,dfn=0,rnk[N],a[N];
inline void add(int x,int y){
e[tot].u=x;
e[tot].v=y;
e[tot].nxt=head[x];
head[x]=tot++;
}
int dep[N],fa[N],son[N],siz[N],top[N],idx[N];
void dfs_1(int u,int f){
siz[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==f) continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs_1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void update(int id){
t[id].c=(t[lson].c+t[rson].c+p)%p;
}
void Build(int id,int L,int R){
t[id].l=L;t[id].r=R;t[id].siz=R-L+1;
if(L==R){
t[id].c=rnk[L];
return;
}
int mid=(L+R)>>1;
Build(lson,L,mid);
Build(rson,mid+1,R);
update(id);
}
void dfs_2(int u,int tp){
idx[u]=++dfn;
rnk[dfn]=a[u];
top[u]=tp;
if(!son[u])return;
dfs_2(son[u],tp);
for(int i=head[u];i;i=e[i].nxt){
if(!idx[e[i].v]){
dfs_2(e[i].v,e[i].v);
}
}
}
void pushdown(int id){
if(!t[id].f) return ;
t[lson].c=(t[lson].c+t[lson].siz*t[id].f)%p;
t[rson].c=(t[rson].c+t[rson].siz*t[id].f)%p;
t[lson].f=(t[lson].f+t[id].f)%p;
t[rson].f=(t[rson].f+t[id].f)%p;
t[id].f=0;
}
void addroute(int id,int L,int R,int c){
if(L<=t[id].l&&t[id].r<=R){
t[id].c+=t[id].siz*c;
t[id].f+=c;
return;
}
pushdown(id);
int mid=(t[id].l+t[id].r)>>1;
if(L<=mid)addroute(lson,L,R,c);
if(R>mid)addroute(rson,L,R,c);
update(id);
}
void addtree(int x,int y,int c){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
addroute(1,idx[top[x]],idx[x],c);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
addroute(1,idx[x],idx[y],c);
}
int sumroute(int id,int L,int R){
int ans=0;
if(L<=t[id].l&&t[id].r<=R)
return t[id].c;
pushdown(id);
int mid=(t[id].l+t[id].r)>>1;
if(L<=mid) ans=(ans+sumroute(lson,L,R))%p;
if(R>mid) ans=(ans+sumroute(rson,L,R))%p;
return ans;
}
void treesum(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+sumroute(1,idx[top[x]],idx[x]))%p;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=(ans+sumroute(1,idx[x],idx[y]))%p;
printf("%d\n",ans);
}
int main(){
n=read();m=read();root=read();p=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<=n-1;i++){
int x=read(),y=read();
add(x,y);add(y,x);
}
dfs_1(root,root);
dfs_2(root,root);
Build(1,1,n);
while(m--){
int op=read(),x,y,z;
if(op==1){
x=read();y=read();z=read();z=z%p;
addtree(x,y,z);
}
else if(op==2){
x=read();y=read();
treesum(x,y);
}
else if(op==3){
x=read(),z=read();
addroute(1,idx[x],idx[x]+siz[x]-1,z%p);
}
else if(op==4){
x=read();
printf("%d\n",sumroute(1,idx[x],idx[x]+siz[x]-1));
}
}
return 0;
}