清北学堂模拟赛d4t5 b-LMLPHP

分析:一眼树形dp题,就是不会写QAQ.树形dp嘛,定义状态肯定有一维是以i为根的子树,其实这道题只需要这一维就可以了.设f[i]为以i为根的子树中的权值和.先处理子树内部的情况,用一个数组son[i]表示以i为根的子树中,i能走到的节点个数,可以利用son数组和当前点的权值来更新f数组.

处理了每个子树内部的情况,接下来就要合并它们,将每一个根节点作为中间点,算一下中间点权值的贡献,利用乘法原理算出有多少对点对经过中间点,乘一下就ok了.

树形dp的基本状态定义要熟记,有些题目子树内部是互相独立的,可以在子树里面单独计算,最后再合并一下.

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm> using namespace std; const int maxn = ; int n, a[maxn], head[maxn], to[maxn * ], nextt[maxn * ], tot = , w[maxn * ];
long long ans, f[maxn], son[maxn]; void add(int x, int y, int z)
{
w[tot] = z;
to[tot] = y;
nextt[tot] = head[x];
head[x] = tot++;
} void dfs(int u, int fa, int col)
{
long long res = ;
f[u] = a[u];
son[u] = ;
bool flag = ;
for (int i = head[u]; i; i = nextt[i])
{
int v = to[i];
if (v == fa)
continue;
dfs(v, u, w[i]);
if (col != w[i])
{
flag = ;
son[u] += son[v];
f[u] += son[v] * a[u] + f[v];
}
res += son[v] * a[u] + f[v];
}
ans += res;
if (flag)
return;
for (int i = head[u]; i; i = nextt[i])
{
int v1 = to[i];
if (v1 != fa)
for (int j = i; j; j = nextt[j]) //防止重复统计,所以j=i而不是j=head[u]
{
int v2 = to[j];
if (v2 != fa && w[i] != w[j])
ans += son[v1] * f[v2] + son[v2] * f[v1] + a[u] * son[v1] * son[v2];
}
}
} int main()
{
scanf("%d", &n);
for (int i = ; i <= n; i++)
scanf("%d", &a[i]);
for (int i = ; i < n; i++)
{
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
add(x, y, z);
add(y, x, z);
}
dfs(, , );
printf("%lld\n", ans); return ;
}
05-16 06:58