本文版权归ljh2000和博客园共有,欢迎转载,但须保留此声明,并给出原文链接,谢谢合作。

本文作者:ljh2000
作者博客:http://www.cnblogs.com/ljh2000-jump/
转载请注明出处,侵权必究,保留最终解释权!

Description

BZOJ1912 [Apio2010]patrol 巡逻-LMLPHP

Input

第一行包含两个整数 n, K(1 ≤ K ≤ 2)。接下来 n – 1行,每行两个整数 a, b, 表示村庄a与b之间有一条道路(1 ≤ a, b ≤ n)。

Output

输出一个整数,表示新建了K 条道路后能达到的最小巡逻距离。

Sample Input

8 1
1 2
3 1
3 4
5 3
7 5
8 5
5 6

Sample Output

11

HINT

10%的数据中,n ≤ 1000, K = 1; 
30%的数据中,K = 1; 
80%的数据中,每个村庄相邻的村庄数不超过 25; 
90%的数据中,每个村庄相邻的村庄数不超过 150; 
100%的数据中,3 ≤ n ≤ 100,000, 1 ≤ K ≤ 2。

 
 
正解:树形DP
解题报告:
  这题很有意思,我记得以前做过一道叫做巡访的题目,正好是k=1的情况,结果这道题叫巡逻XD
  考虑k=1的时候,我们的答案很容易想到就是2*(n-1)-最长链+1,因为如果能加一条边的话,因为我希望减少的尽可能多,那么我只需要把最长链的首尾接起来,就不需要来回走,加一就是加了这一条新加入的边。
  但是k=2的时候呢?首先还是往最长链上面思考。然而做k=1的时候已经用掉了一段,k=2的时候怎么知道和k=1不重叠呢?
  很简单,我们在做k=1之后把最长链上的边权全部修改为-1,再跑一遍最长链就可以了。可能有人会疑问,那-1的边又被选了那不是相当于还是选进去两次了吗?但是考虑第一次算这条边的时候加了一,第二次的时候加的是-1,相当于是这条边没有产生任何贡献。可以画一画图就会发现,相当于是把两条交错的链变成了两条分开的链。这个做法很优秀。
  所以最后的总复杂度就是O(n)。 
 //It is made by ljh2000
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <ctime>
#include <vector>
#include <queue>
#include <map>
#include <set>
using namespace std;
typedef long long LL;
const int inf = (<<);
const int MAXN = ;
const int MAXM = ;
int n,k,ecnt,next[MAXM],to[MAXM],w[MAXM];
int f[MAXN][],first[MAXN],g[MAXN],p[MAXN];
int ans,root,Ans,Son,Son2; inline int getint()
{
int w=,q=; char c=getchar();
while((c<'' || c>'') && c!='-') c=getchar(); if(c=='-') q=,c=getchar();
while (c>='' && c<='') w=w*+c-'', c=getchar(); return q ? -w : w;
}
inline void link(int x,int y){ next[++ecnt]=first[x]; first[x]=ecnt; to[ecnt]=y; w[ecnt]=; }
inline void dfs(int x,int fa){
int now,son=,son2=;
for(int i=first[x];i;i=next[i]) {
int v=to[i]; if(v==fa) continue; dfs(v,x); now=f[v][]+w[i];
if(now>f[x][]) son=g[x],son2=p[x],f[x][]=f[x][],f[x][]=now,g[x]=v,p[x]=i; else if(now>f[x][]) f[x][]=now,son=v,son2=i;
}
if(f[x][]+f[x][]>ans) { ans=f[x][]+f[x][]; root=x; Son=son; Son2=son2; }
} inline void work(){
n=getint(); k=getint(); int x,y; for(int i=;i<n;i++) { x=getint(); y=getint(); link(x,y); link(y,x); }
dfs(,); Ans=*(n-)-ans+; if(k==) { printf("%d",Ans); return ; }
if(f[root][]>) { x=Son; w[Son2]=-; while(g[x]) { w[p[x]]=-; x=g[x]; } }
x=root; while(g[x]) w[p[x]]=-,x=g[x]; ans=; memset(f,,sizeof(f));
dfs(,); Ans-=ans-; printf("%d",Ans);
} int main()
{
work();
return ;
}
05-04 06:58