splay求条
查看原帖
splay求条
730195
Little_Cabbage楼主2025/1/20 19:44
#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 1e6 + 10;

int root, tot;
struct Splay {
    int val, fa, ls, rs, num, ln, rn;
} tr[N];
bool debug;

void zig(int x) {
    int y = tr[x].fa;
    tr[y].ls = tr[x].rs;
    if (tr[x].rs) tr[tr[x].rs].fa = y;
    tr[x].fa = tr[y].fa;
    if (tr[y].fa) {
        if (y == tr[tr[y].fa].ls) tr[tr[y].fa].ls = x;
        else tr[tr[y].fa].rs = x;
    }
    tr[x].rs = y;
    tr[y].fa = x;
    tr[y].ln = tr[x].rn;
    tr[x].rn = tr[y].ln + tr[y].rn + 1;
}

void zag(int x) {
    int y = tr[x].fa;
    tr[y].rs = tr[x].ls;
    if (tr[x].ls) tr[tr[x].ls].fa = y;
    tr[x].fa = tr[y].fa;
    if (tr[y].fa) {
        if (y == tr[tr[y].fa].ls) tr[tr[y].fa].ls = x;
        else tr[tr[y].fa].rs = x;
    }
    tr[x].ls = y;
    tr[y].fa = x;
    tr[y].rn = tr[x].ln;
    tr[x].ln = tr[y].ln + tr[y].rn + 1;
}

void splay(int x) {//
    int f;
    while (tr[x].fa) {
        f = tr[x].fa;
        // cout << f << "\n";
        if (!tr[f].fa) {
            if (x == tr[f].ls) zig(x);
            else zag(x);
            break;
        }
        if (x == tr[f].ls) {
            if (f == tr[tr[f].fa].ls) {
                zig(f);
                zig(x);
            } else {
                zig(x);
                zag(x);
            }
        } else {
            if (f == tr[tr[f].fa].ls) {
                zag(x);
                zig(x);
            } else {
                zag(f);
                zag(x);
            }
        }
    }
    root = x;
}

bool find(int x) {
    int p = root;
    while (p) {
    	// cout << p << "\n";
        if (x == tr[p].val) {
            splay(p);
            return true;
        }
        if (x < tr[p].val) p = tr[p].ls;
        else p = tr[p].rs;
    }
    return false;
}

int Find(int x) {
    int p = root;
    // cout << "* " << tr[p].val << "\n";
    if (tr[p].ln + tr[p].rn + 1 < x) return -1;
    while (1) {
        // cout << "- " << tr[p].val << " " << x << " " << tr[p].ln + 1 << " " << tr[p].ln + tr[p].num + 1 << " " << tr[p].rn << "\n";
        if (tr[p].ln + 1 <= x && x <= tr[p].ln + tr[p].num + 1) return tr[p].val;
        if (tr[p].ln >= x) p = tr[p].ls;
        else {
            x = x - tr[p].ln - 1 - tr[p].num;
            p = tr[p].rs;
        }
    }
}

void Insert(int x) {
	// cout << "NO\n";
    int p, f;
    p = root;
    while (p) {
        f = p;
        // cout << "YES\n"; 
        if (x <= tr[p].val) tr[p].ln ++ , p = tr[p].ls;
        else tr[p].rn ++ , p = tr[p].rs;
    }
    tot ++ ;
    tr[tot].val = x;
    tr[tot].ls = tr[tot].rs = tr[tot].fa = 0;
    if (tot == 1) {
        root = tot;
        return ;
    }
    // cout << tot << " " << f << "\n";
    tr[tot].fa = f;
    if (x <= tr[f].val) tr[f].ls = tot;
    else tr[f].rs = tot;
    // cout << "-----------------\ntr:\n";
    // for (int i = 1; i <= tot; i ++ )
    // cout << tr[i].val << " " << tr[i].ls << " " << tr[i].rs << " " << tr[i].fa << "\n";
    // cout << "-----------------\n";
    splay(tot);
}

void Delete(int x) {
    // if (debug) cout << "---------\nDelete: \n";
    find(x);
    int p = root;
    int ls = tr[p].ls, rs = tr[p].rs;
    if (!ls && !rs) {
        root = 0;
        return ;
    }
    if (!ls) {
        root = rs;
        tr[rs].fa = 0;
        return ;
    }
    if (!rs) {
        root = ls;
        tr[ls].fa = 0;
        return ;
    }
    p = ls;
    tr[p].fa = 0;
    while (tr[p].rs) p = tr[p].rs;
    // if (debug) cout << root << " " << tr[root].val << "\n";
    // if (debug) cout << p << " " << tr[p].val << " " << tr[p].ln << " " << tr[p].rn << "\n";
    splay(p);
    tr[p].rs = rs;
    tr[rs].fa = p;
    tr[p].rn = tr[rs].ln + tr[rs].rn + 1 + tr[rs].num;
    // if (debug) cout << "---------\n";
}

int Pred(int x) {
    find(x);
    int p = tr[root].ls;
    while (p) {
        if (!tr[p].rs) break;
        p = tr[p].rs;
    }
    if (p) return tr[p].val;
    return -1;
}

int Succ(int x) {
    find(x);
    int p = tr[root].rs;
    while (p) {
        // cout << p << " " << tr[p].ls << " " << tr[p].rs << " " << tr[p].val << "\n";
        if (!tr[p].ls) break;
        p = tr[p].ls;
    }
    if (p) return tr[p].val;
    return -1;
}

int n;

signed main() {
    cin >> n;
    // int cnt = 0;
    while (n -- ) {
        int op, x;
        cin >> op >> x;
        // fprintf(stderr, "%d %d\n", op, x);
        if (op == 1) {
        	if (find(x)) tr[root].num ++ ;
            else Insert(x);
        } else if (op == 2) {
            // debug = true;
            find(x);
            if (tr[root].num >= 1) {
                tr[root].num -- ;
                // cout << "fd\n";
            } else Delete(x);
            // debug = false;
        } else if (op == 3) {
            bool flag = false;
            if (!find(x)) {
                flag = true;
                Insert(x);
            }
            cout << tr[root].ln + 1 << "\n";
            if (flag) Delete(x);
            // cnt ++ ;
        } else if (op == 4) {
            cout << Find(x) << "\n";
            // cnt ++ ;
        } else if (op == 5) {
            bool flag = false;
            if (!find(x)) {
                flag = true;
                Insert(x);
            }
            cout << Pred(x) << "\n";
            if (flag) Delete(x);
            // cout << "flag: " << flag << " " << find(x) << "\n";
            // cnt ++ ;
        } else {
            bool flag = false;
            if (!find(x)) {
                flag = true;
                Insert(x);
            }
            cout << Succ(x) << "\n";
            if (flag) Delete(x);
            // cout << "flag: " << flag << " " << find(x) << "\n";
            // cnt ++ ;
        }
        // if (cnt == 468) {
        //     fprintf(stderr, "%d %d\n", op, x);
        //     cnt ++ ;
        // }
        // cout << "-----------------\ntr:\n";
        // for (int i = 1; i <= tot; i ++ )
        // 	cout << tr[i].val << " " << tr[i].ls << " " << tr[i].rs << " " << tr[i].fa << "\n";
    	// cout << "-----------------\n";
	}
    return 0;
}
2025/1/20 19:44
加载中...