题解:
树链剖分
和普通的树链剖分不一样,这里的线段树不只是要记录x-y的和
而是要记录x左到y左,x左到y右,x右到y左,x右到y右
然后就可以了
代码:
#include<bits/stdc++.h>
const int N=,M=,inf=1e9;
using namespace std;
int n,m,cnt,place,sz,last[N],x,y,deep[N],u,v,fa[N][],son[N],belong[N],pl[N];
char mp[N][];
struct data{int l1,l2,r1,r2,d1,d2,d3,d4;};
struct edge{int to,next;}e[*N];
struct seg{int l,r;data d;}t[*N];
void insert(int u,int v)
{
e[++cnt].to=v;
e[cnt].next=last[u];
last[u]=cnt;
}
data merge(data a,data b)
{
data tmp;
tmp.d1=max(a.d1+b.d1,a.d2+b.d3);
if (tmp.d1<)tmp.d1=-inf;
tmp.d2=max(a.d1+b.d2,a.d2+b.d4);
if (tmp.d2<)tmp.d2=-inf;
tmp.d3=max(a.d3+b.d1,a.d4+b.d3);
if (tmp.d3<)tmp.d3=-inf;
tmp.d4=max(a.d3+b.d2,a.d4+b.d4);
if (tmp.d4<)tmp.d4=-inf;
tmp.l1=max(a.d1+b.l1,a.d2+b.l2);
tmp.l1=max(tmp.l1,a.l1);
if (tmp.l1<)tmp.l1=-inf;
tmp.l2=max(a.d3+b.l1,a.d4+b.l2);
tmp.l2=max(tmp.l2,a.l2);
if (tmp.l2<)tmp.l2=-inf;
tmp.r1=max(b.d1+a.r1,b.d3+a.r2);
tmp.r1=max(tmp.r1,b.r1);
if (tmp.r1<)tmp.r1=-inf;
tmp.r2=max(b.d2+a.r1,b.d4+a.r2);
tmp.r2=max(tmp.r2,b.r2);
if (tmp.r2<)tmp.r2=-inf;
return tmp;
}
void build(int k,int l,int r)
{
t[k].l=l;t[k].r=r;
if (l==r)return;
int mid=(l+r)>>;
build(k<<,l,mid);
build(k<<|,mid+,r);
}
void update(int k,int x,char mp[])
{
int l=t[k].l,r=t[k].r;
if(l==r)
{
t[k].d.d1=t[k].d.d2=t[k].d.d3=t[k].d.d4=-inf;
t[k].d.l1=t[k].d.l2=t[k].d.r1=t[k].d.r2=-inf;
if (mp[]=='.')t[k].d.d1=t[k].d.l1=t[k].d.r1=;
if (mp[]=='.')t[k].d.d4=t[k].d.l2=t[k].d.r2=;
if (mp[]=='.'&&mp[]=='.')
t[k].d.d2=t[k].d.d3=t[k].d.l1=t[k].d.l2=t[k].d.r1=t[k].d.r2=;
return;
}
int mid=(l+r)>>;
if (x<=mid)update(k<<,x,mp);
else update(k<<|,x,mp);
t[k].d=merge(t[k<<].d,t[k<<|].d);
}
data query(int k,int x,int y)
{
int l=t[k].l,r=t[k].r;
if (x==l&&y==r)return t[k].d;
int mid=(l+r)>>;
if (mid>=y)return query(k<<,x,y);
else if (mid<x)return query(k<<|,x,y);
else return merge(query(k<<,x,mid),query(k<<|,mid+,y));
}
void dfs1(int x)
{
son[x]=;
for (int i=;i<=;i++)
if (deep[x]>=(<<i))fa[x][i]=fa[fa[x][i-]][i-];
for (int i=last[x];i;i=e[i].next)
{
if(e[i].to==fa[x][])continue;
fa[e[i].to][]=x;
deep[e[i].to]=deep[x]+;
dfs1(e[i].to);
son[x]+=son[e[i].to];
}
}
void dfs2(int x,int chain)
{
pl[x]=++place;belong[x]=chain;
update(,pl[x],mp[x]);
int k=;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa[x][])continue;
if (son[e[i].to]>son[k])k=e[i].to;
}
if (k)dfs2(k,chain);
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==k||e[i].to==fa[x][])continue;
dfs2(e[i].to,e[i].to);
}
}
int lca(int x,int y)
{
if(deep[x]<deep[y])swap(x,y);
int t=deep[x]-deep[y];
for (int i=;i<=;i++)
if ((<<i)&t)x=fa[x][i];
for (int i=;i>=;i--)
if (fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
if (x==y)return x;
return fa[x][];
}
data solveque(int x,int f,bool flag)
{
data ans;
ans.l1=ans.l2=ans.r1=ans.r2=;
ans.d1=ans.d4=ans.d2=ans.d3=;
while (belong[x]!=belong[f])
{
ans=merge(query(,pl[belong[x]],pl[x]),ans);
x=fa[belong[x]][];
}
if (flag==&&pl[f]+<=pl[x])ans=merge(query(,pl[f]+,pl[x]),ans);
if (!flag)ans=merge(query(,pl[f],pl[x]),ans);
return ans;
}
void que(int x,int y)
{
if (mp[x][]=='#'&&mp[x][]=='#'){puts("");return;}
int f=lca(x,y);
data a=solveque(x,f,),b=solveque(y,f,);
swap(a.d2,a.d3);
swap(a.l1,a.r1);
swap(a.l2,a.r2);
data ans=merge(a,b);
printf("%d\n",max(ans.l1,ans.l2));
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=;i<n;i++)scanf("%d%d",&u,&v),insert(u,v),insert(v,u);
for (int i=;i<=n;i++)scanf("%s",mp[i]);
build(,,n);
dfs1();dfs2(,);
for (int i=;i<=m;i++)
{
char ch[];
scanf("%s",ch);
if (ch[]=='Q')scanf("%d%d",&x,&y),que(x,y);
else
{
scanf("%d",&x);
scanf("%s",mp[x]);
update(,pl[x],mp[x]);
}
}
return ;
}