Description
给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设$dep[i]$表示点i的深度,$lca(i,j)$表示i与j的最近公共祖先。
有q次询问,每次询问给出l,r,z,求$\sum\limits_{i=l}^{r}dep[lca(i,z)]$。
(即求在$[l,r]$区间内的每个节点i与z的最近公共祖先的深度之和)
$n,q<=50000$
Input
第一行2个整数n,q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l,r,z。
Output
输出q行,每行表示一个询问的答案。每个答案对201314取模输出
Sample Input
5 2
0
0
1
1
1 4 3
1 4 2
Sample Output
8
5
这题不算很难,但是看到正解的方法很有趣,就记录一下~
考虑暴力,每次询问把z到根的所有节点打上标记,枚举i的时候直接往根找第一个有标记的点,然后统计深度即可。
然而复杂度明显是大于$O(n^2q)$的,显然不能接受,容易发现如果把深度看成点权,那么统计深度就相当于把z到根的每个节点权值都加一,然后枚举时统计根节点到i的权值和。使用树链剖分可以降到$O(nqlog^2n)$,但是还是会超时。
考虑进一步优化,发现询问可以差分,拆成$[1,l-1]$和$[1,r]$两个询问,进一步可以发现对答案有贡献的点只会在$lca(i,z)$以上,因此把每个i到根路径上的结点权值加一,再统计z到根节点的权值和,得出的答案是相同的。再结合差分,每次直接将$[1,l-1]$或$[1,r]$中所有节点到根节点路径上的点权值加一,然后统计z到根节点的权值和,按照dfs序区间修改+区间查询,用树链剖分加线段树可以做到$O(qlog^2n)$,用LCT可以做到$O(qlogn)$
注意long long
代码:
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define mod 201314
using namespace std;
typedef long long ll;
struct edge{
int v,next;
}a[];
struct task{
int r,z,id,ok;
}q[];
int n,qq,u,v,z,tmp=,qqq=,tot=,tim=,head[],son[],siz[],fa[],dep[],dfn[],top[];
ll t[],laz[],tv[],ans[],anss;
bool cmp(task a,task b){
return a.r<b.r;
}
void add(int u,int v){
a[++tot].v=v;
a[tot].next=head[u];
head[u]=tot;
}
void dfs1(int u,int f,int dpt){
dep[u]=dpt;
fa[u]=f;
siz[u]=;
for(int tmp=head[u];tmp!=-;tmp=a[tmp].next){
int v=a[tmp].v;
if(!dep[v]){
dfs1(v,u,dpt+);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]]||son[u]==-)son[u]=v;
}
}
}
void dfs2(int u,int tp){
dfn[u]=++tim;
top[u]=tp;
if(son[u]!=-)dfs2(son[u],tp);
for(int tmp=head[u];tmp!=-;tmp=a[tmp].next){
int v=a[tmp].v;
if(v!=son[u])dfs2(v,v);
}
}
void pushup(int u){
t[u]=t[u*]+t[u*+];
}
void pushdown(int u){
if(laz[u]){
laz[u*]+=laz[u];
laz[u*+]+=laz[u];
t[u*]+=tv[u*]*laz[u];
t[u*+]+=tv[u*+]*laz[u];
laz[u]=;
}
}
void build(int l,int r,int u){
if(l==r){
tv[u]=;
return;
}
int mid=(l+r)/;
build(l,mid,u*);
build(mid+,r,u*+);
tv[u]=tv[u*]+tv[u*+];
}
void updata(int l,int r,int u,int L,int R,int v){
if(L<=l&&r<=R){
t[u]+=(ll)tv[u]*v;
laz[u]+=v;
return;
}
int mid=(l+r)/;
pushdown(u);
if(L<=mid)updata(l,mid,u*,L,R,v);
if(mid<R)updata(mid+,r,u*+,L,R,v);
pushup(u);
}
ll query(int l,int r,int u,int L,int R){
if(L<=l&&r<=R){
return t[u];
}
int mid=(l+r)/,ans=;
pushdown(u);
if(L<=mid)ans+=query(l,mid,u*,L,R);
if(R>mid)ans+=query(mid+,r,u*+,L,R);
return ans;
}
void work(int u){
while(u){
int v=top[u];
updata(,n,,dfn[v],dfn[u],);
u=fa[v];
}
}
ll _work(int u){
ll ans=;
while(u){
int v=top[u];
ans+=query(,n,,dfn[v],dfn[u]);
u=fa[v];
}
return ans;
}
int main(){
memset(son,-,sizeof(son));
memset(head,-,sizeof(head));
scanf("%d%d",&n,&qq);
for(int i=;i<n;i++){
scanf("%d",&u);
add(u+,i+);
}
dfs1(,,);
dfs2(,);
build(,n,);
for(int i=;i<=qq;i++){
scanf("%d%d%d",&u,&v,&z);
u++,v++,z++;
q[++qqq]=(task){u-,z,i,-};
q[++qqq]=(task){v,z,i,};
}
sort(q+,q+qqq+,cmp);
for(int i=;i<=qqq;i++){
while(tmp<q[i].r){
work(++tmp);
}
ans[q[i].id]+=(ll)q[i].ok*_work(q[i].z);
}
for(int i=;i<=qq;i++){
printf("%lld\n",ans[i]%mod);
}
return ;
}