题意
给出一棵N 个节点的树,树上的每个节点都有一个权值$a_i$。
有Q 次询问,每次在树上选中两个点u, v,考虑所有在简单路径u, v 上(包括u, v)的点构成的集合S。
求$\sum_{w∈S}{a_w or dist(u,w)}$
其中dist(u,w) 为简单路径u,w 上的边数,or 是按位或。
思考
设f[i][j]表示从第i个节点开始,总共跳了$2^j$个节点得到的答案,g[i][j]表示从第i个节点开始,向下跳了$2^j$个节点得到的答案。这些状态的转移只要考虑后半部分最高位上1的贡献。
接下来对于查询u,v,我们算出其lca。对于u-lca,可以用倍增求出答案。由于倍增时每次的长度都至少会是以前的一半,那么某些二进制位上在以后一定都会是1,而且之前求出的f和它是独立的。对于lca-v的答案,可以先向上跳一些深度,再用g类似地求出答案。
代码
1 // luogu-judger-enable-o2 2 // luogu-judger-enable-o2 3 // luogu-judger-enable-o2 4 #include<bits/stdc++.h> 5 using namespace std; 6 typedef long long int ll; 7 const int maxn=6E5+5; 8 const int layer=22; 9 int n,T; 10 int size,head[maxn*2]; 11 int root,dep[maxn]; 12 ll val[maxn],fa[maxn][layer],cnt[maxn][layer],up[maxn][layer],down[maxn][layer]; 13 inline int read() 14 { 15 char ch=getchar(); 16 while(!isdigit(ch)) 17 ch=getchar(); 18 int sum=ch-'0'; 19 ch=getchar(); 20 while(isdigit(ch)) 21 { 22 sum=sum*10+ch-'0'; 23 ch=getchar(); 24 } 25 return sum; 26 } 27 void write(ll x) 28 { 29 if(x>=10) 30 write(x/10); 31 putchar('0'+x%10); 32 } 33 inline void writen(ll x) 34 { 35 write(x); 36 putchar('\n'); 37 } 38 struct edge 39 { 40 int to,next; 41 }E[maxn*2]; 42 inline void add(int u,int v) 43 { 44 E[++size].to=v; 45 E[size].next=head[u]; 46 head[u]=size; 47 } 48 void dfs(int u,int F,int d) 49 { 50 fa[u][0]=F; 51 dep[u]=d; 52 for(int i=1;i<layer;++i) 53 fa[u][i]=fa[fa[u][i-1]][i-1]; 54 for(int i=0;i<layer;++i) 55 cnt[u][i]=cnt[F][i]+((val[u]&(1<<i))>0); 56 up[u][0]=down[u][0]=val[u]; 57 for(int i=1;i<layer;++i) 58 { 59 up[u][i]=up[u][i-1]+up[fa[u][i-1]][i-1]+(ll)(1<<(i-1))*((1<<(i-1))-cnt[fa[u][i-1]][i-1]+cnt[fa[u][i]][i-1]); 60 down[u][i]=down[u][i-1]+down[fa[u][i-1]][i-1]+(ll)(1<<(i-1))*((1<<(i-1))-cnt[u][i-1]+cnt[fa[u][i-1]][i-1]); 61 } 62 for(int i=head[u];i;i=E[i].next) 63 { 64 int v=E[i].to; 65 if(v==F) 66 continue; 67 dfs(v,u,d+1); 68 } 69 } 70 inline int jump(int x,int d) 71 { 72 for(int i=0;i<layer;++i) 73 if(d&(1<<i)) 74 x=fa[x][i]; 75 return x; 76 } 77 inline int lca(int x,int y) 78 { 79 if(dep[x]<dep[y]) 80 swap(x,y); 81 int d=dep[x]-dep[y]; 82 x=jump(x,d); 83 if(x==y) 84 return x; 85 for(int i=layer-1;i>=0;--i) 86 if(fa[x][i]!=fa[y][i]) 87 x=fa[x][i],y=fa[y][i]; 88 return fa[x][0]; 89 } 90 inline ll get1(int x,int to) 91 { 92 if(x==to) 93 return 0; 94 ll sum=0; 95 for(int i=layer-1;i>=0;--i) 96 if(dep[fa[x][i]]>=dep[to]) 97 { 98 sum+=up[x][i]; 99 x=fa[x][i]; 100 sum+=(dep[x]-dep[to]-(cnt[x][i]-cnt[to][i]))*(ll)(1<<i); 101 } 102 return sum; 103 } 104 int wait[233],c[233],tot; 105 inline ll get2(int x,int to) 106 { 107 if(x==to) 108 return val[x]; 109 if(dep[x]>dep[to]) 110 return 0; 111 ll sum=0; 112 tot=0; 113 int from=to; 114 int d=dep[to]-dep[x]+1; 115 for(int i=0;i<layer;++i) 116 if(d&(1<<i)) 117 { 118 wait[++tot]=to; 119 c[tot]=i; 120 to=fa[to][i]; 121 } 122 for(int i=tot;i>=1;--i) 123 { 124 int u=wait[i]; 125 sum+=down[u][c[i]]; 126 if(c[i]) 127 sum+=(ll)(1<<(c[i]))*(dep[from]-dep[u]-(cnt[from][c[i]]-cnt[u][c[i]])); 128 } 129 return sum; 130 } 131 int main() 132 { 133 // freopen("C1.in","r",stdin); 134 // freopen("C.out","w",stdout); 135 ios::sync_with_stdio(false); 136 n=read(),T=read(); 137 for(int i=1;i<=n;++i) 138 val[i]=read(); 139 for(int i=2;i<=n;++i) 140 { 141 int x=read(),y=read(); 142 add(x,y); 143 add(y,x); 144 } 145 root=n*2; 146 add(n+1,1); 147 for(int i=n+2;i<=root;++i) 148 add(i,i-1); 149 dfs(root,root,0); 150 while(T--) 151 { 152 int x,y,z,d; 153 x=read(),y=read(); 154 z=lca(x,y); 155 d=dep[x]-dep[z]; 156 int q=jump(z,d); 157 // cout<<x<<" "<<z<<" "<<q<<endl; 158 writen(get1(x,z)+get2(q,y)-get2(q,fa[z][0])); 159 } 160 return 0; 161 }