BZOJ 1208 宠物收养所
我犯过的错误:删除一个节点后没有update新的根节点,导致size错了!
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 80005, P = 1000000;
int n, flag, root, idx, fa[N], ls[N], rs[N], sze[N], val[N];
ll ans;
#define which(u) (ls[fa[(u)]] == (u))
void debug(){
printf("root = %d\n", val[root]);
for(int i = 1; i <= idx; i++)
printf("val = %d, fa = %d, ls = %d, rs = %d, sze = %d\n", val[i], val[fa[i]], val[ls[i]], val[rs[i]], sze[i]);
}
void upt(int u){
sze[u] = sze[ls[u]] + sze[rs[u]] + 1;
}
void rotate(int u){
int v = fa[u], w = fa[v], b = which(u) ? rs[u] : ls[u];
if(w) which(v) ? ls[w] = u : rs[w] = u;
which(u) ? (ls[v] = b, rs[u] = v) : (rs[v] = b, ls[u] = v);
fa[u] = w, fa[v] = u;
if(b) fa[b] = v;
upt(v), upt(u);
}
void splay(int u, int tar){
while(fa[u] != tar){
if(fa[fa[u]] != tar){
if(which(u) == which(fa[u])) rotate(fa[u]);
else rotate(u);
}
rotate(u);
}
if(!tar) root = u;
}
void insert(int x){
int u = root, v = 0;
while(u){
v = u;
if(x < val[u]) u = ls[u];
else u = rs[u];
}
fa[++idx] = v, val[idx] = x, sze[idx] = 1;
if(v) x < val[v] ? ls[v] = idx : rs[v] = idx;
splay(idx, 0);
}
int find(int x){
int u = root, v = 0;
while(u && val[u] != x){
v = u;
if(x < val[u]) u = ls[u];
else u = rs[u];
}
return u ? u : v;
}
int getmin(int u){
while(ls[u]) u = ls[u];
return u;
}
int getmax(int u){
while(rs[u]) u = rs[u];
return u;
}
int getpre(int x){
int u = find(x);
splay(u, 0);
if(val[u] <= x) return u;
else return getmax(ls[u]);
}
int getnxt(int x){
int u = find(x);
splay(u, 0);
if(val[u] >= x) return u;
else return getmin(rs[u]);
}
void erase(int u){
splay(u, 0);
if(sze[u] == 1) root = 0;
else if(!ls[u] || !rs[u]) root = ls[u] + rs[u], fa[root] = 0;
else{
fa[ls[u]] = 0;
int v = getmax(ls[u]);
splay(v, 0);
rs[v] = rs[u], fa[rs[u]] = v, upt(root);
}
}
int main(){
read(n);
while(n--){
int op, x;
read(op), read(x);
if(!sze[root]){
insert(x);
flag = op;
}
else if(flag ^ op){
int u = getpre(x), v = getnxt(x);
if(u && (!v || x - val[u] <= val[v] - x))
ans = (ans + (x - val[u]) % P) % P, erase(u);
else
ans = (ans + (val[v] - x) % P) % P, erase(v);
}
else insert(x);
}
write(ans), enter;
return 0;
}