树链剖分

题目链接:http://www.luogu.org/problemnew/show/3384

简介:一道树剖模板题

关于树剖:

步骤一:第一次dfs

求出每个节点的重儿子(最重的儿子)和每个节点的质量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void dfs1(int st)
{
size[st]=1;
for(int i=head[st]; i; i=E[i].next)
{
int ed=E[i].ed;
if(fa[st]!=ed)
{
dep[ed]=dep[st]+1;
fa[ed]=st;
dfs1(ed);
if(hson[st]==0||size[hson[st]]<size[ed]) hson[st]=ed;
size[st]+=size[ed];
}
}
}

伪代码:

1
2
3
4
5
6
7
8
Func dfs1(st)
Size[st]=1
For(every node ed connected to the current node)
If(is not father of node now)
Depth[ed]=depth[st]+1
Father[ed]=st
Check whether ed is the heaviest
Size[st]+=size[ed]
步骤二:第二遍搜索:求出每一条重边的祖先,建立题目所给的节点序号和线段树存储序号的映射

在线段树中,存储的是重节点优先的先序遍历,好处有二:

1:以s为根的子树的节点是连续的

2:每一条重边都是连续的

C++代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void dfs2(int st,int anc)
{
top[st]=anc;//求重边祖先
sid[st]=++tot;//建立映射
real[tot]=st;//建立映射
if(hson[st]==0) return;
dfs2(hson[st],anc);//重儿子优先
for(int i=head[st]; i; i=E[i].next)
{
int ed=E[i].ed;
if(ed!=fa[st]&&ed!=hson[st])
dfs2(ed,ed);//其他儿子是轻边
}
}

伪代码:

1
2
3
4
5
6
7
8
Func dfs2(int st,int ancester)
Top[st]=ancestor
Build_connection(st)
If exist heavy_son
Dfs2(heavy_son,ancestor)
For(every node ed connected to the current node)
If(is not father of node now AND is not the heavy_son)
Dfs2(ed,ed)

步骤三,建立一颗加法线段树 这里的build_tree函数有一点不一样:
1
2
3
4
5
6
7
8
9
10
11
12
void build_tree(int L,int R,int id)
{
if(L==R)
{
T[id].sum=a[real[L]];//一定要考虑题目编号和线段树编号之间的映射!!!
return;
}
int mid=(L+R)/2;
build_tree(L,mid,id<<1);
build_tree(mid+1,R,id<<1|1);
pushup(id);
}

以上是预备工作,以下还有两种修改,两种查询:

子树修改:

由于是先序遍历,以x为根的子节点肯定在线段树区间][ sid[x],sid[i]+size[x]-1 ]内

只有一行代码:

modify(sid[x],sid[x]+size[x]-1,num_to_add);//这是伪代码

子树查询:

只是把修改换成了查询而已

1
Query(sid[x],sid[x]+size[x]-1)
重头戏:链修改 u to v(题目序号)

1:u,v在一条重边上

保持大小关系,之后modify(sid[u],sid[v],num_to_add)

2: u,v不在一条重边上:

那么u和v不停地向上跳,一定会跳到一条重边上

每次选取所在链的所在链头结点较低的那一条链进行统计,统计完了之后在向上加(跳)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void chain_add()
{
LL u,v,w,tu,tv;
scanf("%lld%lld%lld",&u,&v,&w);
tu=top[u],tv=top[v];
while(tu!=tv)
{
if(dep[tu]<dep[tv])
swap(tu,tv),swap(u,v);
modify(sid[tu],sid[u],w,1,N,1);
u=fa[tu];
tu=top[u];
}
if(dep[u]<dep[v]) swap(u,v);
modify(sid[v],sid[u],w,1,N,1);
}

伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Func chain_add(u,v)
Tu=top[u] tv=top[v]
While(not in the same chain => tu!=tv)
Make u the deepest
modify(sid[tu],sid[u])
Jump => u=father[tu] tu=top[u]
Make u the deepest
Modify(sid[v],sid[u])
##### 查询:把线段树修改换成查询
​```c++
void chain_query()
{
LL u,v,tu,tv,ret=0;
scanf("%lld%lld",&u,&v);
tu=top[u],tv=top[v];
while(tu!=tv)
{
if(dep[tu]<dep[tv])
{
swap(tu,tv);
swap(u,v);
}
ret+=query(sid[tu],sid[u],1,N,1);
ret%=P;
u=fa[tu];
tu=top[u];
}
if(dep[u]<dep[v]) swap(u,v);
ret+=query(sid[v],sid[u],1,N,1);
printf("%lld\n",ret%P);
}

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <cstring>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <queue>
#include <stack>
using namespace std;
typedef long long LL;
struct edge
{
int ed,next;
} E[int(1e6)+10];
int head[100100],Ecnt;
int N,M,R,P;
LL a[100100];
struct node
{
LL sum,atag;
} T[100100<<3];
LL real[100010],sid[100010],hson[100010],fa[100010],dep[100010],size[100010],top[100010];
LL tot;
// code
void addEdge(int st,int ed)
{
E[++Ecnt].ed=ed;
E[Ecnt].next=head[st];
head[st]=Ecnt;
}
void dfs1(int st)
{
size[st]=1;
for(int i=head[st]; i; i=E[i].next)
{
int ed=E[i].ed;
if(fa[st]!=ed)
{
dep[ed]=dep[st]+1;
fa[ed]=st;
dfs1(ed);
if(hson[st]==0||size[hson[st]]<size[ed]) hson[st]=ed;
size[st]+=size[ed];
}
}
}
void dfs2(int st,int anc)
{
//printf("dfs2 st=%d anc=%d\n",st,anc);
top[st]=anc;
//printf("top[%d]=%d\n",st,top[st]);
sid[st]=++tot;
real[tot]=st;
if(hson[st]==0) return;
dfs2(hson[st],anc);
for(int i=head[st]; i; i=E[i].next)
{
int ed=E[i].ed;
if(ed!=fa[st]&&ed!=hson[st])
dfs2(ed,ed);
}
}
/// 线段树
void pushup(int id)
{
T[id].sum=T[id<<1].sum+T[id<<1|1].sum;
}
void pushdown(int L,int R,int id)
{
if(T[id].atag)
{
if(L!=R)
{
T[id<<1].atag+=T[id].atag;
T[id<<1|1].atag+=T[id].atag;
}
T[id].sum+=(R-L+1)*T[id].atag;
T[id].atag=0;
}
}
void build_tree(int L,int R,int id)
{
if(L==R)
{
T[id].sum=a[real[L]];
return;
}
int mid=(L+R)/2;
build_tree(L,mid,id<<1);
build_tree(mid+1,R,id<<1|1);
pushup(id);
}
void modify(int B,int E,int x,int L,int R,int id)
{
if(L>E||R<B) return;
pushdown(L,R,id);
if(L>=B&&R<=E)
{
T[id].atag+=x;
return;
}
int mid=(L+R)>>1;
modify(B,E,x,L,mid,id<<1);
modify(B,E,x,mid+1,R,id<<1|1);
pushdown(L,mid,id<<1);
pushdown(mid+1,R,id<<1|1);
pushup(id);
}
LL query(int B,int E,int L,int R,int id)
{
if(L>E||R<B) return 0ll;
pushdown(L,R,id);
if(L>=B&&R<=E)
return T[id].sum%P;
int mid=(L+R)>>1;
return query(B,E,L,mid,id<<1)+query(B,E,mid+1,R,id<<1|1);
}
///end of segment tree
void tree_add()
{
//cout<<__func__<<endl;
LL x,v;
scanf("%lld%lld",&x,&v);
modify(sid[x],sid[x]+size[x]-1,v,1,N,1);
}
void tree_query()
{
//cout<<__func__<<endl;
LL x;
scanf("%lld",&x);
printf("%lld\n",query(sid[x],sid[x]+size[x]-1,1,N,1)%P);
}
void chain_add()
{
//cout<<__func__<<endl;
LL u,v,w,tu,tv;
scanf("%lld%lld%lld",&u,&v,&w);
//cout<<"["<<u<<","<<v<<"]"<<endl;
tu=top[u],tv=top[v];
//printf("%lld %lld\n",tu,tv);
while(tu!=tv)
{
//printf("dep[tu]=%lld,dep[tv]=%lld\n",dep[tu],dep[tv]);
if(dep[tu]<dep[tv])
{
swap(tu,tv);
swap(u,v);
}
modify(sid[tu],sid[u],w,1,N,1);
//printf("modifyed %lld to %lld\n",tu,u);
u=fa[tu];
tu=top[u];
}
if(dep[u]<dep[v]) swap(u,v);
modify(sid[v],sid[u],w,1,N,1);
// printf("modified %lld to %lld\n",v,u);
}
void chain_query()
{
//printf("\t%lld\n",query(sid[3],sid[3],1,N,1));
//cout<<__func__<<endl;
LL u,v,tu,tv,ret=0;
scanf("%lld%lld",&u,&v);
tu=top[u],tv=top[v];
while(tu!=tv)
{
if(dep[tu]<dep[tv])
{
swap(tu,tv);
swap(u,v);
}
LL t=query(sid[tu],sid[u],1,N,1);
//printf("queryed %lld to %lld == %lld\n",tu,u,t);
ret+=query(sid[tu],sid[u],1,N,1);
ret%=P;
//cout<<ret<<endl;
u=fa[tu];
tu=top[u];
}
if(dep[u]<dep[v]) swap(u,v);
LL t=query(sid[v],sid[u],1,N,1);
//printf("queryed %lld to %lld == %lld\n",v,u,t);
ret+=t;
printf("%lld\n",ret%P);
}
///
void init();
int main()
{
init();
//printf("\t%lld\n",query(sid[3],sid[3],1,N,1));
for(int i=1;i<=M;++i)
{
int q;
scanf("%d",&q);
switch(q)
{
case 4:tree_query();break;
case 3:tree_add();break;
case 2:chain_query();break;
case 1:chain_add();break;
}
}
return 0;
}
void init()
{
scanf("%d%d%d%d",&N,&M,&R,&P);
for(int i=1; i<=N; ++i)
scanf("%lld",a+i);
for(int i=1; i<=N-1; ++i)
{
int x,y;
scanf("%d%d",&x,&y);
addEdge(x,y);
addEdge(y,x);
}
dep[R]=1;
dfs1(R);
dfs2(R,R);
build_tree(1,N,1);
}