没错...我就是要讲点分治。这个东西原本学过的,当时学得不好...今天模拟赛又考这个东西结果写不出来。

于是博主专门又去学了学这个东西,这次绝对要搞懂了...【复赛倒计时:11天】

——正片开始——

点分是一个用来解决树上路径问题、距离问题的算法。说直接点其实就是分治思想在树上的体现和应用。

首先是分治的方法,既然是树上问题,自然就是对树进行分治。

我们对于一棵树,选定一个点作为它的根,将这个点与其子树拆开,然后对它的子树也进行相同的操作,然后利用子树统计答案。一般来说,点分更像一种思想,而不是一个算法,当然你也可以把它当算法来学。

关于怎么选根来分树,其他dalao的博客已经讲得非常清楚仔细了,每次选定一棵树的重心,然后分它,这样可以做到

O(nlogn)的优秀时间复杂度。

关于求重心,做法就是一个size统计。

这里还是介绍一下吧(很多博客都只贴一个代码...):
对于一个节点x,若我们把其删除,原来的树就会变成若干部分,我们设maxsubtree(x)表示删除x后剩下的最大子树的大小,若我们找到一个x,使得maxsubtree(x)最小,x就是这棵树的重心。

给出求重心的代码:

void getroot(int x,int fa){
    siz[x]=1;son[x]=0;
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(y==fa||vis[y])continue;
        getroot(y,x);
        siz[x]+=siz[y];
        if(siz[y]>son[x])son[x]=siz[y];
    }
    if(Siz-siz[x]>son[x])son[x]=Siz-siz[x];
    if(son[x]<maxx)maxx=son[x],root=x;
}

于是我们就知道怎么拆树了,后面的东西就不难了。

update: 19.11.5 对于上面的代码加一些解释,Siz表示当前在求重心的这一棵树的大小,root用来记录重心

我们来讲如何统计信息

首先我们要统计与每一次找到的重心有关的路径,我们用dis[x]表示目标点(重心)到x的距离。

给出getdis的代码:

void getdis(int x,int fa,int d){//d表示x到目标点的距离
    dis[++top]=d;//我们只需要知道每一个点到目标点的距离就行,不用知道这个点是哪个
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(y==fa||vis[y])continue;
        getdis(y,x,d+val(i));
    }
}

有了dis数组后,我们就可以很轻松的获得路径的长度了。

比如,我们已知x到重心的距离是m,y到重心的距离是n,那x到y的距离就是m+n,可能细心的读者已经发现锅了。

如果x到y的路径都在一棵子树内,我们就会有一段距离被重复计算了,这样我们得到的路径就是不对的。

给一张图理解一下:
[点分治系列] 静态点分-LMLPHP如图,蓝色的代表x到重心的路径,红色的代表y到重心的路径,我们可以得到dis[x]=3,dis[y]=2。

如果按照前面说的计算方式,x到y的路径长度应该是5了,但是并不是,我们的路径长度是3。

原因就是,绿色的那一段,我们根本不走,我们不经过它。

于是我们要解决这种不合法路径。知道所有路径,又知道不合法路径,利用容斥原理,我们可以得到:

合法路径=所有路径 - 不合法路径。

这样我们divide的代码就出来了:

void divide(int x){
    vis[x]=1;//保证我们每次divide的x都是当前这棵树的重心,所以标记x已经divide过了
    solve(x,0,1);//计算这棵树以x为重心的所有路径答案
    int totsiz=Siz;//这棵树的大小
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(vis[y])continue;
        solve(y,val(i),0);//求不合法路径
        maxx=inf,root=0;//初始化
        Siz=siz[y]>siz[x]?totsiz-siz[x]:siz[y];//更新siz
        getroot(y,0);//求出以y为根的子树
        divide(root);//以y为根的子树分治下去
    }
}

我们看一看去除不合法路径的代码:

solve(y,val(i),-1);

思考发现:

所有不合法路径都是在同一棵子树中的路径,我们要减去它。

先往下看完solve再回来看这里。

我们进入到solve中,首先是getdis,以x为目标点获取dis。但是我们要获取的是距离为val(i)的dis。

这就使得dis[y]=val(i),所以以y为根的子树中的所有dis值就等于dis[y]+它们到y的距离,然后因为dis[y]是x到y的距离,

所以我们就求出了以y为根的子树中所有点到x的距离。

实际一点的理解就是,y到目标点的距离是dis[y]=val(i),y的子树中的点到x的距离就是它们到y的距离+dis[y],所以跑一遍可以求出y的子树中所有点到x的距离。

后就是solve函数了:

我们这里的solve以模板题:【模板】点分治1 为例,不同题目我们solve的东西不同。

首先肯定可以想到一个O(n^2)的做法的,确实可以水过去这一题。

但是我毕竟是在写博客...所以,O(nlogn)做法奉上...

我们这么想,我们需要统计路径长为k的点对个数,那我们只要确定了一个点x,另一个点的dis[y]就应该是k-dis[x],我们只需要二分找k-dis[x]就行了。

void solve(int x,int d,int avl){//avl(available)表示这次计算得到的答案是否合法 
    top=0;//清空dis数组 
    getdis(x,0,d);//获取到当前这棵树到x的距离为d的所有dis 
    int cnt=0;
    sort(dis+1,dis+top+1);//排好序准备二分
    dis[0]=-1;//第一个dis设置为奇怪的数方便下面比较 
    for(int i=1;i<=top;i++){//把所有距离相同的点放进一个桶里面方便操作 
        if(dis[i]==dis[i-1])
            bucket[cnt].amount++;//原来桶的个数+1 
        else
            bucket[++cnt].dis=dis[i],bucket[cnt].amount=1;//新开一个桶 
    }
    for(int i=1;i<=m;i++){
        if(query[i]%2==0)//如果k是偶数的话,我们单独考虑一下距离为k/2那些点,它们可以互相配对形成长为k的路径 
            for(int j=1;j<=cnt;j++)
                if(bucket[j].dis==query[i]/2)//如果距离是k/2 
                    ans[i]+=(bucket[j].amount-1)*bucket[j].amount/2*avl;
                    //组合计数,假设我们有x个距离为k/2的点,就有(x-1)*x/2个点对距离为k,也就是我们可以配出这么多个不同点对
                    //其实就是C(x,2)->x!/((x-2)!*2)->(x-1)*x/2
        for(int j=1;j<=cnt&&bucket[j].dis<query[i]/2;j++){
        //接着枚举<k/2的距离,然后我们二分找>2的距离配对,避免重复(点对(u,v)和(v,u)是等价的),等于k/2的我们前面算过了,所以所有情况都考虑到了 int l=j+1,r=cnt; while(l<=r){ int mid=(l++r)>>1; if(bucket[j].dis+bucket[mid].dis==query[i]){ ans[i]+=bucket[j].amount*bucket[mid].amount*avl; //组合计数记录答案,假设我们有x个距离为m的点,y个距离为k-m的点,我们就有x*y个不同的点对(分类相乘) break;//这一轮二分完了,下一轮 } if(bucket[j].dis+bucket[mid].dis>query[i])r=mid-1;//大了,往小的二分 else l=mid+1;//小了,往大的二分 } } } }

这么详细都看不懂我就教不了了...

接下来就给出所有代码吧...(我知道你们只想看这个/doge)

#include<bits/stdc++.h>
#define N 100010
#define lint long long
#define inf 0x7fffffff
using namespace std;
int vis[N],son[N],Siz,maxx,siz[N];
int root,head[N],tot,n,m,dis[N],top,query[N],ans[N];
lint k;
struct Bucket{
    int dis,amount;
}bucket[N];
struct Edge{
    int nxt,to,val;
    #define nxt(x) e[x].nxt
    #define to(x) e[x].to
    #define val(x) e[x].val
}e[N<<1];
inline int read(){
    int data=0,w=1;char ch=0;
    while(ch!='-' && (ch<'0'||ch>'9'))ch=getchar();
    if(ch=='-')w=-1,ch=getchar();
    while(ch>='0' && ch<='9')data=data*10+ch-'0',ch=getchar();
    return data*w;
}
inline void addedge(int f,int t,int val){
    nxt(++tot)=head[f];to(tot)=t;val(tot)=val;head[f]=tot;
}
void getroot(int x,int fa){
    siz[x]=1;son[x]=0;
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(y==fa||vis[y])continue;
        getroot(y,x);
        siz[x]+=siz[y];
        if(siz[y]>son[x])son[x]=siz[y];
    }
    if(Siz-siz[x]>son[x])son[x]=Siz-siz[x];
    if(son[x]<maxx)maxx=son[x],root=x;
}
void getdis(int x,int fa,int d){
    dis[++top]=d;
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(y==fa||vis[y])continue;
        getdis(y,x,d+val(i));
    }
}
void solve(int rt,int d,int avl){//avl(available)表示这次计算得到的答案是否合法 
    top=0;//清空dis数组 
    getdis(rt,0,d);//获取到当前这棵树的rt的距离为d的所有dis 
    int cnt=0;
    sort(dis+1,dis+top+1);//排好序准备二分
    dis[0]=-1;//第一个dis设置为奇怪的数方便下面比较 
    for(int i=1;i<=top;i++){//把所有距离相同的点放进一个桶里面方便操作 
        if(dis[i]==dis[i-1])
            bucket[cnt].amount++;//原来桶的个数+1 
        else
            bucket[++cnt].dis=dis[i],bucket[cnt].amount=1;//新开一个桶 
    }
    for(int i=1;i<=m;i++){
        if(query[i]%2==0)//如果k是偶数的话,我们单独考虑一下距离为k/2那些点,它们可以互相配对形成长为k的路径 
            for(int j=1;j<=cnt;j++)
                if(bucket[j].dis==query[i]/2)//如果距离是k/2 
                    ans[i]+=(bucket[j].amount-1)*bucket[j].amount/2*avl;
                    //组合计数,假设我们有x个距离为k/2的点,就有(x-1)*x/2个点对距离为k,也就是我们可以配出这么多个不同点对
                    //其实就是C(x,2)->x!/((x-2)!*2)->(x-1)*x/2
        for(int j=1;j<=cnt&&bucket[j].dis<query[i]/2;j++){//接着枚举<k/2的距离,然后我们二分找>2的距离配对 
            int l=j+1,r=cnt;
            while(l<=r){
                int mid=(l+r)>>1;
                if(bucket[j].dis+bucket[mid].dis==query[i]){
                    ans[i]+=bucket[j].amount*bucket[mid].amount*avl;
                    //组合计数记录答案,假设我们有x个距离为m的点,y个距离为k-m的点,我们就有x*y个不同的点对(分类相乘)
                    break;//这一轮二分完了,下一轮 
                }
                if(bucket[j].dis+bucket[mid].dis>query[i])r=mid-1;//大了,往小的二分
                else l=mid+1;//小了,往大的二分
            }
        }
    }
}
void divide(int x){
    vis[x]=1;
    solve(x,0,1);//合法的算进去 
    int totsiz=Siz;
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(vis[y])continue;
        solve(y,val(i),-1);//不合法的算出来减掉 
        maxx=inf,root=0;
        Siz=siz[y]>siz[x]?totsiz-siz[x]:siz[y];
        getroot(y,0);
        divide(root);
    }
}
int main(){
    n=read();m=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read(),z=read();
        addedge(x,y,z);addedge(y,x,z);
    }
    for(int i=1;i<=m;i++)query[i]=read();
    maxx=inf;root=0;Siz=n;
    getroot(1,0);
    divide(root);
    for(int i=1;i<=m;i++){
        if(ans[i]>0)puts("AYE");
        else puts("NAY");
    }
    return 0;
}

这份代码不仅求出了是否有路径长为k的答案存在,还求出了路径长为k的点对个数。

不需要求个数的话你就改一改solve,这样常数会小一点,跑得更快,但是这一题我更想给你们讲讲思路,学会举一反三...最后还是那句话,我个人并不认为点分是一种算法,它更多体现的是分治的思想在树上的应用。

其实你会发现,很多高端的数据结构、算法之类的,都是由基础的算法思想衍生出来的泛用性更强的东西。

懂了基础的算法思想,你不但可以轻松的学会各种高阶算法,甚至可以自己造出解题的算法。

比如图论里面的dijkstra最短路算法,不就是贪心吗?再比如线段树,不就是分治吗?

一些看起来很高级,听起来很难的东西,只要你弄明白了其中的本质,你在感叹发明者的智慧同时,自己也就收获到了其中蕴含的知识和基础思想的应用方法。学习不是死学...只有弄明白了它的工作原理和方式,你才算是掌握了它,仅仅是可以用它来做题,那题目变变形,你就一脸懵了。

11-06 12:51