[WC2014]紫荆花之恋 解题报告 (动态点分治)

[WC2014]紫荆花之恋

题意

一棵带权树, 每个点有一个属性 \(r_i\),

两个节点 \(i,j\) 成为朋友当且仅当 \(dis(i,j) \le r_i + r_j\).

树上原来没有节点, 每向树上加入一个节点后, 求树上总共有多少对朋友, 强制在线.

思路

考虑树的形态固定时应该怎么做.

考虑点分治.

对于每个重心 \(u\) , 设 \(i,j\) 在它的管理范围内, 且在两棵不同的子树内,

\(dis(i,j) \le r_i + r_j\) 可变为 \(dis(i,u) + dis(j,u) \le r_i + r_j\),

即求满足 \(dis(j,u)-r_j \le r_i-dis(i,u)\) 的点对个数.

按照点分治的套路, 设 \(ft[u]\) 为点 \(u\) 在点分树上的父亲, 对每个点维护两个 \(vector\)

\(v1[u]\) : \(u\) 的管理范围内的节点, 按照到 \(u\) 的距离 (从小到大) 排序.

\(v2[u]\) : \(u\) 的管理范围内的节点, 按照到 \(ft[u]\) 的距离 (从小到大) 排序.

然后直接点分治就行了.

现在, 树的形态是不固定的.

如果没有强制在线的话, 可以先把树建出来后建好点分树, 按照一般的动态点分治来写就好了.


动态点分治不能处理 树形态发生改变 的情况, 是因为树的形态改变会导致重心的改变, 那么点分树 树高为 \(\log n\) 的优美性质就不存在了, 从而会导致复杂度 (时间和空间) 的退化.


那为了防止复杂度的退化, 我们可以 动态维护重心, 也就是当原来的重心和现在的重心偏离到一定规模时, 重新做一遍点分治, 重构点分树.


那么, 我们考虑一下, 为了进行以上的操作, 需要维护什么东西.

首先, 由于树上的点是一个一个插入的, 而 \(vector\) 不支持插入操作, 所以 \(vector\) 必须被换掉, 换成一个支持排序, 插入的数据结构 (删除的话其实不是刚需, 因为重构点分树时直接清空就好了).

这里我们选择用 平衡树. ( \(set\) 本质上就是一个平衡树 (红黑树), 并且常数比较大, 所以这里我们不用它).

光有平衡树还不够, 因为当重构点分树时, 我们不可能每次都对点分树进行重构, 这样时间复杂度会达到 \(O(n^2 \log n)\), 所以我们每次只能重构点分树的一个子树, 那么我们就还需要记录 每个点在点分树中的子孙, 这个的话用个 \(vector\) 就行了.


现在还剩下一个问题, 如何判断什么时候需要重构点分树?

我们可以对每个点记录一个 \(sz\), 表示它的点分树的子树大小,

再定义一个 $\alpha \in (0.5,1) $, 当存在点 \(u\) 满足 \(sz[u] > \alpha \times sz[ft[u]]\) 时, 就重构 \(ft[u]\) 的子树.

那么每次我们找到最高的满足上述条件的点, 然后重构它的子树就好了, \(\alpha\) 的话取 \([0.75,0.85]\) 左右比较好.


至于平衡树, 一开始我是写的 \(treap\), 但是太慢了, \(T\) 4个点 (听说指针板的 \(treap\) 更快, 但我不会),

后来换成了替罪羊树就过了.

代码

#include<bits/stdc++.h>
#define ll long long
#define db double
#define pb push_back
#define sz size
#define rfs(x) sz[(x)]=sz[ls[(x)]]+sz[rs[(x)]]+num[(x)]
using namespace std;
const int _=1e5+7;
const int __=1e7+7;
const int L=17;
const int inf=0x3f3f3f3f;
const ll mod=1e9;
bool be;
int ty,n,f[_][L+7],dep[_],dis[_],ft[_],q[_],r[_];   // 点分治, lca
ll ans;
int rt1[_],rt2[_],ls[__],rs[__],val[__],sz[__],num[__],stk[__],cnt;  // 1: to self   2: to father   平衡树
int ddis[_],ssz[_],rt,minx;   // 点分治
int lst[_],nxt[2*_],to[2*_],len[2*_],tot;   // 建边
bool exi[_],vis[_];    // 点分治
vector<int> vec[_];
db ar=0.76,ra=0.76;
bool en;
int gi(){
  char c=getchar(); int x=0;
  while(c<'0'||c>'9') c=getchar();
  while(c>='0'&&c<='9'){ x=(x<<3)+(x<<1)+c-'0'; c=getchar(); }
  return x;
}
void pu(ll x){
  if(!x){ putchar('0'); return; }
  ll y=1;
  while(y<=x) y=(y<<3)+(y<<1);
  while(y>1){ y/=10; putchar(x/y+'0'); x%=y; }
}
void add(int x,int y,int w){ nxt[++tot]=lst[x]; to[tot]=y; len[tot]=w; lst[x]=tot; }
int Lca(int x,int y){
  if(dep[x]<dep[y]) swap(x,y);
  for(int i=L;i>=0;i--)
    if(dep[f[x][i]]>=dep[y])
      x=f[x][i];
  if(x==y) return x;
  for(int i=L;i>=0;i--)
    if(f[x][i]!=f[y][i]){
      x=f[x][i];
      y=f[y][i];
    }
  return f[x][0];
}
int dist(int x,int y){ return dis[x]+dis[y]-2*dis[Lca(x,y)]; }
int seq[__],p;
void get(int u){
  if(ls[u]) get(ls[u]);
  seq[++p]=u;
  if(rs[u]) get(rs[u]);
}
void arrange(int &u,int l,int r){
  if(l>r){ u=0; return; }
  int mid=(l+r)>>1;
  u=seq[mid];
  arrange(ls[u],l,mid-1);
  arrange(rs[u],mid+1,r);
  sz[u]=sz[ls[u]]+sz[rs[u]]+num[u];
}
void rebuild(int &u){
  p=0;
  get(u);
  arrange(u,1,p);
}
void insert(int &u,int w,bool fl){
  if(u){
    sz[u]++;
    bool flag=0;
    if(val[u]==w){ num[u]++; exi[u]=1; }
    else if(val[u]>w){
      if(!fl&&(db)(sz[ls[u]]+1)>sz[u]*ra) flag=1;
      insert(ls[u],w,fl|flag);
    }
    else{
      if(!fl&&(db)(sz[rs[u]]+1)>sz[u]*ra) flag=1;
      insert(rs[u],w,fl|flag);
    }
    if(flag) rebuild(u);
  }
  else{
    u=stk[cnt--];
    val[u]=w;
    sz[u]=num[u]=1;
    ls[u]=rs[u]=0;
  }
}
void clear(int &u){
  if(ls[u]) clear(ls[u]);
  if(rs[u]) clear(rs[u]);
  stk[++cnt]=u;
  u=0;
}
int query(int u,int w){
  int res=0;
  while(u){
    if(val[u]==w){ res+=sz[ls[u]]+num[u]; break; }
    else if(val[u]>w) u=ls[u];
    else{ res+=sz[ls[u]]+num[u]; u=rs[u]; }
  }
  return res;
}
int run(int x,int &t){
  int len,u=x,fa=ft[u],res=0;
  insert(rt1[x],-r[x],0);
  vec[x].pb(x);
  while(fa){
    len=dist(fa,x);
    res+=query(rt1[fa],r[x]-len)-query(rt2[u],r[x]-len);
    insert(rt1[fa],len-r[x],0);
    insert(rt2[u],len-r[x],0);
    vec[fa].pb(x);
    if((db)vec[u].sz()>vec[fa].sz()*ar) t=fa;
    u=ft[u]; fa=ft[u];
  }
  return res;
}
int que[_],t;
void g_rt(int u,int fa,int sum){
  //printf("   u: %d\n",u);
  int maxn=0; que[++t]=u;
  ssz[u]=1;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(exi[v]&&!vis[v]&&v!=fa){
      ddis[v]=ddis[u]+len[i];
      g_rt(v,u,sum);
      ssz[u]+=ssz[v];
      maxn=max(maxn,ssz[v]);
    }
  }
  maxn=max(maxn,sum-ssz[u]);
  if(maxn<minx){ minx=maxn; rt=u; }
}
void count(int u,int fa){
  que[++t]=u;
  for(int i=lst[u];i;i=nxt[i])
    if(!vis[to[i]]&&exi[to[i]]&&to[i]!=fa){
      ddis[to[i]]=ddis[u]+len[i];
      count(to[i],u);
    }
}
void calc(int u){
  t=0;
  for(int i=lst[u];i;i=nxt[i])
    if(!vis[to[i]]&&exi[to[i]]){
      ddis[to[i]]=len[i];
      count(to[i],0);
    }
  for(int i=1;i<=t;i++){
    insert(rt1[u],ddis[que[i]]-r[que[i]],0);    // 区分 que 和 q
    vec[u].pb(que[i]);
  }
  insert(rt1[u],-r[u],0);    // 记得把自己加上去
  vec[u].pb(u);
}
void pdiv(int u,int lrt,int sum){
  minx=inf; t=0;
  g_rt(u,0,ssz[u]<ssz[lrt] ?ssz[u] :sum-ssz[lrt]);
  if(lrt){
    if(ddis[u]>=0){
      for(int i=1;i<=t;i++){
    insert(rt2[rt],ddis[que[i]]-r[que[i]],0);
      }
    }
    else{
      for(int i=1;i<=t;i++){
    insert(rt2[rt],dist(que[i],lrt)-r[que[i]],0);
      }
    }
  }
  sum=ssz[u]; ft[rt]=lrt; vis[rt]=1; u=rt;
  for(int i=lst[u];i;i=nxt[i]){
    //printf("%d\n",u);
    if(!vis[to[i]]&&exi[to[i]]){
      ddis[to[i]]=len[i];
      pdiv(to[i],u,sum);
    }
  }
  calc(u);
  vis[u]=0;
}
void adj(int u,int x){
  if(!x) return;
  int top=0;
  for(int i=0;i<(int)vec[x].sz();i++) q[++top]=vec[x][i];
  for(int i=1;i<=top;i++){                           // 清空
    exi[q[i]]=1;
    if(rt1[q[i]]) clear(rt1[q[i]]);
    if(rt2[q[i]]) clear(rt2[q[i]]);
    vector<int>().swap(vec[q[i]]);

  }
  ddis[x]=-1;
  ssz[ft[x]]=0;
  pdiv(x,ft[x],top);
  //puts("!!!");
  for(int i=1;i<=top;i++) ssz[q[i]]=ddis[q[i]]=exi[q[i]]=0;
}
int main(){
  //freopen("x.in","r",stdin);
  //freopen("x.out","w",stdout);
  cin>>ty>>n; int c,x; ll a;
  clock_t st=clock();
  for(int i=1;i<=__-7;i++) stk[i]=i; cnt=__-7;
  for(int i=1;i<=n;i++){
    a=(ll)gi(); c=gi(); r[i]=gi();
    x=0;
    a^=ans%mod;
    f[i][0]=ft[i]=(int)a;
    for(int k=1;k<=L;k++)
      f[i][k]=f[f[i][k-1]][k-1];
    dep[i]=dep[ft[i]]+1;
    dis[i]=dis[ft[i]]+c;
    add(i,ft[i],c);
    add(ft[i],i,c);
    ans+=run(i,x);
    pu(ans); putchar('\n');
    adj(i,x);
  }
  //printf("\nused time: %.2lfs\n",(db)(clock()-st)/CLOCKS_PER_SEC);
  //printf("\nused space: %.2lfMB\n",(&en-&be)/(1<<20)*1.0);
  return 0;
}
12-14 03:43