Yuzhen's Blog

Yuzhen Qin

数据结构 #1:二叉搜索树

4
2024-06-18

定义

二叉搜索树是一种二叉树的树形数据结构,其定义如下:

  1. 空树是二叉搜索树。

  2. 若二叉搜索树的左子树不为空,则其左子树上所有点的附加权值均小于其根节点的值。

  3. 若二叉搜索树的右子树不为空,则其右子树上所有点的附加权值均大于其根节点的值。

  4. 二叉搜索树的左右子树均为二叉搜索树。

操作种类

在理想情况下,二叉搜索树可以以 O(\log n) 的复杂度完成以下操作:

  1. 插入一个数

  2. 删除一个数

  3. 查询某数的排名

  4. 查询指定排名的数

  5. 求某数的前驱

  6. 求某数的后继

实现

constexpr int MAXN = 1e6 + 10;

class bst {
   protected:
    int L[MAXN], R[MAXN], N[MAXN], val[MAXN], size[MAXN], cnt = 1;

   public:
    void insert(int v, int pos = 1) {
        size[pos]++;

        if (N[pos] == 0 && L[pos] == 0 && R[pos] == 0) {
            val[pos] = v;
            N[pos] = 1;
        } else if (v < val[pos]) {
            if (L[pos] == 0) L[pos] = ++cnt;

            insert(v, L[pos]);
        } else if (v > val[pos]) {
            if (R[pos] == 0) R[pos] = ++cnt;

            insert(v, R[pos]);
        } else
            N[pos]++;
    }

    void remove(int v, int pos = 1) {
        size[pos]--;

        if (v < val[pos])
            remove(v, L[pos]);
        else if (v > val[pos])
            remove(v, R[pos]);
        else
            N[pos]--;
    }

    int countl(int v, int pos = 1) {
        if (v < val[pos])
            return L[pos] ? countl(v, L[pos]) : 0;
        else if (v > val[pos])
            return size[L[pos]] + N[pos] + (R[pos] ? countl(v, R[pos]) : 0);
        else
            return size[L[pos]];
    }

    int countg(int v, int pos = 1) {
        if (v > val[pos])
            return R[pos] ? countg(v, R[pos]) : 0;
        else if (v < val[pos])
            return size[R[pos]] + N[pos] + (L[pos] ? countg(v, L[pos]) : 0);
        else
            return size[R[pos]];
    }

    int rank(int v) { return countl(v) + 1; }

    int kth(int k, int pos = 1) {
        if (size[L[pos]] + 1 > k)
            return kth(k, L[pos]);
        else if (size[L[pos]] + N[pos] < k)
            return kth(k - size[L[pos]] - N[pos], R[pos]);
        else
            return val[pos];
    }

    int pre(int v) { return kth(countl(v)); }

    int suc(int v) { return kth(size[1] - countg(v) + 1); }

    bst(vector<int> v) : L(), R(), N(), val(), size() {
        for (int i : v) insert(i);
    }
};