锟题x2
以下用$a\rightarrow b$表示端点为$a,b$的链
把式子写成$(h_1(x)+h_1(y)-h_1(lca))-h_2(lca')$,第一部分就是$x\rightarrow rt$和$y\rightarrow rt$的并的总长
考虑对第一棵树边分治,假设分治到$(u,v)$,我们想要统计所有跨过$(u,v)$的$x\rightarrow y$
设在树$1$上$fa_v=u$,对于$u$这边的点$x$,令$f_x=-\infty,g_x=dis(x,u\rightarrow rt)$,对$v$这边的点$y$,令$f_y=h_1(y),g_y=-\infty$,那么$h_1(x)+h_1(y)-h_1(lca)=g_x+f_y$(将另外的$f,g$设为$-\infty$是为了防止统计到不跨过$(u,v)$的情况)
所以我们可以将当前分治范围内的点拿出来在树$2$上建虚树,在虚树上统计答案即可
最后不要忘了统计$x=y$的答案...
#include<stdio.h> #include<algorithm> #include<vector> using namespace std; typedef long long ll; const int inf=2147483647; const ll linf=922337203685477580ll; int n; struct pr{ int to,v; pr(int a=0,int b=0){to=a;v=b;} }; struct tree1{ int h[733340],nex[1466670],to[1466670],v[1466670],M; vector<pr>g[366670]; void ins(int a,int b,int c){ M++; to[M]=b; v[M]=c; nex[M]=h[a]; h[a]=M; } void add(int a,int b,int c){ ins(a,b,c); ins(b,a,c); } int N; void dfs(int fa,int x){ vector<pr>::iterator it; int p=0; for(it=g[x].begin();it!=g[x].end();it++){ if(it->to!=fa){ if(p){ N++; add(p,N,0); add(N,it->to,it->v); p=N; }else{ add(x,it->to,it->v); p=x; } } } for(it=g[x].begin();it!=g[x].end();it++){ if(it->to!=fa)dfs(x,it->to); } } ll dis[733340]; int fa[733340],dep[733340]; void dfs(int x){ dep[x]=dep[fa[x]]+1; for(int i=h[x];i;i=nex[i]){ if(to[i]!=fa[x]){ fa[to[i]]=x; dis[to[i]]=dis[x]+v[i]; dfs(to[i]); } } } void gao(){ int i,x,y,z; for(i=1;i<n;i++){ scanf("%d%d%d",&x,&y,&z); g[x].push_back(pr(y,z)); g[y].push_back(pr(x,z)); } M=1; N=n; dfs(0,1); dfs(1); } }t1; struct tree2{ int h[366670],nex[733340],to[733340],v[733340],M; void add(int a,int b,int c){ M++; to[M]=b; v[M]=c; nex[M]=h[a]; h[a]=M; } int dfn[366670],mn[733340][20],dep[366670],lg[733340]; ll dis[366670]; void dfs(int fa,int x){ dfn[x]=++M; mn[M][0]=x; dep[x]=dep[fa]+1; for(int i=h[x];i;i=nex[i]){ if(to[i]!=fa){ dis[to[i]]=dis[x]+v[i]; dfs(x,to[i]); mn[++M][0]=x; } } } int qmin(int x,int y){return dep[x]<dep[y]?x:y;} int query(int l,int r){ int k=lg[r-l+1]; return qmin(mn[l][k],mn[r-(1<<k)+1][k]); } int lca(int x,int y){ if(dfn[x]>dfn[y])swap(x,y); return query(dfn[x],dfn[y]); } void gao(){ int i,j,x,y,z; for(i=1;i<n;i++){ scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } M=0; dfs(0,1); for(j=1;j<20;j++){ for(i=1;i+(1<<j)-1<=M;i++)mn[i][j]=qmin(mn[i][j-1],mn[i+(1<<(j-1))][j-1]); } for(i=2;i<=M;i++)lg[i]=lg[i>>1]+1; } }t2; bool cmp(int x,int y){return t2.dfn[x]<t2.dfn[y];} bool vis[1466670]; int siz[733340],p[366670],M; void dfs1(int fa,int x){ if(x<=n)p[++M]=x; siz[x]=1; for(int i=t1.h[x];i;i=t1.nex[i]){ if(!vis[i]&&t1.to[i]!=fa){ dfs1(x,t1.to[i]); siz[x]+=siz[t1.to[i]]; } } } int al,mn,cn; void dfs2(int fa,int x){ for(int i=t1.h[x];i;i=t1.nex[i]){ if(!vis[i]&&t1.to[i]!=fa){ dfs2(x,t1.to[i]); if(abs(al-2*siz[t1.to[i]])<mn){ mn=abs(al-2*siz[t1.to[i]]); cn=i; } } } } ll f[733340],g[733340]; void dfs3(int fa,int x,ll d){ g[x]=d; f[x]=-linf; for(int i=t1.h[x];i;i=t1.nex[i]){ if(!vis[i]&&t1.to[i]!=fa)dfs3(x,t1.to[i],t1.to[i]==t1.fa[x]?0:d+t1.v[i]); } } void dfs4(int fa,int x){ f[x]=t1.dis[x]; g[x]=-linf; for(int i=t1.h[x];i;i=t1.nex[i]){ if(!vis[i]&&t1.to[i]!=fa)dfs4(x,t1.to[i]); } } ll ans; struct vtree{ int h[366670],nex[366670],to[366670],M; void add(int a,int b){ M++; to[M]=b; nex[M]=h[a]; h[a]=M; } void dfs(int x){ for(int i=h[x];i;i=nex[i]){ dfs(to[i]); ans=max(ans,max(f[x]+g[to[i]],g[x]+f[to[i]])-t2.dis[x]); f[x]=max(f[x],f[to[i]]); g[x]=max(g[x],g[to[i]]); } h[x]=0; } void clear(int x){ f[x]=g[x]=-linf; for(int i=h[x];i;i=nex[i])clear(to[i]); } }vt; int st[366670],tp; void insert(int x){ if(!tp){ st[++tp]=x; return; } int l=t2.lca(x,st[tp]); while(tp>1&&t2.dep[st[tp-1]]>t2.dep[l]){ vt.add(st[tp-1],st[tp]); tp--; } if(t2.dep[st[tp]]>t2.dep[l]){ vt.add(l,st[tp]); tp--; } if(t2.dep[st[tp]]<t2.dep[l])st[++tp]=l; st[++tp]=x; } void build(){ int i; sort(p+1,p+M+1,cmp); tp=0; vt.M=0; for(i=1;i<=M;i++)insert(p[i]); for(i=1;i<tp;i++)vt.add(st[i],st[i+1]); } void solve(int x){ int y; M=0; dfs1(0,x); al=siz[x]; mn=inf; cn=0; dfs2(0,x); if(cn==0)return; vis[cn]=vis[cn^1]=1; x=t1.to[cn]; y=t1.to[cn^1]; if(t1.dep[x]>t1.dep[y])swap(x,y); build(); vt.clear(st[1]); dfs3(0,x,0); dfs4(0,y); vt.dfs(st[1]); solve(x); solve(y); } int main(){ scanf("%d",&n); t1.gao(); t2.gao(); ans=-linf; solve(1); for(int i=1;i<=n;i++)ans=max(ans,t1.dis[i]-t2.dis[i]); printf("%lld",ans); }