很久之前曾经学过这个解决树上问题的经典算法,奈何时间久远忘干净了,所以重新拿出来再学一遍加深印象.说是树分治,其实这篇文章的主要内容是点分治.故下面的所有树分治均指点分治.边分治算法之后再额外开文章.

引入

例题[国家集训队2011]聪聪可可

给出一棵树,树上每条边有个边权.路径间的距离定义为这条路径上所有边的边权和.求随机选择两个点,使得两点间的距离为3的倍数的概率.

题目链接

很明显这道题有一个DP做法,时间复杂度为O(n).代码如下所示:

#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
#define N 20010
#define l(x) (x<<1)
#define r(x) ((x<<1)+1)
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;

int n,i,x,y,z,lsum=1;
int head[N],v[N];
LL Ans,p;
LL dp[N][4];

struct Flow{
    int t,next,l;
}e[N*8];

inline int Abs(int x){return (x<0)?-x:x;}
inline void Swap(int &a,int &b){a^=b^=a^=b;}
inline int Min(int a,int b){return (a<b)?a:b;}
inline int Max(int a,int b){return (a>b)?a:b;}
inline LL Gcd(LL a,LL b){return (!b)?a:Gcd(b,a%b);}

inline int read(){
    int p=0;    char    c=getchar();
    while (c<48||c>57)  c=getchar();
    while (c>=48&&c<=57)    p=(p<<1)+(p<<3)+c-48,c=getchar();
    return p;
}

inline void Add(int s,int t,int l){
    e[lsum].t=t;    e[lsum].l=l;    e[lsum].next=head[s];   head[s]=lsum++;
}

inline void Work(int x){
    int i=0;
    for (i=head[x];i;i=e[i].next){
        if (v[e[i].t])  continue;
        v[e[i].t]=x;    Work(e[i].t);
        dp[x][e[i].l%3]+=dp[e[i].t][0];
        dp[x][(1+e[i].l)%3]+=dp[e[i].t][1];
        dp[x][(2+e[i].l)%3]+=dp[e[i].t][2];
    }
    Ans+=dp[x][0]*(dp[x][0]-1)/2;
    Ans+=dp[x][1]*dp[x][2];
    for (i=head[x];i;i=e[i].next){
        if (v[e[i].t]!=x)   continue;
        Ans-=(dp[e[i].t][(3-(e[i].l)%3)%3])*(dp[e[i].t][(3-(e[i].l)%3)%3]-1)/2;
        Ans-=dp[e[i].t][(4-e[i].l%3)%3]*dp[e[i].t][(5-e[i].l%3)%3];
    }
    Ans+=dp[x][0];  dp[x][0]++;
}

int main(){
    n=read();
    for (i=1;i<n;i++){
        x=read();   y=read();   z=read();
        Add(x,y,z); Add(y,x,z);
    }
    v[1]=-1;    Work(1);
    Ans=Ans*2+n;
    p=Gcd(Ans,n*n);
    printf("%lld/%lld\n",Ans/p,n*n/p);
    return 0;
}

树分治做法

首先考虑如何计算一个有根树中经过根节点的路径对答案的贡献.

当给定了一个有根树时,求出每个点到根节点的距离,然后直接计算贡献.这时不难证明,所有经过根节点的路径都已经被计算过了.所以,可以将根节点删去,计算剩下的那些森林对答案的贡献.

而上面过程中,求出所有节点到根节点的距离的时间复杂度是O(n)的.如果选择根节点不当的话,例如从一条链的一段不断选择节点作为根节点,那么算法的时间复杂度会退化到平方级.所以,为了避免这一情况,定义了树的重心的概念,并且每次选择重心作为新的根节点.

重心的定义是:对于有n个节点的树,重心作为根时,最大子树的大小不超过节点数的一半.

存在性证明:随机选择一个节点作为根节点,如果符合上面的定义,则证毕;如果不符合,找到最大的子树,其大小必然大于 \frac{n}{2} .将根节点转移到这个子树的根上.不断重复上述过程.显然,这个转移过程是不会走回头路的(否则会导出节点总数大于n,矛盾),而且不会一直进行下去,存在性证毕.

由于重心的良好性质,每次删除节点并继续计算的时间复杂度得到了很好的保证.因为每一次定下来一个根就需要遍历一次子树,重心又保证了接下来被遍历的子树大小不会超过之前的一半,故时间复杂度为

O(NlogN)

每次记录下需要寻找根节点的子树的大小,遍历一次,找到根,然后再遍历一次计算题目所需要的内容,再进行递归.这就是树(点)分治的核心过程.

找根的代码实现如下,size[x]表示以x为根的树的大小;msize[x]表示x的最大的子树大小.tot为需要找根的树的节点总数.

inline void FindRoot(int x,int fa){
    int i=0;
    size[x]=1;  msize[x]=0;
    for (i=head[x];i;i=e[i].next)
        if (!v[e[i].t]&&e[i].t!=fa) {
            FindRoot(e[i].t,x); msize[x]=Max(msize[x],size[e[i].t]);
            size[x]+=size[e[i].t];
        }
    msize[x]=Max(msize[x],tot-msize[x]);
    if (msize[x]<msize[root])   root=x;
}

找到之后,按要求进行处理,并递归计算.

inline void Solve(int x){
    int i=0;
    v[x]=1; Ans+=Calc(x,0);
    for (i=head[x];i;i=e[i].next)
        if (!v[e[i].t]){
            Ans-=Calc(e[i].t,e[i].l);
            root=0; tot=size[e[i].t];
            FindRoot(e[i].t,0); Solve(root);
        }
}

Calc函数根据题目而变化.这道题的Calc函数见下面的完整代码.因为计算时会有重复(例如两个到根距离为0的点不一定可以算作是一个合法点对,因为它们可能在同一条链上),所以在递归之前需要减去一部分.整个算法的时间复杂度为

O(NlogN)

虽然比上面的树上DP慢了一个Log,代码长度也长了一点,但是适用性比较广泛,具体可见下方的例子.

完整代码:

#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
#define N 20010
#define l(x) (x<<1)
#define r(x) ((x<<1)+1)
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;

int n,i,x,y,z,lsum=1,root,tot;
int head[N],v[N],size[N],msize[N],cnt[5],dis[N];
LL Ans,p;

struct Flow{
    int t,next,l;
}e[N*8];

inline int Max(int a,int b){return (a>b)?a:b;}
inline LL Gcd(LL a,LL b){return (!b)?a:Gcd(b,a%b);}

inline int read(){
    int p=0;    char    c=getchar();
    while (c<48||c>57)  c=getchar();
    while (c>=48&&c<=57)    p=(p<<1)+(p<<3)+c-48,c=getchar();
    return p;
}

inline void Add(int s,int t,int l){
    e[lsum].t=t;    e[lsum].l=l;    e[lsum].next=head[s];   head[s]=lsum++;
}

inline void FindRoot(int x,int fa){
    int i=0;
    size[x]=1;  msize[x]=0;
    for (i=head[x];i;i=e[i].next)
        if (!v[e[i].t]&&e[i].t!=fa) {
            FindRoot(e[i].t,x); msize[x]=Max(msize[x],size[e[i].t]);
            size[x]+=size[e[i].t];
        }
    msize[x]=Max(msize[x],tot-msize[x]);
    if (msize[x]<msize[root])   root=x;
}

inline void Dfs(int x,int fa){
    int i=0;
    cnt[dis[x]]++;
    for (i=head[x];i;i=e[i].next){
        if (!v[e[i].t]&&e[i].t!=fa) {
            dis[e[i].t]=(dis[x]+e[i].l)%3;
            Dfs(e[i].t,x);
        }
    }
}

inline int Calc(int x,int p){
    int i=0;
    dis[x]=p%3; cnt[0]=cnt[1]=cnt[2]=0;
    Dfs(x,0);
    return cnt[0]*cnt[0]+2*cnt[1]*cnt[2];
}

inline void Solve(int x){
    int i=0;
    v[x]=1; Ans+=Calc(x,0);
    for (i=head[x];i;i=e[i].next)
        if (!v[e[i].t]){
            Ans-=Calc(e[i].t,e[i].l);
            root=0; tot=size[e[i].t];
            FindRoot(e[i].t,0); Solve(root);
        }
}

int main(){
    n=read();
    for (i=1;i<n;i++){
        x=read();   y=read();   z=read();
        Add(x,y,z); Add(y,x,z);
    }
    tot=n;  msize[root=0]=INF;
    FindRoot(1,0);      Solve(1);
    p=Gcd(Ans,n*n);
    printf("%lld/%lld\n",Ans/p,n*n/p);
    return 0;
}

拓展

例题[IOI2011]Race

给一棵树,每条边有非负权.求一条简单路径,权值和等于K,且边的数量最小.

题目链接

显然,这时不能再使用类似之前dp[i][j]的形式进行转移了,空间和时间均不允许.

考虑使用点分治,求以某个点为根的子树中,距离为p的最小边的数量是多少,用sum[p]表示.但是这道题有个麻烦的地方是,因为题目要求的是最小值,故答案不再具有加减性.不合法的方案不能之间减去了.所以,计算答案的时候需要进行两次递归.

每一次递归均从根节点的子节点开始.第一次递归只考虑能否更新答案,并不更新sum数组.第二次递归只更新sum数组,不计算答案.这样保证了计算答案的时候,利用的sum数组的信息均来自根节点的其他子树,即一定合法.

但是,还有一个问题是sum数组的信息可能来自其他的分治块.如果在每次分治的时候清空,时间复杂度会是平方级的.因此再引入dfn数组作为时间戳,保证了sum信息的准确性.中间有些细节见下方代码.总的时间复杂度为

O(NlogN)

完整代码:

#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
#define N 200010
#define l(x) (x<<1)
#define r(x) ((x<<1)+1)
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;

int n,k,i,x,y,z,lsum=1,root,tot,Ans,dp;
int head[N],v[N],size[N],msize[N],sum[N*5],dfn[N*5];

struct Edge{
    int t,next,l;
}e[N*8];

inline int Abs(int x){return (x<0)?-x:x;}
inline void Swap(int &a,int &b){a^=b^=a^=b;}
inline int Min(int a,int b){return (a<b)?a:b;}
inline int Max(int a,int b){return (a>b)?a:b;}

inline int read(){
    int p=0;    char    c=getchar();
    while (c<48||c>57)  c=getchar();
    while (c>=48&&c<=57)    p=(p<<1)+(p<<3)+c-48,c=getchar();
    return p;
}

inline void Add(int s,int t,int l){
    e[lsum].t=t;    e[lsum].l=l;    e[lsum].next=head[s];   head[s]=lsum++;
}

inline void FindRoot(int x,int fa){
    int i=0;
    size[x]=1;  msize[x]=0;
    for (i=head[x];i;i=e[i].next)
        if (!v[e[i].t]&&e[i].t!=fa){
            FindRoot(e[i].t,x); size[x]+=size[e[i].t];
            msize[x]=Max(msize[x],size[e[i].t]);
        }
    msize[x]=Max(tot-msize[x],msize[x]);
    if (msize[x]<msize[root])   root=x;
}

inline void Dfs(int x,int fa,int val,int p,int op){
    int i=0;
    if (val>k)  return;
    if (op==1) {
        if (dfn[k-val]==dp) Ans=Min(Ans,p+sum[k-val]);
    }   else {
        if (dfn[val]!=dp)   dfn[val]=dp,sum[val]=p;
        else sum[val]=Min(sum[val],p);
    }
    for (i=head[x];i;i=e[i].next)
        if (!v[e[i].t]&&e[i].t!=fa)
            Dfs(e[i].t,x,val+e[i].l,p+1,op);
}

inline void Solve(int x){
    int i=0;
    v[x]=1; dp++;
    dfn[0]=dp;  sum[0]=0;
    for (i=head[x];i;i=e[i].next){
        if (v[e[i].t])  continue;
        Dfs(e[i].t,x,e[i].l,1,1);
        Dfs(e[i].t,x,e[i].l,1,0);
    }
    for (i=head[x];i;i=e[i].next){
        if (v[e[i].t])  continue;
        tot=size[e[i].t];   root=0;
        FindRoot(e[i].t,0); Solve(root);
    }
}

int main(){
    n=read();   k=read();
    for (i=1;i<n;i++){
        x=read()+1; y=read()+1; z=read();
        Add(x,y,z); Add(y,x,z);
    }
    memset(sum,INF,sizeof(sum));
    msize[root=0]=INF;  tot=n;  Ans=INF;
    FindRoot(1,0);  Solve(root);
    printf("%d\n",(Ans==INF)?-1:Ans);
    return 0;
}

例题[POJ1741]树上的点对

给一棵有n个节点的树,求有多少个点对间的距离小于k.

题目链接

和第一道例题相似的思路,每次分治完毕后,先整体算一次答案,再减去非法的情况.计算时,需要统计所有链的长度,然后排序,头尾两个指针向中间跳的同时计算方案数.时间复杂度为

O(NlogNlogN)

说点什么
支持Markdown语法
好耶,沙发还空着ヾ(≧▽≦*)o
Loading...