题意
给定一个长为\(N\)的序列\(a\),请求出所有满足下列条件的三元组\(<x,y,z>\)
- \(1\leq x < y < z \leq N\)
- \(a_x \oplus a_y < a_y \oplus a_z\)
这里的\(\oplus\)运算指按位异或
解法
异或\(+\)序列问题\(\to\) 01Trie
很自然的想到枚举\(y\),对\(1\to y-1\)与\(y+1\to N\)分别维护一颗Trie,计算符合条件的\(<x,z>\)二元组个数
我们可以发现,如果按位考虑的话,比较\(x,\)\(z\)与\(y\)的异或值大小,实际上起到决定作用的是\(x,z\)的二进制位从高到低第一个不同的位
暴力Trie上DFS的复杂度显然是不对的,可能达到指数级别
由于动态维护Trie每次只增删一条链,考虑统计这条链所带来的影响
设\(f[i][0/1]\)为从高位到低位考虑,前\(i-1\)位相同,第\(i\)位不同的\(<x,z>\)对数。其中,第二维为\(0\)代表\(x\)的第\(i\)位为\(0\),\(z\)的第\(i\)位为\(1\)(所以我们的Trie是由高位到低位建立的)
这样我们只需在Trie上动态维护\(f\)数组,统计答案时直接用\(f\)数组更新即可
代码
#include <cstdio>
#include <cctype>
#include <cstring>
using namespace std;
const int MAX_N = 1e5 + 10;
const int lg = 30;
int read();
int a[MAX_N];
long long ans;
long long f[lg + 1][2];
struct Trie {
int root, cnt;
struct node {
int sum;
int ch[2];
} t[MAX_N * lg];
void clear() {
for (int i = 1; i <= cnt; ++i)
t[i].ch[0] = t[i].ch[1] = t[i].sum = 0;
cnt = root = 1;
}
void ins(int x) {
int p = root;
for (int i = lg; i >= 0; --i) {
int c = x >> i & 1;
if (!t[p].ch[c])
t[p].ch[c] = ++cnt;
p = t[p].ch[c];
t[p].sum++;
}
}
void era(int x) {
int p = root;
for (int i = lg; i >= 0; --i) {
int c = x >> i & 1;
p = t[p].ch[c];
t[p].sum--;
}
}
} tr_a, tr_b;
// f[x][0] : pre 0 suf 1
// f[x][1] : pre 1 suf 0
void add(int x) {
int p = 1;
for (int i = lg; i >= 0; --i) {
int c = x >> i & 1;
f[i][c] += tr_b.t[tr_b.t[p].ch[c ^ 1]].sum;
p = tr_b.t[p].ch[c];
}
}
void del(int x) {
int p = 1;
for (int i = lg; i >= 0; --i) {
int c = x >> i & 1;
f[i][c ^ 1] -= tr_a.t[tr_a.t[p].ch[c ^ 1]].sum;
p = tr_a.t[p].ch[c];
}
}
int main() {
// freopen("xyz.in", "r", stdin);
// freopen("xyz.out", "w", stdout);
int T = read();
while (T--) {
int N = read();
for (int i = 1; i <= N; ++i) a[i] = read();
ans = 0;
tr_a.clear(), tr_b.clear();
memset(f, 0, sizeof f);
for (int i = 2; i <= N; ++i) tr_b.ins(a[i]);
for (int i = 2; i < N; ++i) {
del(a[i]), tr_b.era(a[i]);
add(a[i - 1]), tr_a.ins(a[i - 1]);
for (int j = lg; j >= 0; --j)
ans += f[j][a[i] >> j & 1];
}
printf("%lld\n", ans);
}
return 0;
}
int read() {
int x = 0, c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
return x;
}