Splay 几乎全 TLE 求条
查看原帖
Splay 几乎全 TLE 求条
867577
lucasincyber楼主2025/1/24 12:47

insert 的 else 出了问题,但不会改,后面也有错

求 dalao 调一下

#include <bits/stdc++.h>
using namespace std;

const int N = 1e5 + 5, inf = 1e7;

int n, idx, root;

struct Splay
{
	struct node
	{
		int s[2], fa, val, cnt, sz;
		void init(int f, int v)
		{
			val = v, fa = f, cnt = sz = 1;
		}
	} tr[N];
	void push_up(int p)
	{
		tr[p].sz = tr[tr[p].s[0]].sz + tr[tr[p].s[1]].sz + tr[p].cnt;
	}
	void rotate(int x)
	{
		int y = tr[x].fa, z = tr[y].fa;
		int k = (tr[y].s[1] == x);
		tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].fa = y;
		tr[x].s[k ^ 1] = y, tr[y].fa = x;
		tr[z].s[(tr[z].s[1]) == y] = x, tr[x].fa = z;
		push_up(y); push_up(x);
	}
	void splay(int x, int k)
	{
		while (tr[x].fa != k)
		{
			int y = tr[x].fa, z = tr[y].fa;
			if (z != k)
			{
				int t1 = (tr[y].s[0] == x), t2 = (tr[z].s[0] == y);
				if (t1 ^ t2) rotate(x);
				else rotate(y);
			}
			rotate(x);
		}
		if (!k) root = x;
	}
	void insert(int val)
	{
		int f = 0, x = root;
		while (x && tr[x].val != val)
		{
			f = x;
			x = tr[x].s[val > tr[x].val];
		}
		if (x) tr[x].cnt++;
		else
		{
			int p = ++idx;
			tr[f].s[val > tr[f].val] = p;
			tr[p].init(f, val);
		}
		splay(x, 0);
	}
	void find(int val)
	{
		int x = root;
		while (tr[x].s[val > tr[x].val] && tr[x].val != val) x = tr[x].s[val > tr[x].val];
		splay(x, 0);
	}
	int queryPre(int val)
	{
		find(val);
		int x = root;
		if (tr[x].val < val) return x;
		x = tr[x].s[0];
		while (tr[x].s[1]) x = tr[x].s[1];
		splay(x, 0);
		return x;
	}
	int querySuc(int val)
	{
		find(val);
		int x = root;
		if (tr[x].val > val) return x;
		x = tr[x].s[1];
		while (tr[x].s[0]) x = tr[x].s[0];
		splay(x, 0);
		return x;
	}
	void remove(int val)
	{
		int pre = queryPre(val), suc = querySuc(val);
		splay(pre, 0), splay(suc, pre);
		int x = tr[suc].s[0];
		if (tr[x].cnt > 1) tr[x].cnt--, splay(x, 0);
		else tr[suc].s[0] = 0, splay(suc, 0);
	}
	int queryRank(int val)
	{
		insert(val);
		int res = tr[tr[root].s[0]].sz;
		remove(val);
		return res;
	}
	int queryVal(int rk)
	{
		int x = root;
		while (true)
		{
			int y = tr[x].s[0];
			if (tr[y].sz + tr[x].cnt < rk)
			{
				rk -= tr[y].sz + tr[x].cnt;
				x = tr[x].s[1];
			}
			else if (tr[y].sz >= rk) x = y;
			else break;
		}
		splay(x, 0);
		return tr[x].val;
	}
} spl;

int main()
{
	spl.insert(-inf); spl.insert(inf);
	scanf("%d", &n);
	while (n--)
	{
		int opt, x;
		scanf("%d%d", &opt, &x);
		if (opt == 1) spl.insert(x);
		else if (opt == 2) spl.remove(x);
		else if (opt == 3) printf("%d\n", spl.queryRank(x));
		else if (opt == 4) printf("%d\n", spl.queryVal(x + 1));
		else if (opt == 5) printf("%d\n", spl.queryPre(x));
		else printf("%d\n", spl.querySuc(x));
	}
	return 0;
}
2025/1/24 12:47
加载中...