思路1:
树上启发式合并
代码:
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pii pair<int, int>
#define piii pair<int,pii>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head const int N = 1e5 + ;
int color[N], sz[N], bs[N], mx = ;
LL ans[N], t = ;
vector<int>g[N];
map<int, int>mp;
void dfs(int o, int u) {
sz[u] = ;
for (int v : g[u]) {
if(v != o) {
dfs(u, v);
sz[u] += sz[v];
if(sz[v] > sz[bs[u]]) bs[u] = v;
}
}
}
void ADD(int o, int u) {
mp[color[u]] ++;
if(mp[color[u]] > mx) t = color[u], mx = mp[color[u]];
else if(mp[color[u]] == mx) t += color[u];
for (int v : g[u]) {
if(v != o) {
ADD(u, v);
}
}
}
void DFS(int o, int u) {
for (int v : g[u]) {
if(v != o && v != bs[u]) {
DFS(u, v);
mp.clear();
mx = ;
t = ;
}
}
if(bs[u]) DFS(u, bs[u]);
for (int v : g[u]) {
if(v != o && v != bs[u]) {
ADD(u, v);
}
}
mp[color[u]] ++;
if(mp[color[u]] > mx) t = color[u], mx = mp[color[u]];
else if(mp[color[u]] == mx) t += color[u];
ans[u] = t;
}
int main() {
int n, u, v;
scanf("%d", &n);
for (int i = ; i <= n; i++) scanf("%d", &color[i]);
for (int i = ; i < n; i++) {
scanf("%d %d", &u, &v);
g[u].pb(v);
g[v].pb(u);
}
dfs(, );
DFS(, );
for (int i = ; i <= n; i++) printf("%lld%c", ans[i], " \n"[i==n]);
return ;
}
思路2:
dfs序+分块 求区间众数和
代码:
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head const int N = 1e5 + ;
const int M = ;
vector<int> g[N];
vector<int> pos[N];
int a[N], t[N], in[N], out[N], cnt[N], bl[N], blo, tot = , n;
int f[M][M];
LL fres[M][M];
void dfs(int o, int u) {
in[u] = ++tot;
t[tot] = a[u];
for (int v : g[u]) {
if(v != o) {
dfs(u, v);
}
}
out[u] = tot;
}
void init(int x) {
for (int i = ; i <= n; i++) cnt[i] = ;
int ans = , c = ;
LL res = ;
for (int i = (x-)*blo + ; i <= n; i++) {
cnt[t[i]]++;
if(cnt[t[i]] > c) {
res = ans = t[i];
c = cnt[t[i]];
}
else if(cnt[t[i]] == c) {
ans = t[i];
res += t[i];
}
f[x][bl[i]] = ans;
fres[x][bl[i]] = res;
}
}
int cal(int l, int r, int x) {
if(l > r) return ;
return upper_bound(pos[x].begin(), pos[x].end(), r) - lower_bound(pos[x].begin(), pos[x].end(), l);
}
LL query(int l, int r) {
LL res = ;
int ans = , c = ;
if(bl[l] == bl[r]) {
vector<int> vc;
for (int i = l; i <= r; i++) vc.pb(t[i]);
sort(vc.begin(), vc.end());
vc.erase(unique(vc.begin(), vc.end()), vc.end());
for (int i = ; i < vc.size(); i++) {
int tot = cal(l, r, vc[i]);
if(tot > c) {
res = ans = vc[i];
c = tot;
}
else if(tot == c) {
ans = vc[i];
res += vc[i];
}
}
return res;
}
ans = f[bl[l]+][bl[r]-];
res = fres[bl[l]+][bl[r]-];
int L = bl[l]*blo+, R = (bl[r]-)*blo;
c = cal(L, R, ans);
vector<int> vc;
for (int i = l; i <= bl[l]*blo; i++) vc.pb(t[i]);
for (int i = (bl[r]-)*blo + ; i <= r; i++) vc.pb(t[i]);
sort(vc.begin(), vc.end());
vc.erase(unique(vc.begin(), vc.end()), vc.end());
for (int i = ; i < vc.size(); i++) {
int tot = cal(l, r, vc[i]);
if(tot > c) {
res = ans = vc[i];
c = tot;
}
else if(tot == c) {
ans = vc[i];
res += vc[i];
}
}
return res;
}
int main() {
int u, v;
scanf("%d", &n);
blo = sqrt(n);
for (int i = ; i <= n; i++) scanf("%d", &a[i]);
for (int i = ; i < n; i++) {
scanf("%d %d", &u, &v);
g[u].pb(v);
g[v].pb(u);
}
dfs(, );
for (int i = ; i <= n; i++) bl[i] = (i-)/blo + ;
for (int i = ; i <= bl[n]; i++) init(i);
for (int i = ; i <= n; i++) pos[t[i]].pb(i);
for (int i = ; i <= n; i++) printf("%lld%c", query(in[i], out[i]), " \n"[i==n]);
return ;
}