玄关求条,WA on #2
查看原帖
玄关求条,WA on #2
1024254
CloverLava楼主2024/12/12 19:32
#include<bits/stdc++.h>
using namespace std;
const int N = 100005, INF = 2000000000;
struct Node
{
	int size, cnt, v;
	int p, s[2];
	void init(int _v, int _p)
	{
		v = _v, p = _p;
		size = 1;
	}
}g[N];
int n, root, cntt;
void pushup(int x)
{
	g[x].size = g[g[x].s[0]].size + g[g[x].s[1]].size + g[x].cnt;
}
void rotate(int x)
{
	int y = g[x].p, z = g[y].p;
	int k = g[y].s[1] == x;
	g[z].s[g[z].s[1] == y] = x, g[x].p = z;
	g[y].s[k] = g[x].s[k ^ 1], g[g[x].s[k^1]].p = y;
	g[x].s[k^1] = y, g[y].p = x;
	pushup(y), pushup(x);
}
void splay(int x, int k)
{
	while(g[x].p != k)
	{
		int y = g[x].p, z = g[y].p;
		if(z != k)
		{
			if((g[y].s[1] == x) ^ (g[z].s[1] == y)) rotate(x);
			else rotate(y);
		}
		rotate(x);
	}
	if(!k) root = x;
}
void upper(int v)
{
	int u = root;
	while(g[u].s[v > g[u].v] && g[u].v != v) u = g[u].s[v > g[u].v];
	splay(u, 0);
}
int get_prev(int v)
{
	upper(v);
	if(g[root].v < v) return root;
	int u = g[root].s[0];
	while(g[u].s[1]) u = g[u].s[1];
	return u;
}
int get_next(int v)
{
	upper(v);
	if(g[root].v > v) return root;
	int u = g[root].s[1];
	while(g[u].s[0]) u = g[u].s[0];
	return u;
}
int get_rank_by_val(int v)
{
	upper(v);
	return g[g[root].s[0]].size + 1;
}
int get_val_by_rank(int k)
{
	int u = root;
	while(g[u].size >= k)
	{
		if(g[g[u].s[0]].size >= k) u = g[u].s[0];
		else if(g[g[u].s[0]].size + g[u].cnt >= k) return splay(u, 0), g[u].v;
		else k -= g[g[u].s[0]].size + g[u].cnt, u = g[u].s[1];
	}
	return -1;
}
void insert(int v)
{
	int u = root, p = 0;
	while(u && g[u].v != v) p = u, u = g[u].s[v > g[u].v];
	if(u) g[u].cnt++;
	else
	{
		u = ++cntt;
		if(p) g[p].s[v > g[p].v] = u;
		g[u] = {1, 1, v, p};
	}
	splay(u, 0);
}
void remove(int v)
{
	int prev = get_prev(v), next = get_next(v);
	splay(prev, 0), splay(next, prev);
	int w = g[next].s[0];
	if(g[w].cnt > 1) g[w].cnt--, splay(w, 0);
	else g[next].s[0] = 0, splay(next, 0);
}
void output(int u)
{
	if(g[u].s[0]) output(g[u].s[0]);
	if(g[u].v != -INF && g[u].v != INF) printf("%d ", g[u].v);
	if(g[u].s[1]) output(g[u].s[1]); 
}
int main()
{
	int op, x;
	insert(INF), insert(-INF);
	scanf("%d", &n);
	while(n--)
	{
		scanf("%d%d", &op, &x);
		switch(op)
		{
			case 1: insert(x); break;
			case 2: remove(x); break;
			case 3: printf("%d\n", get_rank_by_val(x) - 1); break;
			case 4: printf("%d\n", get_val_by_rank(x+1)); break;
			case 5: printf("%d\n", g[get_prev(x)].v); break;
			case 6: printf("%d\n", g[get_next(x)].v); break;
		}
	}
	return 0;
}
2024/12/12 19:32
加载中...