树分治———点分治

点分治还是很有意思的,不过自己看博客还是比较吃力的,看了好久。
基本思想:
找一颗树上的重心(所有以此点为根的子树中最大子树最小的那个根结点)
分治解决。(这是关键。)
详细解释一下:一开始结点数为n,朴素算法O($n^2$),若找重心分治处理就是O($log(n)$)层,每次合并时间是(这里是O(nlog(n)+n) 排序+找最远可行对)不定。
那最终就是O(nlog(n)log(n)).
最终的决定权基本在合并算法的手里,能合并,就能处理。

需要注意的是

       ans-=solve(v,w);

要减掉没有经过根的点

#include<iostream>
#include<stdio.h>
#include<string>
#include<cstring>
#include<vector>
#include<stack>
#include<algorithm>
#include<queue>
using namespace std;
#define clr(a, x) memset(a, x, sizeof(a))
#define mp(x, y) make_pair(x, y)
#define pb(x) push_back(x)
#define X first
#define Y second
#define fastin                    \
    ios_base::sync_with_stdio(0); \
    cin.tie(0);
typedef long long LL;
typedef pair<int, int> PII;
typedef vector<int> VI;
const int INF=0x3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-6;
const int N=10010;
const int M=40010;
int head[N];
int top;
struct Edge
{
    int val;
    int to;
    int next;
}edge[M];
void addedge(int a,int b,int c)
{
    edge[top]=Edge{c,b,head[a]};
    head[a]=top++;
}
void init()
{
    top=0;
    memset(head,-1,sizeof(head));
}
int k;
int root,sim[N],S,mxson[N];
int MX,vis[N];
void getroot(int u,int fa){
    sim[u]=1;mxson[u]=0;
    for(int i=head[u];~i;i=edge[i].next){
        int v=edge[i].to;
        int w=edge[i].val;
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        sim[u]+=sim[v];
        mxson[u]=max(mxson[u],sim[v]);
    }
    mxson[u] = max(mxson[u],S-sim[u]);
    if(mxson[u]<MX){
        MX=mxson[u];
        root=u;
    }
}
LL ans;
int id;
int dis[N]; //到重心的距离
void getdis(int u,int fa,int dist) //点u到root的距离
{
    dis[id++] = dist;
    for(int i=head[u];~i;i=edge[i].next){
        int v=edge[i].to;int w=edge[i].val;
        if(v==fa||vis[v]) continue;
        getdis(v,u,dist+w);
    }
}
int solve(int u,int len){ //排序找符合的点
    id=0;
    clr(dis,0);
    getdis(u,0,len);
    sort(dis,dis+id);
    int L=0,R=id-1;
    int tmp=0;
    while(L<R){
        if(dis[R]+dis[L]<=k){tmp+=R-L;L++;}
        else R--;
    }
    return tmp;
}
void Divide(int u){
    ans+=solve(u,0);
    vis[u]=true;
    for(int i=head[u];~i;i=edge[i].next){
        int v=edge[i].to;int w=edge[i].val;
        if(vis[v]) continue;
        ans-=solve(v,w);
        S=sim[v];root=0;
        MX=INF;getroot(v,0);
        Divide(root);
    }
}
int main()
{
    int n;
    while(scanf("%d%d",&n,&k),n+k){
        init();
        int a,b,c;
        for(int i=1;i<n;i++){
            scanf("%d%d%d",&a,&b,&c);
            addedge(a,b,c);
            addedge(b,a,c);
        }
        clr(vis,0);
        S=n,MX=INF;
        root=0;ans=0;
        getroot(1,0);
        Divide(root);
        printf("%lld\n",ans);
    }
    return 0;
}
添加新评论