题目描述
一些村庄建在一条笔直的高速公路边上,我们用一条坐标轴来描述这条公路,每个村庄的坐标都是整数,没有两个村庄的坐标相同。两个村庄的距离定义为坐标之差的绝对值。我们需要在某些村庄建立邮局。使每个村庄使用与它距离最近的邮局,建立邮局的原则是:所有村庄到各自使用的邮局的距离总和最小。数据规模:1<=村庄数<=1600, 1<=邮局数<=200, 1<=村庄坐标<=maxlongint
输入
2行第一行:n m {表示有n个村庄,建立m个邮局} 第二行:a1 a2 a3 .. an {表示n个村庄的坐标}
输出
1行第一行:l 个整数{表示最小距离总和}
样例输入
10 5
1 2 3 6 7 9 11 22 44 50
样例输出
9
这道题目是IOI2000的真题哦~
可以这样考虑:
给定一个区间,假设我们要建一个邮局,那么一定是在这个序列的中点,所以我们可以先预处理出序列区间[i,j]之间的距离
一个邮局的最短距离记录为sum[i][j],然后用f[i][j]表示到i个村庄建立j个邮局的最短距离和,那么就有状态转移方程:
f[i][j]=min(f[i][j],f[k][j-1]+sum[k+1][i]);
这样,代码就好写了。
但是——这个数据,用O(n) 的普通DP算法显然无法通过。
O(n)代码如下:
#include<bits/stdc++.h> using namespace std; int n,m; int a[]; long long sum[]]; long long f[][]; //f[i][j]表示前i个村庄设j个邮局 //sum[i][j]表示在第i个村庄到第j个村庄设一个邮局的路程 int main(void){ cin>>n>>m; for (int i=;i<=n;i++) cin>>a[i]; sort(a+,a+n+); for (int i=;i<=n;i++){ for (int j=i;j<=n;j++){ sum[i][j]=dis(i,j); } } memset(f,0x3f,sizeof(f)); for (int i=;i<=n;i++){ f[i][]=sum[][i]; } for (int i=;i<=n;i++){ for (int j=;j<=min(i,m);j++){ for (int k=j-;k<=i-;k++){ f[i][j]=min(f[i][j],f[k][j-]+sum[k+][i]); } } } cout<<f[n][m]<<endl; }
这东西肯定过不了啊~
那怎么办?"四边形不等式!"
f[a][c]+f[b][d]<=f[b][c]+f[a][d]
( a < b <= c< d )
(可以理解成:交叉小于包含)
满足这个条件的DP方程(或者说是别的什么数组啊之类的)就称为***为凸。
(以下一段文字来自https://blog.csdn.net/noiau/article/details/72514812)
给出两个定理:
1、如果上述的w函数同时满足区间包含单调性和四边形不等式性质,那么函数dp也满足四边形不等式性质
我们再定义s(i,j)表示 dp(i,j) 取得最优值时对应的下标(即 i≤k≤j 时,k 处的 dp 值最大,则 s(i,j)=k此时有如下定理
2、假如dp(i,j)满足四边形不等式,那么s(i,j)单调,即 s(i,j)≤s(i,j+1)≤s(i+1,j+1)
大家可以自己尝试推倒一下,为什么f[i][j]和sum[i][j]是满足这个式子的(因为我懒得再推了)
再就是要证明"决策单调"
(以下一段文字来自https://blog.csdn.net/noiau/article/details/72514812)
如果我们用s[i][j]表示dp[i][j]取得最优解的时候k的位置的话
那么我们要证明如下结论的成立性:
s[i][j-1]<=s[i][j]<=s[i+1][j]
对于s[i][j-1]<=s[i][j]来说,我们先令dp[i][j-1]取得最优解的时候的k值为y,然后令除了最优值以外的其他值可以为x,这里我们由于要讨论单调性,所以让x小于y,即x<=y<=j-1< j
这里的证明更为繁琐,在实际应用中,我们可以写出O(n)后,自己跑一边是否决策单调,不是就输出"false"就行了。
在这道题中,我们要注意三点:
- s数组(决策数组)的初始化
- 循环的次序
- 对邮局多于村庄的特判(血泪)
话不多说,代码上
#include<bits/stdc++.h> using namespace std; int n,m; int a[]; long long sum[][]; long long f[][]; int s[][]; //s是决策数组 int main(void){ cin>>n>>m; if(m>=n){ printf(""); return ; } for (int i=;i<=n;i++) cin>>a[i]; sort(a+,a+n+); for (int i=;i<=n;i++){ sum[i][i]=; for (int j=i+;j<=n;j++){ sum[i][j]=sum[i][j-]+a[j]-a[(i+j)/]; } } memset(f,0x3f,sizeof(f)); //注意这里f要初始化成最大值 memset(s,,sizeof(s)); for (int i=;i<=n;i++){ f[i][]=sum[][i]; s[i][]=; } for (int j=;j<=m;j++){ s[n+][j]=n; for (int i=n;i>=j;i--){ for (int k=s[i][j-];k<=s[i+][j];k++){ if (f[k][j-]+sum[k+][i]<f[i][j]){ f[i][j]=f[k][j-]+sum[k+][i]; s[i][j]=k; } } } } cout<<f[n][m]<<endl; }
这样的代码,经过四边形不等式的优化,就是O(n)的算法了!
(以下一段文字来自https://blog.csdn.net/noiau/article/details/72514812)
关于O(n^2)复杂度的证明
其实证明很简单,对于一个i,j来说,我们要for s[i][j-1]到s[i+1][j]个数,那么所有的i和j加起来一共会for多少次呢?
我们可以这样思考
(s[2][2]-[1][1])+(s[3][3]-s[2][2])+(s[4][4]-s[3][3])+…+(s[n][n]-s[n-1][n-1])=s[n][n]-s[1][1]很显然是小于n的嘛,所以本来是(n *n *n)的复杂度,就这样降成了O(n *n)啦
(关于四边形不等式强推https://blog.csdn.net/noiau/article/details/72514812这篇博客)
FFFeiya编辑于2018.7.30