最近沉迷码农题无法自拔
首先有一个暴力的想法:对于每个重链维护一个splay,需要翻转的连起来,翻转,接回去
然后发现这样没问题。。。
一条链只能跨log个重链,也就只有log个splay的子树参与重排,所以一次翻转只要log^2的时间
需要维护的东西有点多
头一次在splay上维护这么多乱七八糟的东西,写错了好几遍
感觉学到不少debug技巧,也纠正了之前splay写法的一个小漏洞
话说极限数据好像和miaom拍不过啊。。。但是我们都A了啊
好烦啊不知道谁出锅了
#include <bits/stdc++.h>
#define num(x) (dep[x]-dep[top[x]]+1)
#define ans tr[split(num(x),num(y))]
#define rt root[RT]
#define RT1 root[n+1]
#define NO 800000
#define N 800000
#define INF 2000000000
#define ll long long
using namespace std;
struct trnode
{
int size,min,max;ll sum;int plus,num;
bool rev;
trnode()
{
size=;min=max=sum=plus=rev=num=;
}
} tr[NO];
int n,m,R,p,q,RT,x,y,z,RET,E,NODE,sttop,inv_x,inv_y;ll ret;
int c[NO][],fa[NO];bool die[N];
int root[N],Fa[N],size[N],fir[N],top[N],dep[N],st[N],sz[N],lin[N];
int nex[*N],to[*N];
char cmd[];
void up(int x)
{
if(!x)
{
tr[]=trnode();tr[].size=;
return;
}
tr[x].size=tr[c[x][]].size+tr[c[x][]].size+;
tr[x].sum=tr[c[x][]].sum+tr[c[x][]].sum+(ll)tr[x].plus*tr[x].size+tr[x].num;
tr[x].max=max(tr[c[x][]].max,tr[c[x][]].max)+tr[x].plus;
tr[x].max=max(tr[x].max,tr[x].num+tr[x].plus);
tr[x].min=tr[x].num;
if(c[x][]!=)
tr[x].min=min(tr[x].min,tr[c[x][]].min);
if(c[x][]!=)
tr[x].min=min(tr[x].min,tr[c[x][]].min);
tr[x].min+=tr[x].plus;
}
void down(int x)
{
if(tr[x].rev)
swap(c[x][],c[x][]),tr[c[x][]].rev^=,tr[c[x][]].rev^=,tr[x].rev=;
if(tr[x].plus)
tr[c[x][]].plus+=tr[x].plus,tr[c[x][]].plus+=tr[x].plus,
tr[x].num+=tr[x].plus,tr[x].plus=,up(c[x][]),up(c[x][]);
}
void rot(int x)
{
down(fa[x]);down(x);
int y=fa[x],k=c[y][]==x;
if(fa[y]) c[fa[y]][c[fa[y]][]==y]=x;
else rt=x;
fa[x]=fa[y];fa[y]=x;
c[y][k]=c[x][!k];
if(c[x][!k])
fa[c[x][!k]]=y;
c[x][!k]=y;
up(y);up(x);
}
void splay(int x,int z)
{
for(int y;(y=fa[x])!=z;rot(x))
if(fa[y]!=z)
rot(c[fa[y]][]==y^c[y][]==x?x:y);
}
int Find(int x)
{
if(x==) return -;
if(x>tr[rt].size) return -;
int now=rt;
while(down(now),tr[c[now][]].size+!=x)
if(x<=tr[c[now][]].size)
now=c[now][];
else
x-=tr[c[now][]].size+,now=c[now][];
return now;
}
int split(int x,int y)
{
if(x== && y==tr[rt].size)
return rt;
if(x==)
{ splay(Find(y+),);return c[rt][];}
if(y==tr[rt].size)
{ splay(Find(x-),);return c[rt][];}
splay(Find(x-),);
splay(Find(y+),rt);
return c[c[rt][]][];
}
void upd(int x)
{
int len=;
while(x)
lin[++len]=x,x=fa[x];
for(int i=len;i;i--)
down(lin[i]);
for(int i=;i<=len;i++)
up(lin[i]);
}
int Split(int x,int y)
{
int p=split(x,y),q=fa[p];
upd(p);
fa[p]=;c[q][c[q][]==p]=;
upd(q);
splay(q,);
return p;
}
void work(int x,int y)
{
RT=top[x];int tem;
if(cmd[]=='c')
tr[tem=split(num(x),num(y))].plus+=z,upd(tem);//?
if(cmd[]=='S')
ret+=ans.sum;
if(cmd[]=='a') ret=max(ret,(ll)ans.max);
if(cmd[]=='i') RET=min(RET,ans.min);
if(cmd[]=='v')
{
int tem=Split(num(x),num(y));
if(tem==rt) die[RT]=;
RT=n+;
if(!rt) rt=tem;
else
{
splay(Find(),);
c[rt][]=tem;fa[tem]=rt;
splay(tem,);
}
}
}
void add(int p,int q)
{
to[++E]=q;nex[E]=fir[p];fir[p]=E;
}
int build(int now,int fat)
{
size[now]=;Fa[now]=fat;dep[now]=dep[fat]+;
for(int i=fir[now];i;i=nex[i])
if(to[i]!=fat)
size[now]+=build(to[i],now);
return size[now];
}
void ins()
{
int now=rt;
while(c[now][]) now=c[now][];
c[now][]=++NODE;fa[NODE]=now;
tr[NODE]=trnode();
splay(NODE,);
}
void pou(int now,int Top)
{
top[now]=Top;
if(Top==now)
tr[++NODE]=trnode(),root[now]=NODE;
else
RT=Top,ins();
int id=;
for(int i=fir[now];i;i=nex[i])
if(to[i]!=Fa[now])
if(size[to[i]]>size[id])
id=to[i];
if(id) pou(id,Top);
for(int i=fir[now];i;i=nex[i])
if(to[i]!=Fa[now] && to[i]!=id)
pou(to[i],to[i]);
}
int main()
{
scanf("%d%d%d",&n,&m,&R);
for(int i=;i<n;i++)
scanf("%d%d",&p,&q),
add(p,q),add(q,p);
tr[].size=;
build(R,);
pou(R,R);
for(int i=;i<=m;i++)
{
if(i==)
int e=;
scanf("%s%d%d",cmd,&x,&y);
if(cmd[]=='c') scanf("%d",&z);
ret=;RET=INF;sttop=;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
work(top[x],x);
st[++sttop]=top[x];sz[sttop]=num(x);
x=Fa[top[x]];
}
if(dep[x]>dep[y])
swap(x,y);
if(cmd[]=='v')
{
RT=top[x];
inv_x=Find(num(x)-);inv_y=Find(num(y)+);
}
work(x,y);
if(cmd[]=='S' || cmd[]=='a') printf("%lld\n",ret);
if(cmd[]=='i') printf("%d\n",RET);
if(cmd[]=='v')
{
RT=n+;
tr[rt].rev^=;
int tem=Split(,num(y)-num(x)+);
RT=top[x];
if(inv_x==- && inv_y==-)
rt=tem,die[RT]=;
else
if(inv_x==-)
{
splay(inv_y,);
c[rt][]=tem;fa[tem]=rt;
splay(tem,);
}
else
if(inv_y==-)
splay(inv_x,),c[rt][]=tem,fa[tem]=rt,splay(tem,);
else
splay(inv_x,),splay(inv_y,rt),c[c[rt][]][]=tem,fa[tem]=c[rt][],splay(tem,);
for(int i=sttop;i;i--)
{
RT=n+;
int tem=Split(,sz[i]);
RT=st[i];
if(die[RT]) die[RT]=,rt=tem;
else
{
splay(Find(),);
c[rt][]=tem;fa[tem]=rt;splay(tem,);
}
}
}
root[n+]=;die[n+]=;
}
return ;
}