很久之前曾经学过这个解决树上问题的经典算法,奈何时间久远忘干净了,所以重新拿出来再学一遍加深印象.说是树分治,其实这篇文章的主要内容是点分治.故下面的所有树分治均指点分治.边分治算法之后再额外开文章.
引入
例题[国家集训队2011]聪聪可可
给出一棵树,树上每条边有个边权.路径间的距离定义为这条路径上所有边的边权和.求随机选择两个点,使得两点间的距离为3的倍数的概率.
很明显这道题有一个DP做法,时间复杂度为O(n).代码如下所示:
#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
#define N 20010
#define l(x) (x<<1)
#define r(x) ((x<<1)+1)
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;
int n,i,x,y,z,lsum=1;
int head[N],v[N];
LL Ans,p;
LL dp[N][4];
struct Flow{
int t,next,l;
}e[N*8];
inline int Abs(int x){return (x<0)?-x:x;}
inline void Swap(int &a,int &b){a^=b^=a^=b;}
inline int Min(int a,int b){return (a<b)?a:b;}
inline int Max(int a,int b){return (a>b)?a:b;}
inline LL Gcd(LL a,LL b){return (!b)?a:Gcd(b,a%b);}
inline int read(){
int p=0; char c=getchar();
while (c<48||c>57) c=getchar();
while (c>=48&&c<=57) p=(p<<1)+(p<<3)+c-48,c=getchar();
return p;
}
inline void Add(int s,int t,int l){
e[lsum].t=t; e[lsum].l=l; e[lsum].next=head[s]; head[s]=lsum++;
}
inline void Work(int x){
int i=0;
for (i=head[x];i;i=e[i].next){
if (v[e[i].t]) continue;
v[e[i].t]=x; Work(e[i].t);
dp[x][e[i].l%3]+=dp[e[i].t][0];
dp[x][(1+e[i].l)%3]+=dp[e[i].t][1];
dp[x][(2+e[i].l)%3]+=dp[e[i].t][2];
}
Ans+=dp[x][0]*(dp[x][0]-1)/2;
Ans+=dp[x][1]*dp[x][2];
for (i=head[x];i;i=e[i].next){
if (v[e[i].t]!=x) continue;
Ans-=(dp[e[i].t][(3-(e[i].l)%3)%3])*(dp[e[i].t][(3-(e[i].l)%3)%3]-1)/2;
Ans-=dp[e[i].t][(4-e[i].l%3)%3]*dp[e[i].t][(5-e[i].l%3)%3];
}
Ans+=dp[x][0]; dp[x][0]++;
}
int main(){
n=read();
for (i=1;i<n;i++){
x=read(); y=read(); z=read();
Add(x,y,z); Add(y,x,z);
}
v[1]=-1; Work(1);
Ans=Ans*2+n;
p=Gcd(Ans,n*n);
printf("%lld/%lld\n",Ans/p,n*n/p);
return 0;
}
树分治做法
首先考虑如何计算一个有根树中经过根节点的路径对答案的贡献.
当给定了一个有根树时,求出每个点到根节点的距离,然后直接计算贡献.这时不难证明,所有经过根节点的路径都已经被计算过了.所以,可以将根节点删去,计算剩下的那些森林对答案的贡献.
而上面过程中,求出所有节点到根节点的距离的时间复杂度是O(n)的.如果选择根节点不当的话,例如从一条链的一段不断选择节点作为根节点,那么算法的时间复杂度会退化到平方级.所以,为了避免这一情况,定义了树的重心的概念,并且每次选择重心作为新的根节点.
重心的定义是:对于有n个节点的树,重心作为根时,最大子树的大小不超过节点数的一半.
存在性证明:随机选择一个节点作为根节点,如果符合上面的定义,则证毕;如果不符合,找到最大的子树,其大小必然大于 \frac{n}{2} .将根节点转移到这个子树的根上.不断重复上述过程.显然,这个转移过程是不会走回头路的(否则会导出节点总数大于n,矛盾),而且不会一直进行下去,存在性证毕.
由于重心的良好性质,每次删除节点并继续计算的时间复杂度得到了很好的保证.因为每一次定下来一个根就需要遍历一次子树,重心又保证了接下来被遍历的子树大小不会超过之前的一半,故时间复杂度为
O(NlogN)
每次记录下需要寻找根节点的子树的大小,遍历一次,找到根,然后再遍历一次计算题目所需要的内容,再进行递归.这就是树(点)分治的核心过程.
找根的代码实现如下,size[x]表示以x为根的树的大小;msize[x]表示x的最大的子树大小.tot为需要找根的树的节点总数.
inline void FindRoot(int x,int fa){
int i=0;
size[x]=1; msize[x]=0;
for (i=head[x];i;i=e[i].next)
if (!v[e[i].t]&&e[i].t!=fa) {
FindRoot(e[i].t,x); msize[x]=Max(msize[x],size[e[i].t]);
size[x]+=size[e[i].t];
}
msize[x]=Max(msize[x],tot-msize[x]);
if (msize[x]<msize[root]) root=x;
}
找到之后,按要求进行处理,并递归计算.
inline void Solve(int x){
int i=0;
v[x]=1; Ans+=Calc(x,0);
for (i=head[x];i;i=e[i].next)
if (!v[e[i].t]){
Ans-=Calc(e[i].t,e[i].l);
root=0; tot=size[e[i].t];
FindRoot(e[i].t,0); Solve(root);
}
}
Calc函数根据题目而变化.这道题的Calc函数见下面的完整代码.因为计算时会有重复(例如两个到根距离为0的点不一定可以算作是一个合法点对,因为它们可能在同一条链上),所以在递归之前需要减去一部分.整个算法的时间复杂度为
O(NlogN)
虽然比上面的树上DP慢了一个Log,代码长度也长了一点,但是适用性比较广泛,具体可见下方的例子.
完整代码:
#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
#define N 20010
#define l(x) (x<<1)
#define r(x) ((x<<1)+1)
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;
int n,i,x,y,z,lsum=1,root,tot;
int head[N],v[N],size[N],msize[N],cnt[5],dis[N];
LL Ans,p;
struct Flow{
int t,next,l;
}e[N*8];
inline int Max(int a,int b){return (a>b)?a:b;}
inline LL Gcd(LL a,LL b){return (!b)?a:Gcd(b,a%b);}
inline int read(){
int p=0; char c=getchar();
while (c<48||c>57) c=getchar();
while (c>=48&&c<=57) p=(p<<1)+(p<<3)+c-48,c=getchar();
return p;
}
inline void Add(int s,int t,int l){
e[lsum].t=t; e[lsum].l=l; e[lsum].next=head[s]; head[s]=lsum++;
}
inline void FindRoot(int x,int fa){
int i=0;
size[x]=1; msize[x]=0;
for (i=head[x];i;i=e[i].next)
if (!v[e[i].t]&&e[i].t!=fa) {
FindRoot(e[i].t,x); msize[x]=Max(msize[x],size[e[i].t]);
size[x]+=size[e[i].t];
}
msize[x]=Max(msize[x],tot-msize[x]);
if (msize[x]<msize[root]) root=x;
}
inline void Dfs(int x,int fa){
int i=0;
cnt[dis[x]]++;
for (i=head[x];i;i=e[i].next){
if (!v[e[i].t]&&e[i].t!=fa) {
dis[e[i].t]=(dis[x]+e[i].l)%3;
Dfs(e[i].t,x);
}
}
}
inline int Calc(int x,int p){
int i=0;
dis[x]=p%3; cnt[0]=cnt[1]=cnt[2]=0;
Dfs(x,0);
return cnt[0]*cnt[0]+2*cnt[1]*cnt[2];
}
inline void Solve(int x){
int i=0;
v[x]=1; Ans+=Calc(x,0);
for (i=head[x];i;i=e[i].next)
if (!v[e[i].t]){
Ans-=Calc(e[i].t,e[i].l);
root=0; tot=size[e[i].t];
FindRoot(e[i].t,0); Solve(root);
}
}
int main(){
n=read();
for (i=1;i<n;i++){
x=read(); y=read(); z=read();
Add(x,y,z); Add(y,x,z);
}
tot=n; msize[root=0]=INF;
FindRoot(1,0); Solve(1);
p=Gcd(Ans,n*n);
printf("%lld/%lld\n",Ans/p,n*n/p);
return 0;
}
拓展
例题[IOI2011]Race
给一棵树,每条边有非负权.求一条简单路径,权值和等于K,且边的数量最小.
显然,这时不能再使用类似之前dp[i][j]的形式进行转移了,空间和时间均不允许.
考虑使用点分治,求以某个点为根的子树中,距离为p的最小边的数量是多少,用sum[p]表示.但是这道题有个麻烦的地方是,因为题目要求的是最小值,故答案不再具有加减性.不合法的方案不能之间减去了.所以,计算答案的时候需要进行两次递归.
每一次递归均从根节点的子节点开始.第一次递归只考虑能否更新答案,并不更新sum数组.第二次递归只更新sum数组,不计算答案.这样保证了计算答案的时候,利用的sum数组的信息均来自根节点的其他子树,即一定合法.
但是,还有一个问题是sum数组的信息可能来自其他的分治块.如果在每次分治的时候清空,时间复杂度会是平方级的.因此再引入dfn数组作为时间戳,保证了sum信息的准确性.中间有些细节见下方代码.总的时间复杂度为
O(NlogN)
完整代码:
#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
#define N 200010
#define l(x) (x<<1)
#define r(x) ((x<<1)+1)
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;
int n,k,i,x,y,z,lsum=1,root,tot,Ans,dp;
int head[N],v[N],size[N],msize[N],sum[N*5],dfn[N*5];
struct Edge{
int t,next,l;
}e[N*8];
inline int Abs(int x){return (x<0)?-x:x;}
inline void Swap(int &a,int &b){a^=b^=a^=b;}
inline int Min(int a,int b){return (a<b)?a:b;}
inline int Max(int a,int b){return (a>b)?a:b;}
inline int read(){
int p=0; char c=getchar();
while (c<48||c>57) c=getchar();
while (c>=48&&c<=57) p=(p<<1)+(p<<3)+c-48,c=getchar();
return p;
}
inline void Add(int s,int t,int l){
e[lsum].t=t; e[lsum].l=l; e[lsum].next=head[s]; head[s]=lsum++;
}
inline void FindRoot(int x,int fa){
int i=0;
size[x]=1; msize[x]=0;
for (i=head[x];i;i=e[i].next)
if (!v[e[i].t]&&e[i].t!=fa){
FindRoot(e[i].t,x); size[x]+=size[e[i].t];
msize[x]=Max(msize[x],size[e[i].t]);
}
msize[x]=Max(tot-msize[x],msize[x]);
if (msize[x]<msize[root]) root=x;
}
inline void Dfs(int x,int fa,int val,int p,int op){
int i=0;
if (val>k) return;
if (op==1) {
if (dfn[k-val]==dp) Ans=Min(Ans,p+sum[k-val]);
} else {
if (dfn[val]!=dp) dfn[val]=dp,sum[val]=p;
else sum[val]=Min(sum[val],p);
}
for (i=head[x];i;i=e[i].next)
if (!v[e[i].t]&&e[i].t!=fa)
Dfs(e[i].t,x,val+e[i].l,p+1,op);
}
inline void Solve(int x){
int i=0;
v[x]=1; dp++;
dfn[0]=dp; sum[0]=0;
for (i=head[x];i;i=e[i].next){
if (v[e[i].t]) continue;
Dfs(e[i].t,x,e[i].l,1,1);
Dfs(e[i].t,x,e[i].l,1,0);
}
for (i=head[x];i;i=e[i].next){
if (v[e[i].t]) continue;
tot=size[e[i].t]; root=0;
FindRoot(e[i].t,0); Solve(root);
}
}
int main(){
n=read(); k=read();
for (i=1;i<n;i++){
x=read()+1; y=read()+1; z=read();
Add(x,y,z); Add(y,x,z);
}
memset(sum,INF,sizeof(sum));
msize[root=0]=INF; tot=n; Ans=INF;
FindRoot(1,0); Solve(root);
printf("%d\n",(Ans==INF)?-1:Ans);
return 0;
}
例题[POJ1741]树上的点对
给一棵有n个节点的树,求有多少个点对间的距离小于k.
和第一道例题相似的思路,每次分治完毕后,先整体算一次答案,再减去非法的情况.计算时,需要统计所有链的长度,然后排序,头尾两个指针向中间跳的同时计算方案数.时间复杂度为
O(NlogNlogN)