题意

给定一个长为\(N\)的序列\(a\),请求出所有满足下列条件的三元组\(<x,y,z>\)

  • \(1\leq x < y < z \leq N\)
  • \(a_x \oplus a_y < a_y \oplus a_z\)

这里的\(\oplus\)运算指按位异或


解法

异或\(+\)序列问题\(\to\) 01Trie

很自然的想到枚举\(y\),对\(1\to y-1\)\(y+1\to N\)分别维护一颗Trie,计算符合条件的\(<x,z>\)二元组个数

我们可以发现,如果按位考虑的话,比较\(x,\)\(z\)\(y\)的异或值大小,实际上起到决定作用的是\(x,z\)的二进制位从高到低第一个不同的位

暴力Trie上DFS的复杂度显然是不对的,可能达到指数级别

由于动态维护Trie每次只增删一条链,考虑统计这条链所带来的影响

\(f[i][0/1]\)为从高位到低位考虑,前\(i-1\)位相同,第\(i\)位不同的\(<x,z>\)对数。其中,第二维为\(0\)代表\(x\)的第\(i\)位为\(0\)\(z\)的第\(i\)位为\(1\)(所以我们的Trie是由高位到低位建立的)

这样我们只需在Trie上动态维护\(f\)数组,统计答案时直接用\(f\)数组更新即可


代码

#include <cstdio>
#include <cctype>
#include <cstring>

using namespace std;

const int MAX_N = 1e5 + 10;
const int lg = 30;

int read();

int a[MAX_N];

long long ans;
long long f[lg + 1][2];

struct Trie {

    int root, cnt;

    struct node {
        int sum;
        int ch[2];
    } t[MAX_N * lg];

    void clear() {
        for (int i = 1; i <= cnt; ++i)
            t[i].ch[0] = t[i].ch[1] = t[i].sum = 0;
        cnt = root = 1;
    }

    void ins(int x) {
        int p = root;
        for (int i = lg; i >= 0; --i) {
            int c = x >> i & 1;
            if (!t[p].ch[c])
                t[p].ch[c] = ++cnt;
            p = t[p].ch[c];
            t[p].sum++;
        }
    }

    void era(int x) {
        int p = root;
        for (int i = lg; i >= 0; --i) {
            int c = x >> i & 1;
            p = t[p].ch[c];
            t[p].sum--;
        }
    }

} tr_a, tr_b;

// f[x][0] : pre 0 suf 1
// f[x][1] : pre 1 suf 0
void add(int x) {
    int p = 1;
    for (int i = lg; i >= 0; --i) {
        int c = x >> i & 1;
        f[i][c] += tr_b.t[tr_b.t[p].ch[c ^ 1]].sum;
        p = tr_b.t[p].ch[c];
    }
}

void del(int x) {
    int p = 1;
    for (int i = lg; i >= 0; --i) {
        int c = x >> i & 1;
        f[i][c ^ 1] -= tr_a.t[tr_a.t[p].ch[c ^ 1]].sum;
        p = tr_a.t[p].ch[c];
    }
}

int main() {

//  freopen("xyz.in", "r", stdin);
//  freopen("xyz.out", "w", stdout);

    int T = read();

    while (T--) {

        int N = read();
        for (int i = 1; i <= N; ++i)  a[i] = read();

        ans = 0;
        tr_a.clear(), tr_b.clear();
        memset(f, 0, sizeof f);

        for (int i = 2; i <= N; ++i)  tr_b.ins(a[i]);

        for (int i = 2; i < N; ++i) {
            del(a[i]), tr_b.era(a[i]);
            add(a[i - 1]), tr_a.ins(a[i - 1]);
            for (int j = lg; j >= 0; --j)
                ans += f[j][a[i] >> j & 1];
        }

        printf("%lld\n", ans);
    }

    return 0;
}

int read() {
    int x = 0, c = getchar();
    while (!isdigit(c))  c = getchar();
    while (isdigit(c))   x = x * 10 + c - 48, c = getchar();
    return x;
}
02-10 04:28