题目大意:
给一个n个节点的树,然后将其分成k+1个联通块,再在每个联通块取一条路径,将其连接起来,求连接起来的路径最大权值。
题解:
考场只会20分,还都打挂了……
60分的做法其实并不难,nk DP即可,设$f(i,j,0/1/2)$表示i子树选取了j个联通块,i这个节点连了0/1/2条边时的最优解。
100分的做法就是60分做法的拓展。
很容易想到一件事,就是以联通块数为x轴,最优解为y轴,那么这个图像应该是一个单峰上凸函数。同时该离散函数每相邻两点间的斜率是递减的:因为考虑当前联通块数为a,则当联通块数为a+1时,必然是在a时最优解上再连接一段新切出的可空路径并割去一部分可空路径,当补上一条新路径时,我们很容易知道这次补上的路径-割去路径一定小于以前做的同样操作(最优性)。
这样我们发现斜率是具有单调性的(即单调减)。那么我们二分这个斜率,并将原图像减去这个斜率对应的正比例函数,会发现,新图像将会在这个斜率对应点的位置最高,同时也是一个斜率递减函数。
那么我们考虑如何求出此时的答案:新图像上的最高点权值+新图像上最高点联通块数*斜率。
我们考虑这个东西怎么求。
设二元组$f(i,0/1/2)$表示i节点连了0/1/2条边时的最优解和其联通块数(尽量小)。特别的没有连边的i,算为一个联通块,2为0/1/2这三个状态的最优解。
考虑这个东西怎么转移:
假设已经得到子节点v的答案。
对于$f(x,2)$,我们有三种选择,1.保持原来不变,把$f(v,2)$加上;2.由$f(x,1)$和$f(v,1)$合并;3.由$f(x,1)和f(v,0)$合并。
$f(x,1)$,我们同样有三种选择,大体同上者。
$f(x,0)$,我们只有一种选择,即和$f(v,2)$结合。
我们以$f(x,2)$为例:对于第一种情况,联通块数不变直接合并即可,对于第二种情况联通块数减少1,第三种情况同样减少了1个联通块。
得到结果以后比较k+1与最优解对应的联通块数,大于则说明斜率过小,否则说明斜率还可能更大。
代码:
#include "bits/stdc++.h" using namespace std; inline int read(){
int s=,k=;char ch=getchar();
while (ch<''|ch>'') ch=='-'?k=-:,ch=getchar();
while (ch>&ch<='') s=s*+(ch^),ch=getchar();
return s*k;
} typedef long long ll; const int N=3e5+; struct edges{
int v,w;edges *last;
}edge[N<<],*head[N];int cnt; inline void push(int u,int v,int w) {
edge[++cnt]=(edges){v,w,head[u]},head[u]=edge+cnt;
} int n,k;
ll slope;
const ll inf=1e15; struct node {
ll val,num;
node(){val=num=;}
node(ll v,ll nm):val(v),num(nm){}
inline ll &operator [](int x){
return x?num:val;
}
inline void max(node a){
if(a[]>val||(a[]==val&&a[]<num))
(*this)=a;
}
inline void add(node a,node b){
if (a[]==-inf||a[]==-inf) return ;
a[]+=b[],a[]+=b[];
max(a);
}
inline void add(node a,node b,int w,int opt){
if(a[]==-inf||b[]==-inf) return ;
a[]+=b[]-opt,a[]+=b[]+w+slope*opt;
if(a[]<=) return ;
max(a);
}
inline node fa(){
return node(val-slope,num+);
}
}f[N][]; inline void dp(int x,int fa){
f[x][]=f[x][]=f[x][]=node();
f[x][][]=f[x][][]=-inf;
for (edges *i=head[x];i;i=i->last) if(i->v!=fa) {
dp(i->v,x);
f[x][].add(f[x][],f[i->v][]);
f[x][].add(f[x][],f[i->v][],i->w,);
f[x][].add(f[x][],f[i->v][],i->w,);
f[x][].add(f[x][],f[i->v][]);
f[x][].add(f[x][],f[i->v][],i->w,);
f[x][].add(f[x][],f[i->v][],i->w,-);
f[x][].add(f[x][],f[i->v][]);
}
f[x][].max(f[x][]);
f[x][].max(f[x][]);
f[x][].max(f[x][].fa());
} int main(){
n=read(),k=read()+;
for (int i=;i<n;++i) {
int a=read(),b=read(),w=read();
push(a,b,w),push(b,a,w);
}
ll l=-1e12,r=1e12;
node now;
ll ans=;
while (l<=r) {
slope=l+r>>;
dp(,);
now=f[][];
if(now[]<=k)
ans=now[]+slope*k,r=slope-;
else l=slope+;
}
printf("%lld\n",ans);
}