差不多可以说是树链剖分的模板题了,直接维护即可。

#include <bits/stdc++.h>

using namespace std;

#define REP(i,n)                for(int i(0); i <  (n); ++i)
#define rep(i,a,b) for(int i(a); i <= (b); ++i)
#define dec(i,a,b) for(int i(a); i >= (b); --i)
#define for_edge(i,x) for(int i = H[x]; i; i = X[i]) #define LL long long
#define ULL unsigned long long
#define MP make_pair
#define PB push_back
#define FI first
#define SE second
#define INF 1 << 30 const int N = 300000 + 10;
const int M = 10000 + 10;
const int Q = 1000 + 10;
const int A = 30 + 1; int E[N << 1], H[N << 1], X[N << 1];
int c[N];
int top[N];
int fa[N];
int deep[N];
int num[N];
int son[N];
int fp[N];
int p[N];
int et, pos;
int a[N];
int n, x, y; inline int lowbit(int x){ return (x) & (-x);} inline int query(int x){int ret = 0; for (; x; x -= lowbit(x)) ret += c[x]; return ret;}
inline void add(int x, int val){ for (; x <= n; x += lowbit(x)) c[x] += val;} inline void addedge(int a, int b){
E[++et] = b, X[et] = H[a], H[a] = et;
E[++et] = a, X[et] = H[b], H[b] = et;
} void dfs(int x, int pre){
deep[x] = deep[pre] + 1;
fa[x] = pre;
num[x] = 1;
for_edge(i, x){
int v = E[i];
if (v != pre){
dfs(v, x);
num[x] += num[v];
if (son[x] != -1 || num[v] > num[son[x]])
son[x] = v;
}
}
} void getpos(int x, int sp){
top[x] = sp;
p[x] = ++pos;
fp[p[x]] = x;
if (son[x] == -1) return;
getpos(son[x], sp);
for_edge(i, x){
int v = E[i];
if (v != son[x] && v != fa[x])
getpos(v, v);
}
} void cover(int u, int v, int val){
int f1 = top[u], f2 = top[v];
int tmp = 0;
while (f1 != f2){
if (deep[f1] < deep[f2]){
swap(f1, f2);
swap(u, v);
}
add(p[f1], val);
add(p[u] + 1, -val);
u = fa[f1];
f1 = top[u];
} if (deep[u] > deep[v]) swap(u, v);
add(p[u], val);
add(p[v] + 1, -val);
} int main(){
#ifndef ONLINE_JUDGE
freopen("test.txt", "r", stdin);
freopen("test.out", "w", stdout);
#endif scanf("%d", &n);
rep(i, 1, n) scanf("%d", a + i);
rep(i, 1, n - 1){
scanf("%d%d", &x, &y);
addedge(x, y);
} memset(son, -1, sizeof son);
dfs(1, 0);
getpos(1, 1);
rep(i, 1, n - 1){
x = a[i], y = a[i + 1];
cover(x, y, 1);
} rep(i, 1, n) if (i == a[1]) printf("%d\n", query(p[i]));
else printf("%d\n", query(p[i]) - 1); return 0; }
05-21 18:57