50 pts 可持久化 0-1 trie 求调
查看原帖
50 pts 可持久化 0-1 trie 求调
868864
zjinze楼主2025/1/21 17:02
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 5e5 + 7;
int n, k, a[N], tot = 1, rt[N], sum[N], ans;
struct node {
    int l, r, siz, id;
} trie[N * 35 * 2];
int copy(int x) {
    ++tot;
    trie[tot] = trie[x];
    return tot;
}
void insert(int lstx, int nowx, int j, int val, int id) {
    if (j == -1) {
        trie[nowx].id = id;
        return ;
    }

    if ((val >> j) & 1) {
        trie[nowx].l = trie[lstx].l;
        trie[nowx].r = copy(trie[lstx].r);
        trie[trie[nowx].r].siz++;
        trie[nowx].id = id;
        insert(trie[lstx].r, trie[nowx].r, j - 1, val, id);
    } else {
        trie[nowx].r = trie[lstx].r;
        trie[nowx].l = copy(trie[lstx].l);
        trie[trie[nowx].l].siz++;
        trie[nowx].id = id;
        insert(trie[lstx].l, trie[nowx].l, j - 1, val, id);
    }


    return ;
}
int query(int lstx, int nowx, int j, int val) {
    if (j == -1)
        return trie[nowx].id;

    if ((val >> j) & 1) {
        if (trie[trie[nowx].l].siz - trie[trie[lstx].l].siz > 0) {
            return query(trie[lstx].l, trie[nowx].l, j - 1, val);
        } else {
            return query(trie[lstx].r, trie[nowx].r, j - 1, val);
        }
    } else {
        if (trie[trie[nowx].r].siz - trie[trie[lstx].r].siz > 0) {
            return query(trie[lstx].r, trie[nowx].r, j - 1, val);
        } else {
            return query(trie[lstx].l, trie[nowx].l, j - 1, val);
        }
    }

    return 0;
}
struct Q {
    int l, r, st, ed, val;
    bool friend operator < (Q a, Q b) {
        return a.val < b.val;
    }
    Q(int VAL, int L, int R, int ST, int ED) {
        val = VAL, l = L, r = R, st = ST, ed = ED;
    }
};
priority_queue<Q>que;
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n >> k;

    for (int i = 1; i <= n; i++)
        cin >> a[i], sum[i] = sum[i - 1] ^ a[i];

    rt[0] = ++tot;
    insert(0, rt[0], 33, 0, 0);

    for (int i = 1; i <= n; i++) {
        rt[i] = ++tot;
        insert(rt[i - 1], rt[i], 33, sum[i], i);
        int j = query(0, rt[i - 1], 33, sum[i]);
        que.push(Q{sum[i]^sum[j], j, i, 0, i - 1});
    }

    while (k--) {
        ans += que.top().val;
        int l = que.top().l, r = que.top().r, st = que.top().st, ed = que.top().ed;
        que.pop();

        if (l != st) {
            int j = query(st == 0 ? 0 : rt[st], rt[l - 1], 33, sum[r]);
            que.push(Q{sum[r]^sum[j], j, r, st, l - 1});
        }

        if (l != ed) {
            int j = query(rt[l + 1], rt[ed], 33, sum[r]);
            que.push(Q{sum[r]^sum[j], j, r, l + 1, ed});
        }
    }


    cout << ans << "\n";
    return 0;
}
2025/1/21 17:02
加载中...