https://www.luogu.org/problemnew/show/P4647
首先发现答案与顺序无关,令 $ x_i $ 表示高度为 $ i $ 的那一行帆的个数,第 $ i $ 行对答案的贡献为 $ \frac{x_i * (x_i - 1)}{2} $
先把旗杆按照高度从小到大排序,有一个显然的贪心是每次选择能放的地方帆最少的一行放一个帆,最少的一行放一个帆对答案的贡献一定最小,而后面的旗杆高度更高,能选择放帆的地方更多,这样可以保证答案最小(可以感性理解一下
那么我们的做法就出来了,对于旗杆 $ i $ 选择当前能放帆的最少的 $ k_i $ 行,放上帆,这个操作用平衡树实现,将原序列前 $ k $ 小拿出,加上 $ 1 $ 后放回去,这个时候可能会使平衡树的性质被打破,如 $ 0 $ $ 0 $ 将第一个数加 $ 1 $ 再放回去就变成了 $ 1 $ $ 0 $,所以对于分裂出来的前 $ k $ 小中的最大值还要再次分裂,如 $ 0 $ $ 1 $ $ 1 $ $ 1 $ $ 2 $ 前 $ 3 $ 小要加 $ 1 $,先分裂成 $ 0 $ $ 1 $ $ 1 $ 和 $ 1 $ $ 2 $,然后分裂成 $ 0 $ ,$ 1 $ $ 1 $ 和 $ 1 $ $ 2 $,区间加后放回原序列,先变成 $ 1 $ {这里放原来的 $ 1 $ $ 1 $ } $ 2 $,然后变成 {这里放原来的 $ 0 $} $ 1 $ {这里放原来的 $ 1 $ $ 1 $ } $ 2 $,即 $ 0 $ $ 1 $ $ 1 $ $ 1 $ $ 2 $,可以使用 splay 实现
#include <bits/stdc++.h>
#define Fast_cin ios::sync_with_stdio(false), cin.tie(0);
#define rep(i, a, b) for(register int i = a; i <= b; i++)
#define per(i, a, b) for(register int i = a; i >= b; i--)
#define DEBUG(x) cerr << "DEBUG" << x << " >>> " << endl;
using namespace std;
typedef unsigned long long ull;
typedef pair <int, int> pii;
typedef long long ll;
template <typename _T>
inline void read(_T &f) {
f = 0; _T fu = 1; char c = getchar();
while(c < '0' || c > '9') { if(c == '-') fu = -1; c = getchar(); }
while(c >= '0' && c <= '9') { f = (f << 3) + (f << 1) + (c & 15); c = getchar(); }
f *= fu;
}
template <typename T>
void print(T x) {
if(x < 0) putchar('-'), x = -x;
if(x < 10) putchar(x + 48);
else print(x / 10), putchar(x % 10 + 48);
}
template <typename T>
void print(T x, char t) {
print(x); putchar(t);
}
template <typename T>
struct hash_map_t {
vector <T> v, val, nxt;
vector <int> head;
int mod, tot, lastv;
T lastans;
bool have_ans;
hash_map_t (int md = 0) {
head.clear(); v.clear(); val.clear(); nxt.clear(); tot = 0; mod = md;
nxt.resize(1); v.resize(1); val.resize(1); head.resize(mod);
have_ans = 0;
}
bool count(int x) {
int u = x % mod;
for(register int i = head[u]; i; i = nxt[i]) {
if(v[i] == x) {
have_ans = 1;
lastv = x;
lastans = val[i];
return 1;
}
}
return 0;
}
void ins(int x, int y) {
int u = x % mod;
nxt.push_back(head[u]); head[u] = ++tot;
v.push_back(x); val.push_back(y);
}
int qry(int x) {
if(have_ans && lastv == x) return lastans;
count(x);
return lastans;
}
};
const int N = 1e5 + 5;
struct ele {
int h, gs;
bool operator < (const ele A) const { return h < A.h; }
} a[N];
int ch[N][2], val[N], tag[N], siz[N], n, root;
inline void update(int u) { siz[u] = siz[ch[u][0]] + siz[ch[u][1]] + 1; }
inline void add_tag(int u, int x) { if(u <= 100000) val[u] += x; tag[u] += x; }
inline void pushdown(int u) {
if(tag[u]) {
if(ch[u][0]) add_tag(ch[u][0], tag[u]);
if(ch[u][1]) add_tag(ch[u][1], tag[u]);
tag[u] = 0;
}
}
inline void rotate(int &u, int d) {
int tmp = ch[u][d];
ch[u][d] = ch[tmp][d ^ 1];
ch[tmp][d ^ 1] = u;
update(u); update(tmp);
u = tmp;
}
void splay(int &u, int k) {
pushdown(u);
int ltree = siz[ch[u][0]];
if(ltree + 1 == k) return;
int d = k > ltree;
pushdown(ch[u][d]);
int k2 = d ? k - ltree - 1 : k;
int ltree2 = siz[ch[ch[u][d]][0]];
if(ltree2 + 1 != k2) {
int d2 = k2 > ltree2;
splay(ch[ch[u][d]][d2], d2 ? k2 - ltree2 - 1 : k2);
if(d == d2) rotate(u, d); else rotate(ch[u][d], d2);
}
rotate(u, d);
}
int find(int u, int x) {
pushdown(u);
if(!u) return 0;
if(x > val[u]) return siz[ch[u][0]] + 1 + find(ch[u][1], x);
return find(ch[u][0], x);
}
void insert(int &u, int x, int y) {
splay(u, x + 1); splay(ch[u][0], x);
ch[ch[u][0]][1] = y; update(ch[u][0]); update(u);
}
ll ans;
void dfs(int u) {
if(!u) return;
pushdown(u);
dfs(ch[u][0]);
if(u <= 100000) ans += 1ll * val[u] * (val[u] - 1) / 2;
dfs(ch[u][1]);
}
int main() {
read(n);
for(register int i = 1; i <= n; i++) read(a[i].h), read(a[i].gs);
sort(a + 1, a + n + 1); int now = 1; root = 1;
ch[root][0] = 100001; ch[root][1] = 100002;
val[100001] = -1; val[100002] = 100002;
update(100001); update(100002); update(root);
for(register int i = 1; i <= n; i++) {
while(now < a[i].h) {
++now; update(now);
insert(root, 1, now);
}
splay(root, a[i].gs + 2); splay(ch[root][0], a[i].gs + 1);
int left = find(ch[root][0], val[ch[root][0]]), v = val[ch[root][0]], l = ch[root][0];
ch[root][0] = 0; update(root);
int right = find(root, v + 1);
if(!right) {
add_tag(l, 1);
ch[root][0] = l;
update(root);
} else {
// fprintf(stderr, "left = %d\n", left);
splay(l, left + 1);
int ll = ch[l][0]; ch[l][0] = 0; update(l);
add_tag(ll, 1); add_tag(l, 1);
insert(root, right, l);
splay(root, 1); ch[root][0] = ll; update(root);
}
}
dfs(root);
cout << ans << endl;
return 0;
}