每次求出最长链更新答案后要将最长链上的边权改为-1
写的贼长 还可以优化...
/*Huyyt*/
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define pb push_back
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int MAXN = 1e5 + , MAXM = 2e5 + ;
int to[MAXM << ], nxt[MAXM << ], Head[MAXN], ed = ;
int value[MAXM << ];
inline void addedge(int u, int v, int val)
{
to[++ed] = v;
nxt[ed] = Head[u];
value[ed] = val;
Head[u] = ed;
}
int d[MAXN];
void dfs(int x, int pre)
{
for (int v, i = Head[x]; i; i = nxt[i])
{
v = to[i];
if (v == pre)
{
continue;
}
d[v] = d[x] + value[i];
dfs(v, x);
}
}
void change(int x)
{
for (int v, i = Head[x]; i; i = nxt[i])
{
v = to[i];
if (d[v] == d[x] - )
{
value[i] = value[i ^ ] = -;
change(v);
}
}
}
int s, t, dmx = -;
int ans2 = , vis[MAXN], dpd[MAXN];
void dp(int x)
{
vis[x] = ;
for (int v, i = Head[x]; i; i = nxt[i])
{
v = to[i];
if (vis[v])
{
continue;
}
dp(v);
ans2 = max(ans2, dpd[x] + dpd[v] + value[i]);
dpd[x] = max(dpd[x], dpd[v] + value[i]);
}
}
int main()
{
int anser;
int n, k;
int u, v;
scanf("%d %d", &n, &k);
for (int i = ; i < n; i++)
{
scanf("%d %d", &u, &v);
addedge(u, v, ), addedge(v, u, );
}
anser = * (n - );
d[] = ;
dfs(, );
for (int i = ; i <= n; i++)
{
if (d[i] > dmx)
{
dmx = d[i];
s = i;
}
}
d[s] = ;
dfs(s, );
dmx = -;
for (int i = ; i <= n; i++)
{
if (d[i] > dmx)
{
dmx = d[i];
t = i;
}
}
anser -= d[t] - ;
if (k == )
{
printf("%d\n", anser);
return ;
}
change(t);
dp();
anser -= ans2 - ;
printf("%d\n", anser);
return ;
}
//BZOJ1912
求树直径dp
void dp(int x)
{
vis[x] = ;
for (int v, i = Head[x]; i; i = nxt[i])
{
v = to[i];
if (vis[v])
{
continue;
}
dp(v);
ans2 = max(ans2, dpd[x] + dpd[v] + value[i]);
dpd[x] = max(dpd[x], dpd[v] + value[i]);
}
}
其实这个dp的作用是先把无根树转化为有根树 再求每个点子树中的最长链和次长链(如果有次长链的话)
则树的直径有两种情况
1.是一个节点的最长链
2.是一个节点的次长链+最长链
我们首先记录直径取最长是在哪个节点 然后在每个节点我们都要记录 次长链是那条边拓展出去和最长链是那条边拓展出去
因为一个节点的最长链和次长链必定是一个边加下一个节点的最长链
这样就可以一个dfs搞定
/*Huyyt*/
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define pb push_back
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int MAXN = 1e5 + , MAXM = 1e5 + ;
int to[MAXM << ], nxt[MAXM << ], Head[MAXN], ed = ;
int value[MAXM << ];
inline void addedge(int u, int v, int val)
{
to[++ed] = v;
nxt[ed] = Head[u];
value[ed] = val;
Head[u] = ed;
}
int mxlen[MAXN], mxlen2[MAXN];
int ansdis = ; //直径大小
int s, t;
int dfs(int x, int pre)
{
int mx1 = , mx2 = ; //当前节点的最长链和次长链长度
int now;
for (int v, i = Head[x]; i; i = nxt[i])
{
v = to[i];
if (v == pre)
{
continue;
}
now = dfs(v, x) + value[i];
if (now > mx1)
{
mx2 = mx1;
mxlen2[x] = mxlen[x];
mx1 = now;
mxlen[x] = i; //更新最长链 原最长链变为次长链
}
else if (now > mx2)
{
mx2 = now;
mxlen2[x] = i; //更新次长链
}
}
if (mx1 + mx2 > ansdis)
{
ansdis = mx1 + mx2;
s = x;
}
return mx1;//返回每个节点的最长链大小
}
int main()
{
int anser;
int n, k;
int u, v;
scanf("%d %d", &n, &k);
for (int i = ; i < n; i++)
{
scanf("%d %d", &u, &v);
addedge(u, v, ), addedge(v, u, );
}
anser = * (n - );
dfs(, );
anser -= ansdis - ;
if (k == )
{
printf("%d\n", anser);
return ;
}
ansdis = ;
for (int i = mxlen[s]; i; i = mxlen[to[i]]) //最长链上的边重置为-1
{
value[i] = value[i ^ ] = -;
}
for (int i = mxlen2[s]; i; i = mxlen[to[i]]) //次长链上的边重置为-1
{
value[i] = value[i ^ ] = -;
}
dfs(, );
anser -= ansdis - ;
printf("%d\n", anser);
return ;
}