Treap 86pts求条
查看原帖
Treap 86pts求条
924402
yangjinqian楼主2025/1/24 20:08

rt,WA on #1 #2

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10, INF = 0x3f3f3f3f;
int n, root, idx;
struct node{
	int l, r;
	int len, cnt;
	int val, key;
}tr[N];
void pushup(int p){
	tr[p].len = tr[tr[p].l].len + tr[tr[p].r].len + tr[p].cnt;
}
int get_node(int key){
	tr[++idx].key = key;
	tr[idx].val = rand();
	tr[idx].cnt = tr[idx].len = 1;
	return idx;
}
void zig(int &p){
	int q = tr[p].l;
	tr[p].l = tr[q].r, tr[q].r = p, p = q;
	pushup(tr[p].r), pushup(p);
}
void zag(int &p){
	int q = tr[p].r;
	tr[p].r = tr[q].l, tr[q].l = p, p = q;
	pushup(tr[p].l), pushup(p);
}
void build(){
	root = get_node(-INF);
	tr[root].r = get_node(INF);
	pushup(root);
}
void insert(int &p, int key){
	if (!p) p = get_node(key);
	else if (tr[p].key == key) tr[p].cnt++;
	else if (tr[p].key > key){
		insert(tr[p].l, key);
		if (tr[tr[p].l].val > tr[p].val) zig(p);
	}
	else if (tr[p].key < key){
		insert(tr[p].r, key);
		if (tr[tr[p].r].val > tr[p].val) zag(p);
	}
	pushup(p);
}
void remove(int &p, int key){
	if (!p) return;
	if (tr[p].key == key){
		if (tr[p].cnt > 1) tr[p].cnt--;
		else if (tr[p].l || tr[p].r){
			if (!tr[p].r || tr[tr[p].l].val) zig(p), remove(tr[p].r, key);
			else zag(p), remove(tr[p].l, key);
		}
		else p = 0;
	}
	else if (tr[p].key > key) remove(tr[p].l, key);
	else if (tr[p].key < key) remove(tr[p].r, key);
	pushup(p);
}
int get_rank_by_key(int p, int key){
	if (!p) return 0;
	if (tr[p].key > key) return get_rank_by_key(tr[p].l, key);
	else if (tr[p].key == key) return tr[tr[p].l].len + 1;
	return get_rank_by_key(tr[p].r, key) + tr[tr[p].l].len + tr[p].cnt;
}
int get_key_by_rank(int p, int rank){
	if (!p) return INF;
	if (tr[tr[p].l].len >= rank) return get_key_by_rank(tr[p].l, rank);
	else if (tr[tr[p].l].len + tr[p].cnt >= rank) return tr[p].key;
	return get_key_by_rank(tr[p].r, rank - tr[tr[p].l].len - tr[p].cnt);
}
int get_prev(int p, int key){
	if (!p) return -INF;
	if (tr[p].key >= key) return get_prev(tr[p].l, key);
	return max(tr[p].key, get_prev(tr[p].r, key));
}
int get_next(int p, int key){
	if (!p) return INF;
	if (tr[p].key <= key) return get_next(tr[p].r, key);
	return min(tr[p].key, get_next(tr[p].l, key));
}
int main(){
	cin >> n;
	build ();
	while (n--){
		int op, x;
		cin >> op >> x;
		if (op == 1) insert(root, x);
		else if (op == 2) remove(root, x);
		else if (op == 3) cout << get_rank_by_key(root, x) - 1 << endl;
		else if (op == 4) cout << get_key_by_rank(root, x + 1) << endl;
		else if (op == 5) cout << get_prev(root, x) << endl;
		else cout << get_next(root, x) << endl;
	}
	return 0; 
}
2025/1/24 20:08
加载中...