和之前那个 【LNOI】LCA 几乎是同一道题,就是用动态树来维护查分就行.

code:

#include <bits/stdc++.h>
using namespace std;
#define N 50006
#define mod 998244353
#define ll long long
#define lson t[x].ch[0]
#define rson t[x].ch[1]
#define get(x) (t[t[x].f].ch[1]==x)
#define isrt(x) (!(t[t[x].f].ch[0]==x||t[t[x].f].ch[1]==x))
#define setIO(s) freopen(s".in","r",stdin)
int sta[N],hd[N],to[N],nex[N],answer[N],dep[N],n,Q,K,edges;
inline void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
inline int qpow(int x,int y)
{
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=(ll)tmp*x%mod;
return tmp;
}
struct sol
{
int y,id;
sol(int y=0,int id=0):y(y),id(id){}
};
vector<sol>a[N];
struct node
{
int f,rev,ch[2],add;
ll sum1,sum2,val1,val2;
}t[N];
inline void pushup(int x)
{
t[x].sum1=(t[lson].sum1+t[rson].sum1+t[x].val1)%mod;
t[x].sum2=(t[lson].sum2+t[rson].sum2+t[x].val2)%mod;
}
inline void rotate(int x)
{
int old=t[x].f,fold=t[old].f,which=get(x);
if(!isrt(old)) t[fold].ch[t[fold].ch[1]==old]=x;
t[old].ch[which]=t[x].ch[which^1],t[t[old].ch[which]].f=old;
t[x].ch[which^1]=old,t[old].f=x,t[x].f=fold;
pushup(old),pushup(x);
}
inline void mark(int x,int d)
{
(t[x].val2+=1ll*d*t[x].val1%mod)%=mod;
(t[x].sum2+=1ll*d*t[x].sum1%mod)%=mod;
t[x].add+=d;
}
inline void pushdown(int x)
{
if(x&&t[x].add)
{
if(lson) mark(lson,t[x].add);
if(rson) mark(rson,t[x].add);
t[x].add=0;
}
}
void splay(int x)
{
int v=0,u=x,fa;
for(sta[++v]=u;!isrt(u);u=t[u].f) sta[++v]=t[u].f;
for(;v;--v) pushdown(sta[v]);
for(u=t[u].f;(fa=t[x].f)!=u;rotate(x))
{
if(t[fa].f!=u)
{
rotate(get(fa)==get(x)?fa:x);
}
}
}
void Access(int x)
{
for(int y=0;x;y=x,x=t[x].f)
{
splay(x);
rson=y;
pushup(x);
}
}
void dfs(int u)
{
dep[u]=dep[t[u].f]+1;
t[u].val1=(qpow(dep[u],K)-qpow(dep[u]-1,K)+mod)%mod;
for(int i=hd[u];i;i=nex[i]) dfs(to[i]);
pushup(u);
}
int main()
{
// setIO("input");
int i,j;
scanf("%d%d%d",&n,&Q,&K);
for(i=2;i<=n;++i)
{
scanf("%d",&t[i].f),add(t[i].f,i);
}
dep[1]=1,dfs(1);
for(i=1;i<=Q;++i)
{
int x,y;
scanf("%d%d",&x,&y);
a[x].push_back(sol(y,i));
}
for(i=1;i<=n;++i)
{
Access(i),splay(i),mark(i,1);
for(j=0;j<a[i].size();++j)
{
Access(a[i][j].y),splay(a[i][j].y);
answer[a[i][j].id]=t[a[i][j].y].sum2%mod;
}
}
for(i=1;i<=Q;++i) printf("%d\n",answer[i]);
return 0;
}

  

05-11 20:42