## 非常神仙的 wqs 二分优化dp,又学了一招。
首先我们需要先想到一个人类智慧版的前缀和优化。
# part 1:violence
然鹅在前缀和优化之前我们先考虑暴力做法:
我们可以枚举 i 、 j 令其表示前 i 个村庄设立 j 个邮局的最小贡献。
然后枚举 k 表示前 k 个村庄已经设立邮局,现在处理 k+1~i 的村庄。
接着再枚举当前邮局设立在哪里,然后 O(n) 累加每个村庄的贡献。
这样的复杂度是 O(n^5) 的,也许达不到这个上限,但是 O(n^4) 的时间总是要的。
于是这样...已经炸掉了。
# part 2:optimization(human wisdom)
我们考虑在一段区间内建立一个邮局,那么这个邮局会使得附近村庄的贡献降低。
那么如何使得这个降低的贡献最大呢?我们可以由 **~~人类智慧~~ 推论** 得出:
当我们将邮局设立在一个要产生贡献的区间的中点时,降低的贡献最大。
那么这时我们不妨设区间中点坐标为 k ,左端点坐标 i ,右端点坐标 j 。
此时这段区间对答案的贡献为:# $$(S[j]-S[k])-a[k] \times (j-k) + a[k] \times (k-i)-(S[k]-S[i])$$ #
那么这样的复杂度是 O(n^3) 的,已经有了较大进步,起码30分是到手了。
# part 3:optimization(Quadrilateral inequality)
于是我们考虑进一步优化,看到 满数据 是 3e3 的数据范围,那么应该是要用 O(n^2) 的算法。
那这里就要用 四边形不等式优化了(我不会)。同学们可以自行研究,大概就是根据 f[i][j] 的一个性质:
f[i][j]+f[i-1][j+1]>f[i-1][j]+f[i][j+1] => f[i][j] 的决策点在 f[i-1][j] 和 f[i][j+1] 之间之类的。
(怎么证我就母鸡了)
于是 O(n^2) 满分。
# part4:optimization(wqs binary cut)
然鹅我们还可以考虑进一步升华算法。
我们可以考虑二分将算法复杂度优化成 O(nlogn)
而且是 wqs 二分。
如何二分? 我们考虑给区间的每次分割增加一个贡献。
那么我们可以看出:增加贡献越大,将会分割的次数就越少。
容易想到,当分割出的段数恰好为 m 时,该状态下的 f[n] 减去增加贡献就是答案。
这样一个 log 去了。 那么怎么 O(n) 转移方程?
我们用单调队列优化转移,单调队列内每个点记录上次转移位置以及其控制的后方最优解范围。
也就是说,最后一种算法是用了二分优化 part3 ,将一个 n 变成了 log 。
# part 5:coding(s)
$$ O(n^3) $$
//by Judge
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
const int M=;
#ifndef Judge
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#endif
char buf[<<],*p1=buf,*p2=buf;
inline int read(){
int x=,f=; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-;
for(;isdigit(c);c=getchar()) x=x*+c-''; return x*f;
} int n,m,a[M],s[M],f[M][M];
inline void cmin(int& a,int b){ if(a>b) a=b; }
int main(){
n=read(),m=read(),memset(f,0x3f,sizeof(f)),f[][]=;
for(int i=;i<=n;++i) a[i]=read(); std::sort(a+,a++n);
for(int i=;i<=n;++i) s[i]=s[i-]+a[i];
for(int i=;i<=n;++i) for(int j=;j<=m;++j) for(int k=,t;k<i;++k)
t=i+k+>>,cmin(f[i][j],f[k][j-]+(s[i]-s[t])-a[t]*(i-t)+a[t]*(t-k)-(s[t]-s[k]));
return printf("%d\n",f[n][m]),;
}
n^3
$$ O(n^2) $$
//by Judge
#include<algorithm>
#include<iostream>
#include<cstdio>
#define ll long long
const int M=;
#ifndef Judge
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#endif
char buf[<<],*p1=buf,*p2=buf;
inline int read(){
int x=,f=; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-;
for(;isdigit(c);c=getchar()) x=x*+c-''; return x*f;
} int n,m,a[M],mk[M][M]; ll f[M][M],w[M][M];
int main(){
n=read(),m=read();
for(int i=;i<=n;++i) a[i]=read(); std::sort(a+,a++n);
for(int i=;i<=n;++i) for(int j=i+;j<=n;++j)
w[i][j]=w[i][j-]+a[j]-a[i+j>>];
for(int i=;i<=n;++i) f[][i]=w[][i],mk[][i]=;
for(int i=;i<=m;++i){ mk[i][n+]=n;
for(int j=n;j>i;--j){
f[i][j]=1ll<<;
for(int k=mk[i-][j];k<=mk[i][j+];++k)
if(f[i][j]>f[i-][k]+w[k+][j])
f[i][j]=f[i-][k]+w[k+][j],mk[i][j]=k;
}
} return printf("%lld\n",f[m][n]),;
}
n^2
$$ O(n logn) $$
//by Judge
#include<algorithm>
#include<iostream>
#include<cstdio>
#define mid (l+r>>1)
#define ll long long
using namespace std;
const int M=1e5+;
const ll inf=1e18+;
#ifndef Judge
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#endif
char buf[<<],*p1=buf,*p2=buf;
inline int read(){
int x=,f=; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-;
for(;isdigit(c);c=getchar()) x=x*+c-''; return x*f;
} int n,m,las[M]; ll a[M],S[M],f[M];
struct node{ int pos,l,r; // l~r 为 pos 点所控制的最优区间
node(int pos,int l,int r):pos(pos),l(l),r(r){} node(){}
}Q[M];
inline ll calc(int l,int r,int tim){
if(l>=r) return inf; int t=l+r+>>; //人类智慧+前缀和优化
return f[l]+(S[r]-S[t])-(r-t)*a[t]+(t-l)*a[t]-(S[t]-S[l])+tim;
} inline bool check(int tim){
int siz=,ans=; Q[]=node(,,n);
for(int i=;i<=n;++i){ int l=,r=siz,pos;
while(l<=r) Q[mid].l<=i?l=(pos=mid)+:r=mid-;
f[i]=calc(Q[pos].pos,i,tim),las[i]=Q[pos].pos,pos=n+;
while(siz&&calc(Q[siz].pos,Q[siz].l,tim)>=calc(i,Q[siz].l,tim)) pos=Q[siz--].l;
if(siz && calc(Q[siz].pos,Q[siz].r,tim)>=calc(i,Q[siz].r,tim)){ l=Q[siz].l,r=Q[siz].r;
while(l<=r) calc(Q[siz].pos,mid,tim)>=calc(i,mid,tim)?r=(pos=mid)-:l=mid+;
Q[siz].r=pos-;
} if(pos!=n+) Q[++siz]=node(i,pos,n);
} for(int i=n;i;i=las[i]) ++ans; return ans<m;
}
int main(){
n=read(),m=read();
for(int i=;i<=n;++i) a[i]=read();
sort(a+,a++n);
for(int i=;i<=n;++i) S[i]=S[i-]+a[i];
int l=,r=5e6; ll ans=;
while(l<=r) check(mid)?r=mid-:(ans=f[n]-m*mid,l=mid+);
return printf("%lld\n",ans),;
}
n log n