给出一颗n个结点的树,点上有权

求点对(x,y)的数量

其中 x!=y,x到y的路径上最大值与最小值的差<=D

按最小值排序,用最大值二分最小值比他小的所有点,容斥一下,最后答案*2即可

#include<bits/stdc++.h>
#define ll long long
#define rep(ii,a,b) for(int ii=a;ii<=b;++ii)
#define per(ii,a,b) for(int ii=b;ii>=a;--ii)
#define IO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
#define fi first
#define se second
#define mp make_pair
#define pii pair<int,int>
using namespace std;//head
const int maxn=1e5+10,maxm=2e6+10;
const ll INF=0x3f3f3f3f,mod=1e9+7;
int casn,n,m,k;
int val[maxn];
namespace graph{
  struct node{int to,next;}e[maxn<<1];
  int head[maxn],nume,all,root,maxt,sz[maxn];
  bool vis[maxn];
  void add(int a,int b){
    e[++nume]={b,head[a]};head[a]=nume;
  }
  void init(int n){
    rep(i,0,n) vis[i]=head[i]=0;root=nume=1;
  }
  void getroot(int now,int fa){
    sz[now]=1;
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(to==fa||vis[to]) continue;
      getroot(to,now);
      sz[now]+=sz[to];
    }
    int tmp=max(sz[now]-1,all-sz[now]);
    if(maxt>tmp) maxt=tmp,root=now;
  }//@基础部分@
  pii dis[maxn];
  int dfn;
  ll ans;
  void dfs(int now,int fa,int mx,int mn){
    mx=max(val[now],mx),mn=min(val[now],mn);
    if(mx-mn<=k)  dis[++dfn]=mp(mn,mx);
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(to!=fa&&!vis[to]) dfs(to,now,mx,mn);
    }
  }
  ll cal(int now,int mx,int mn){
    dfn=0;
    dfs(now,now,mx,mn);
    ll ret=0;
    sort(dis+1,dis+dfn+1);
    rep(i,1,dfn){
      int d=dis[i].se-k;
      int pos=lower_bound(dis+1,dis+i+1,make_pair(d,0))-dis;
      ret+=(i-pos);
    }
    return ret;
  }
  void getans(int now){
    vis[now]=1;
    ans+=cal(now,val[now],val[now]);
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(!vis[to]) ans-=cal(to,val[now],val[now]);
    }
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(vis[to]) continue;
      all=sz[to],maxt=n+1;
      getroot(to,now);getans(root);
    }
  }
  void solve(int n){
    maxt=all=n;
    ans=0;
    getroot(1,1);
    getans(root);
  }
}
using namespace graph;
int main() {IO;
  cin>>casn;
  while(casn--){
    cin>>n>>k;
    init(n);
    rep(i,1,n) cin>>val[i];
    rep(i,2,n){
      int a,b;
      cin>>a>>b;
      add(a,b);add(b,a);
    }
    solve(n);
    cout<<ans*2<<'\n';
  }
  return 0;
}
01-23 23:06