题目描述

您正在打galgame,然后突然家长进来了,于是您假装在写数据结构题:

给一个树,n 个点,有点权,初始根是 1。

m 个操作,每次操作:

1.将树根换为 x。

2.给出两个点 x,y,从 x 的子树中选每一个点,y 的子树中选每一个点,如果两个点点权相等,ans++,求 ans。

题解

  lxl的大毒瘤题名不虚传

  顺便先膜一下gxz大佬再说(毕竟像我这种菜鸡根本想不出这么巧的方法)->这里

  首先,如果没有换根的话,那么可以直接把子树当成dfs序上的一段区间来做,那么只要把询问给拆成好几个询问,然后直接用莫队就可以了

  然后考虑换根要怎么解决呢?可以参考【bzoj3083】遥远的国度,把子树变成原来的dfs序中的最多两段区间

  然后考虑块的大小,大佬说这题卡常,得把块的大小调成$\frac{n}{\sqrt{m}}*2$,否则会被卡,然而亲测不用乘2也能过……虽然跑得稍微慢了点

  顺带提一句,这里的莫队范围是$(1,x)$而不是$(l,r)$(因为是把询问写成了前缀和相减的形式),当初刚看到的时候没看懂还写错了……

 //minamoto
#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#define ll long long
using namespace std;
#define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[<<],*p1=buf,*p2=buf;
inline int read(){
#define num ch-'0'
char ch;bool flag=;int res;
while(!isdigit(ch=getc()))
(ch=='-')&&(flag=true);
for(res=num;isdigit(ch=getc());res=res*+num);
(flag)&&(res=-res);
#undef num
return res;
}
char sr[<<],z[];int C=-,Z;
inline void Ot(){fwrite(sr,,C+,stdout),C=-;}
inline void print(ll x){
if(C><<)Ot();if(x<)sr[++C]=,x=-x;
while(z[++Z]=x%+,x/=);
while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=;
int a[N],b[N],head[N],Next[N<<],ver[N<<],tot;
int dep[N],ls[N],rs[N],son[N],sz[N],top[N],fa[N],cnt;
int vx[],ox[],tx,vy[],oy[],ty,cx[N],cy[N],w[N],px,py;
ll ans[N*],now;
int n,m,rt,cq,num=;
struct node{
int l,r,rt,id,opt;
node(){}
node(int L,int R,int Id,int Opt){l=min(L,R),r=max(L,R),id=Id,opt=Opt;}
bool operator<(const node &a)const {return rt == a.rt ? r < a.r : l < a.l;}
//inline bool operator <(const node &b)const{return rt==b.rt?rt&1?r<b.r:r>b.r:rt<b.rt;}
}q[N*];
inline void add(int u,int v){
ver[++tot]=v,Next[tot]=head[u],head[u]=tot;
ver[++tot]=u,Next[tot]=head[v],head[v]=tot;
}
void dfs1(int u){
sz[u]=,dep[u]=dep[fa[u]]+;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(v!=fa[u]){
fa[v]=u,dfs1(v),sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
}
void dfs2(int u,int t){
top[u]=t,ls[u]=++cnt,w[cnt]=a[u];
if(son[u]){
dfs2(son[u],t);
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
}
rs[u]=cnt;
}
inline int find(int t,int u){
while(top[u]!=top[t]){
if(fa[top[u]]==t) return top[u];
u=fa[top[u]];
}
return son[t];
}
int main(){
//freopen("testdata.in","r",stdin);
n=read(),m=read();
for(int i=;i<=n;++i) a[i]=b[i]=read();
sort(b+,b++n);
for(int i=;i<=n;++i) a[i]=lower_bound(b+,b++n,a[i])-b;
for(int i=;i<n;++i){
int u=read(),v=read();add(u,v);
}
dfs1(),dfs2(,);
for(int i=;i<=m;++i){
int op=read(),x=read(),y,z;
if(op&) rt=x;
else{
y=read(),++cq;
tx=ty=;
if(x==rt) vx[++tx]=n,ox[tx]=;
else if(ls[rt]<ls[x]||ls[rt]>rs[x]) vx[++tx]=rs[x],ox[tx]=,vx[++tx]=ls[x]-,ox[tx]=-;
else z=find(x,rt),vx[++tx]=n,ox[tx]=,vx[++tx]=rs[z],ox[tx]=-,vx[++tx]=ls[z]-,ox[tx]=;
if(y==rt) vy[++ty]=n,oy[ty]=;
else if(ls[rt]<ls[y]||ls[rt]>rs[y]) vy[++ty]=rs[y],oy[ty]=,vy[++ty]=ls[y]-,oy[ty]=-;
else z=find(y,rt),vy[++ty]=n,oy[ty]=,vy[++ty]=rs[z],oy[ty]=-,vy[++ty]=ls[z]-,oy[ty]=;
for(int j=;j<=tx;++j)
for(int k=;k<=ty;++k)
if(vx[j]&&vy[k])
q[++num]=node(vx[j],vy[k],cq,ox[j]*oy[k]);
}
}
int s=(n/sqrt(num))*+;
for(int i=;i<=num;++i) q[i].rt=(q[i].l-)/s;
sort(q+,q++num);
for(int i=;i<=num;++i){
while(px<q[i].l) now+=cy[w[++px]],++cx[w[px]];
while(py<q[i].r) now+=cx[w[++py]],++cy[w[py]];
while(px>q[i].l) --cx[w[px]],now-=cy[w[px--]];
while(py>q[i].r) --cy[w[py]],now-=cx[w[py--]];
ans[q[i].id]+=now*q[i].opt;
}
for(int i=;i<=cq;++i) print(ans[i]);
Ot();
return ;
}
05-18 07:55