犯傻了,想到了如果是 0 的话就找最深的非 1 编号,然后是 1 的话就找最深的非 0 编号. 

但是没有想到可以直接拿 LCT + 下穿标记来维护这个东西. 

code: 

#include <bits/stdc++.h>
using namespace std;
#define N 1600000
#define lson t[x].ch[0]
#define rson t[x].ch[1]
#define get(x)  (t[t[x].f].ch[1]==x)
#define isrt(x) (!(t[t[x].f].ch[0]==x||t[t[x].f].ch[1]==x))
#define setIO(s) freopen(s".in","r",stdin)
int n,m,edges;
int hd[N],to[N],nex[N],sta[N];
struct node
{
    int ch[2],id[3],f,add,val,sum;
}t[N];
inline void mark(int x,int add)
{
    if(!x) return;
    t[x].sum+=add,t[x].val=t[x].sum>1;
    swap(t[x].id[1],t[x].id[2]);
    t[x].add+=add;
}
inline void pushup(int x)
{
    t[x].id[1]=t[rson].id[1];
    t[x].id[2]=t[rson].id[2];
    if(!t[x].id[1])
    {
        if(t[x].sum!=1)  t[x].id[1]=x;
        else t[x].id[1]=t[lson].id[1];
    }
    if(!t[x].id[2])
    {
        if(t[x].sum!=2)  t[x].id[2]=x;
        else t[x].id[2]=t[lson].id[2];
    }
}
inline void pushdown(int x)
{
    if(t[x].add)
    {
        if(lson)   mark(lson,t[x].add);
        if(rson)   mark(rson,t[x].add);
        t[x].add=0;
    }
}
inline void rotate(int x)
{
    int old=t[x].f,fold=t[old].f,which=get(x);
    if(!isrt(old))  t[fold].ch[t[fold].ch[1]==old]=x;
    t[old].ch[which]=t[x].ch[which^1],t[t[old].ch[which]].f=old;
    t[x].ch[which^1]=old,t[old].f=x,t[x].f=fold;
    pushup(old),pushup(x);
}
void splay(int x)
{
    int u=x,v=0,fa;
    for(sta[++v]=u;!isrt(u);u=t[u].f)   sta[++v]=t[u].f;
    for(;v;--v)   pushdown(sta[v]);
    for(u=t[u].f;(fa=t[x].f)!=u;rotate(x))
    {
        if(t[fa].f!=u)
        {
            rotate(get(fa)==get(x)?fa:x);
        }
    }
}
inline void Access(int x)
{
    for(int y=0;x;y=x,x=t[x].f)
    {
        splay(x);
        rson=y;
        pushup(x);
    }
}
void add(int u,int v)
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs(int u)
{
    for(int i=hd[u];i;i=nex[i])    dfs(to[i]),t[u].sum+=t[to[i]].val;
    if(u<=n) t[u].val=t[u].sum>1;
}
int main()
{
    // setIO("input");
    int i,j,Q;
    scanf("%d",&n);
    for(i=1;i<=n;++i)
    {
        for(j=1;j<=3;++j)
        {
            int x;
            scanf("%d",&x),add(i,x),t[x].f=i;
        }
    }
    for(i=n+1;i<=3*n+1;++i)    scanf("%d",&t[i].val);
    dfs(1);
    scanf("%d",&Q);
    int ans=t[1].val;
    while(Q--)
    {
        int x;
        scanf("%d",&x);
        int tag=(t[x].val?-1:1),y=t[x].f;
        Access(y),splay(y);
        int w=t[y].id[t[x].val?2:1];
        if(w)
        {
            splay(w);
            mark(t[w].ch[1],tag);
            pushup(t[w].ch[1]);
            t[w].sum+=tag;
            t[w].val=t[w].sum>1;
            pushup(w);
        }
        else
        {
            ans^=1;
            mark(y,tag);
            pushup(y);
        }
        t[x].val^=1;
        printf("%d\n",ans);
    }
    return 0;
}

  

12-14 05:48
查看更多