Splay 模板

 

Splay 数组实现模板

#include <iostream>
using namespace std;

const int N = 1e5 + 10;

struct Node {
   int son[2], p;
   int sz, cnt;
   int val;

   void init(int v, int par) {
      val = v, p = par;
      sz = 1;
      cnt = 1;
      son[0] = son[1] = 0;
   }
} Tree[N];

int root = 0;

// Splay

void update_size(int p) {
   Tree[p].sz = Tree[p].cnt;
   if (Tree[p].son[0]) Tree[p].sz += Tree[Tree[p].son[0]].sz;
   if (Tree[p].son[1]) Tree[p].sz += Tree[Tree[p].son[1]].sz;
}

void _rotate(int x) {
   int y = Tree[x].p;
   int z = Tree[y].p;

   int d = Tree[y].son[1] == x;

   Tree[y].son[d] = Tree[x].son[d ^ 1];
   Tree[x].son[d ^ 1] = y;

   Tree[x].p = z;
   Tree[y].p = x;

   if (Tree[y].son[d]) Tree[Tree[y].son[d]].p = y;

   update_size(y);
   update_size(x);

   if (z == 0)
      root = x;
   else
      Tree[z].son[Tree[z].son[1] == y] = x;
}

int id = 0;

void splay(int cur, int tar) {
   while (Tree[cur].p != tar) {
      int p = Tree[cur].p, pp = Tree[p].p;
      if (pp != tar) {
         if ((Tree[pp].son[1] == p) == (Tree[p].son[1] == cur))
            _rotate(p);
         else
            _rotate(cur);
      }
      _rotate(cur);
   }
   if (tar == 0) root = cur;
}

void insert(int v) {
   // cerr << "insert " << v << endl;
   int cur = root;
   int pre = 0;
   while (cur && Tree[cur].val != v) {
      pre = cur;
      cur = Tree[cur].son[v > Tree[cur].val];
   }

   if (cur) {
      Tree[cur].cnt++;
      Tree[cur].sz++;
   } else {
      cur = ++id;
      Tree[cur].init(v, pre);
      if (pre) Tree[pre].son[v > Tree[pre].val] = cur;
   }

   splay(cur, 0);
}

int find_pre(int cur, int v) {
   if (cur == 0) return 0;
   if (Tree[cur].val >= v)
      return find_pre(Tree[cur].son[0], v);
   else {
      auto t = find_pre(Tree[cur].son[1], v);
      if (t) return t;
      return cur;
   }
   return 0;
}

int find_nxt(int cur, int v) {
   if (cur == 0) return 0;
   if (Tree[cur].val <= v) {
      return find_nxt(Tree[cur].son[1], v);
   } else {
      auto t = find_nxt(Tree[cur].son[0], v);
      if (t) return t;
      return cur;
   }
}

void erase(int v) {
   // cerr << "erase " << v << endl;
   auto pre = find_pre(root, v);
   splay(pre, 0);
   auto nxt = find_nxt(root, v);
   splay(nxt, pre);

   int cur = Tree[nxt].son[0];

   if (cur == 0 || Tree[cur].cnt == 1) {
      Tree[nxt].son[0] = 0;
      splay(nxt, 0);
      return;
   }

   Tree[cur].cnt--;
   Tree[cur].sz--;
   splay(cur, 0);
}

int querry_val(int rk) {
   int now = root;

   while (rk) {
      int Lsize = Tree[now].son[0] ? Tree[Tree[now].son[0]].sz : 0;
      if (Lsize >= rk) {
         now = Tree[now].son[0];
      } else if (Lsize + Tree[now].cnt >= rk) {
         splay(now, 0);
         return Tree[now].val;
      } else {
         rk -= Lsize + Tree[now].cnt;
         now = Tree[now].son[1];
      }
   }
   return 0;
}

int querry_rank(int x) {
   int now = root;
   int pre = 0;
   int tot = 0;

   while (now && Tree[now].val != x) {
      int Lsize = Tree[now].son[0] ? Tree[Tree[now].son[0]].sz : 0;
      pre = now;
      if (Tree[now].val > x) {
         now = Tree[now].son[0];
      } else {
         tot += Lsize + Tree[now].cnt;
         now = Tree[now].son[1];
      }
   }

   if (now && Tree[now].son[0]) tot += Tree[Tree[now].son[0]].sz;
   if (now)
      splay(now, 0);
   else if (pre)
      splay(pre, 0);
   return tot + 1;
}

int nxt(int cur, int x) {
   if (cur == 0) return 1e9;
   if (Tree[cur].val <= x) return nxt(Tree[cur].son[1], x);
   auto t = nxt(Tree[cur].son[0], x);
   if (t != 1e9) return t;
   splay(cur, 0);
   return Tree[cur].val;
}

int pre(int cur, int x) {
   if (cur == 0) return -1e9;
   if (Tree[cur].val >= x) return pre(Tree[cur].son[0], x);
   auto t = pre(Tree[cur].son[1], x);
   if (t != -1e9) return t;
   splay(cur, 0);
   return Tree[cur].val;
}

int main() {
   // freopen("b.txt", "r", stdin);
   // freopen("c.txt", "w", stdout);
   cin.tie(0)->sync_with_stdio(false);
   int t;
   cin >> t;

   // ! 一定要哨兵

   insert(1e9);
   insert(-1e9);

   while (t--) {
      int op, x;
      cin >> op >> x;
      if (op == 1) insert(x);
      if (op == 2) erase(x);
      if (op == 3) cout << querry_rank(x) - 1 << '\n';
      if (op == 4) cout << querry_val(x + 1) << '\n';
      if (op == 5) cout << pre(root, x) << '\n';
      if (op == 6) cout << nxt(root, x) << '\n';
   }
   return 0;
}
algorithm   ACM_template