NOI模拟题4 Problem A: 生成树(mst)-LMLPHP

Solution

我们考虑答案的表达式:

\[ans = \sqrt{\frac{\sum_{i = 1}^{n - 1} (w_i - \overline{w})^2}{n - 1}}
\]

其中\(w[i]\)表示选择的每一条边的权值.

考虑一种朴素的做法:

我们枚举每一个\(\overline{w}\), 把所有边按照其对答案的贡献\((w_i - \overline{w})^2\)排序, 然后用普通的kruskal解决即可.

这种做法的本质在于枚举每一个可能的\((w_i - \overline{w})\)序列. 我们注意到, \((w_i - \overline w)\)是一个二次函数, 因此我们当\(\overline w\)上升时, 该值先下降再上升.

考虑什么时候两条边对应的贡献在排好序的序列中会交换位置: \((w_i - \overline w)^2 = (w_j - \overline w)^2\), 即\(\overline w = \frac {w_i + w_j} 2\)

两条边的贡献在序列中最多只会交换一次位置, 因此我们只需要枚举每个临界值, 就相当于枚举所有合法的序列. 跑kruskal即可.

#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector> using namespace std;
const double INF = 1e50, EPS = 1e-9;
const int N = 20;
int n;
int x[N], y[N];
double w[N * N], dis[N][N];
struct edge
{
int u, v;
double w;
inline edge() {}
inline edge(int _u, int _v, double _w) { u = _u; v = _v; w = _w; }
inline int operator <(const edge &a) const { return w < a.w; }
}edg[N * N];
inline double sqr(double a) { return a * a; }
inline double getDistance(int u, int v) { return sqrt(sqr(x[u] - x[v]) + sqr(y[u] - y[v])); }
struct disjointSet
{
int pre[N];
inline void clear() { for (int i = 0; i < n; ++ i) pre[i] = i; }
inline int access(int u)
{
if (pre[u] == u) return u;
return pre[u] = access(pre[u]);
}
}st;
inline double work(double avr)
{
int tot = 0;
for (int i = 0; i < n; ++ i) for (int j = i + 1; j < n; ++ j) edg[tot ++] = edge(i, j, sqr(dis[i][j] - avr));
sort(edg, edg + tot);
st.clear();
vector<int> bck; bck.clear(); int cnt = 0;
for (int i = 0; cnt < n - 1; ++ i)
{
int u = edg[i].u, v = edg[i].v, rootOfU = st.access(u), rootOfV = st.access(v);
if (rootOfU == rootOfV) continue;
++ cnt;
st.pre[rootOfU] = rootOfV; bck.push_back(i);
}
double sum = 0, res = 0;
for (int i = 0; i < n - 1; ++ i) sum += dis[edg[bck[i]].u][edg[bck[i]].v];
for (int i = 0; i < n - 1; ++ i) res += sqr(dis[edg[bck[i]].u][edg[bck[i]].v] - sum / (n - 1));
return sqrt(res / (n - 1));
}
int main()
{ #ifndef ONLINE_JUDGE freopen("mst.in", "r", stdin);
freopen("mst.out", "w", stdout); #endif int T; scanf("%d", &T);
for (int cs = 0; cs < T; ++ cs)
{
scanf("%d", &n);
for (int i = 0; i < n; ++ i) scanf("%d", x + i);
for (int i = 0; i < n; ++ i) scanf("%d", y + i);
int tot = 0;
for (int i = 0; i < n; ++ i) for (int j = i + 1; j < n; ++ j) dis[i][j] = w[tot ++] = getDistance(i, j);
sort(w, w + tot);
double ans = INF;
for (int i = 0; i < tot; ++ i) for(int j = i + 1; j < tot; ++ j)
ans = min(ans, work((w[i] + w[j]) / 2 + EPS)), ans = min(ans, work((w[i] + w[j]) / 2 - EPS));
printf("%.3lf\n", ans);
}
}
05-11 18:19