洛谷 CF429D Tricky Function

题目:

题解:

  • 人话题意:
  • 给定序列 a[1], a[2], …, a[n]
  • 函数 f(i, j) = (i - j) ^ 2 + (a[i+1] + a[i+2] + … + a[j]) ^ 2
  • 求 f 的最小值。n ≤ 100000

  • 设sum[]为a[]的前缀和数组。
  • 那么f(i, j)可以改写为f(i, j) = (i - j) ^ 2 + (sum[j] - sum[i]) ^ 2
  • 这个,不就是距离公式吗?
  • 即点(i, sum[i])与点(j, sum[j])的距离的平方。
  • 那么此题就转换成了平面上的 n 个点,求最近点对。

  • 值得注意的是,分治求最近点对会T。也就是
for(int i = l; i <= r; i++)
    {
        int v = a[i].x - a[mid].x;
        if(v < d)
            t[++cnt] = i;
    }
sort(t + 1, t + 1 + cnt, cmp2);
for(int i = 1; i < cnt; i++)
    for(int j = i + 1; j <= cnt; j++)
    {
        int v = a[t[j]].y - a[t[i]].y;
        if(v >= d) break;
        else d = min(d, cal(t[i], t[j]));
    }
  • 这一段代码,必须改成如下,才能A掉。
for(int i = l; i <= r; i++)
    {
        int v = a[i].x - a[mid].x;
        if(v * v < d) //改动了这里
            t[++cnt] = i;
    }
sort(t + 1, t + 1 + cnt, cmp2);
for(int i = 1; i < cnt; i++)
    for(int j = i + 1; j <= cnt; j++)
    {
        int v = a[t[j]].y - a[t[i]].y;
        if(v * v >= d) break; //改动了这里
        else d = min(d, cal(t[i], t[j]));
    }
  • 个人认为,不改动在原理才是说得通的。
  • 改动后的正确性太蒻不懂如何证明。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#define N 100005
#define inf 0x7fffffff
#define int long long
using namespace std;

struct A {int x, y;} a[N];
int sum[N], t[N];
int n;

int read()
{
    int x = 0, f = 1; char c = getchar();
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
    return x *= f;
}

bool cmp1(A u, A v) {return u.x < v.x;}
bool cmp2(int u, int v) {return a[u].y < a[v].y;}

int cal(int u, int v)
{
    int v1 = (a[u].x - a[v].x) * (a[u].x - a[v].x);
    int v2 = (a[u].y - a[v].y) * (a[u].y - a[v].y);
    return v1 + v2;
}

int fun(int l, int r)
{
    if(l == r) return inf;
    if(l + 1 == r) return cal(l, r);
    int mid = (l + r) >> 1, cnt = 0;
    int d = min(fun(l, mid), fun(mid + 1, r));
    for(int i = l; i <= r; i++)
    {
        int v = a[i].x - a[mid].x;
        if(v * v < d)
            t[++cnt] = i;
    }
    sort(t + 1, t + 1 + cnt, cmp2);
    for(int i = 1; i < cnt; i++)
        for(int j = i + 1; j <= cnt; j++)
        {
            int v = a[t[j]].y - a[t[i]].y;
            if(v * v >= d) break;
            else d = min(d, cal(t[i], t[j]));
        }
    return d;
}

signed main()
{
    cin >> n;
    for(int i = 1; i <= n; i++)
    {
        int val = read();
        sum[i] = sum[i - 1] + val;
        a[i].x = i, a[i].y = sum[i];
    }
    sort(a + 1, a + 1 + n, cmp1);
    cout << fun(1, n);
    return 0;
}
01-20 16:51