Luogu P1600 天天爱跑步
### 树上差分
题目链接
树上问题
被观察到的条件有两个
lca前一半(包括lca) \(d[S_i]-d[x]=w[x]\)
\(d[i]\)表示节点深度
lca后一半 \(d[S_i]+d[x]-2*d[lca(S_i,T_i)]=w[x]\)
但是具体怎么实现这个公式??
实现 \(d[S_i]=w[x]+d[x]\)
可以转化为线段树合并模型
在\(S_i\)到\(lca\)的路径上增加\(d[S_i]\)的价值
最后求出各点的\(w[x]+d[x]\)的价值个数
但是好像有点开不下。
可以在每一个节点上建一个vector记录投放
然后一个全局数组c
在树上dfs的时候记录\(c[w[x]+d[x]]\)递归回来的时候和原来的做差就是答案。
代码如下:
#include<bits/stdc++.h>
#define mk make_pair
using namespace std;
const int maxn=300000;
int n,m,head[maxn],tot,w[maxn],lc[maxn],ans[maxn];
int id[maxn],d[maxn],fa[maxn],f[maxn],c1[maxn*2],c2[maxn*2];
struct node{
int nxt,to;
#define nxt(x) e[x].nxt
#define to(x) e[x].to
}e[maxn<<1];
inline void add(int from,int to){
to(++tot)=to;nxt(tot)=head[from];head[from]=tot;
}
inline int find(int x){return fa[x]==x ? fa[x] : fa[x]=find(fa[x]);}
vector<pair<int,int> > pt[maxn];
vector<int> a1[maxn],a2[maxn],b1[maxn],b2[maxn];
pair<int,int> di[maxn];
void tarjan(int x){
id[x]=1;
for(int i=head[x];i;i=nxt(i)){
if(id[to(i)]) continue;
d[to(i)]=d[x]+1;
tarjan(to(i));
fa[to(i)]=x;f[to(i)]=x;
}
for(unsigned int i=0;i<pt[x].size();i++){
int to=pt[x][i].first,vl=pt[x][i].second;
if(id[to]==2) lc[vl]=find(to);
}
id[x]=2;
}
void dfs(int x){
int val1=c1[d[x]+w[x]],val2=c2[w[x]-d[x]+n];
id[x]=1;
for(int i=head[x];i;i=nxt(i)){
if(id[to(i)]) continue;
dfs(to(i));
}
for(int i=0;i<a1[x].size();i++)
c1[a1[x][i]]++;
for(int i=0;i<b1[x].size();i++)
c1[b1[x][i]]--;
for(int i=0;i<a2[x].size();i++)
c2[n+a2[x][i]]++;
for(int i=0;i<b2[x].size();i++)
c2[n+b2[x][i]]--;
ans[x]=c1[d[x]+w[x]]+c2[w[x]-d[x]+n]-val1-val2;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++){
int x,y;scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
for(int i=1;i<=n;i++) fa[i]=i,scanf("%d",&w[i]);
for(int i=1;i<=m;i++){
int x,y;scanf("%d%d",&x,&y);
if(x==y) lc[i]=x;
pt[x].push_back(mk(y,i));pt[y].push_back(mk(x,i));
di[i].first=x;di[i].second=y;
}
d[1]=1;tarjan(1);
for(int i=1;i<=m;i++){
int x=di[i].first,y=di[i].second;
a1[x].push_back(d[x]);a2[y].push_back(d[x]-2*d[lc[i]]);
b1[f[lc[i]]].push_back(d[x]);b2[lc[i]].push_back(d[x]-2*d[lc[i]]);
}
memset(id,0,sizeof(id));
dfs(1);
for(int i=1;i<=n;i++) printf("%d ",ans[i]);
return 0;
}