Yuzhen's Blog

Yuzhen Qin

数据结构 #2:替罪羊树

10
2024-06-18

概念

当发现某个子树很不平衡时(size_{当前节点的子树} > \alpha \cdot size_{当前节点为根的树}, \alpha \in (0.5, 1),一般情况下,\alpha=0.7):

  1. 中序遍历拉平该子树(二叉搜索树的性质保证拉平后有序)

  2. 重建新的二叉搜索树

实现

P3369 【模板】普通平衡树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

#include <bits/stdc++.h>
using namespace std;

constexpr int MAXN = 1e6 + 10;

class scapegoat {
   private:
    vector<int> fp, fn, fv;

   protected:
    int L[MAXN], R[MAXN], N[MAXN], val[MAXN], size[MAXN], cnt = 1;

    bool exist(int x) { return !(N[x] == 0 && L[x] == 0 && R[x] == 0); }

   public:
    int flatten(int pos) {
        if (exist(L[pos])) flatten(L[pos]);

        int id = fp.size();

        if (N[pos] != 0) {
            fp.push_back(pos);
            fv.push_back(val[pos]);
            fn.push_back(N[pos]);
        }

        if (exist(R[pos])) flatten(R[pos]);

        return id;
    }

    void rebuild(int pos, int l = 0, int r = -1) {
        if (r == -1) r = fp.size() - 1;

        int mid = (l + r) / 2, sz1 = 0, sz2 = 0;

        if (l < mid) {
            L[pos] = fp[(l + mid - 1) / 2];
            rebuild(L[pos], l, mid - 1);
            sz1 = size[L[pos]];
        } else
            L[pos] = 0;

        if (mid < r) {
            R[pos] = fp[(mid + 1 + r) / 2];
            rebuild(R[pos], mid + 1, r);
            sz2 = size[R[pos]];
        } else
            R[pos] = 0;

        N[pos] = fn[mid];
        val[pos] = fv[mid];
        size[pos] = sz1 + sz2 + N[pos];
    }

    void refactor(int pos) {
        double k = max(size[L[pos]], size[R[pos]]) / double(size[pos]);
        if (k > 0.7) {
            fp.clear(), fn.clear(), fv.clear();
            int id = flatten(pos);
            swap(fp[id], fp[(fp.size() - 1) / 2]);
            rebuild(pos);
        }
    }

    void insert(int v, int pos = 1) {
        size[pos]++;

        if (!exist(pos)) {
            val[pos] = v;
            N[pos] = 1;
        } else if (v < val[pos]) {
            if (L[pos] == 0) L[pos] = ++cnt;

            insert(v, L[pos]);
        } else if (v > val[pos]) {
            if (R[pos] == 0) R[pos] = ++cnt;

            insert(v, R[pos]);
        } else
            N[pos]++;

        refactor(pos);
    }

    void remove(int v, int pos = 1) {
        size[pos]--;

        if (v < val[pos])
            remove(v, L[pos]);
        else if (v > val[pos])
            remove(v, R[pos]);
        else
            N[pos]--;

        refactor(pos);
    }

    int countl(int v, int pos = 1) {
        if (v < val[pos])
            return L[pos] ? countl(v, L[pos]) : 0;
        else if (v > val[pos])
            return size[L[pos]] + N[pos] + (R[pos] ? countl(v, R[pos]) : 0);
        else
            return size[L[pos]];
    }

    int countg(int v, int pos = 1) {
        if (v > val[pos])
            return R[pos] ? countg(v, R[pos]) : 0;
        else if (v < val[pos])
            return size[R[pos]] + N[pos] + (L[pos] ? countg(v, L[pos]) : 0);
        else
            return size[R[pos]];
    }

    int rank(int v) { return countl(v) + 1; }

    int kth(int k, int pos = 1) {
        if (size[L[pos]] + 1 > k)
            return kth(k, L[pos]);
        else if (size[L[pos]] + N[pos] < k)
            return kth(k - size[L[pos]] - N[pos], R[pos]);
        else
            return val[pos];
    }

    int pre(int v) { return kth(countl(v)); }

    int suc(int v) { return kth(size[1] - countg(v) + 1); }

    scapegoat() : L(), R(), N(), val(), size() {}

    scapegoat(vector<int> v) : L(), R(), N(), val(), size() {
        for (int i : v) insert(i);
    }
} tr;

int main(void) {
    int n;
    cin >> n;

    while (n--) {
        int opt, x;
        cin >> opt >> x;

        if (opt == 1)
            tr.insert(x);
        else if (opt == 2)
            tr.remove(x);
        else if (opt == 3)
            cout << tr.rank(x) << endl;
        else if (opt == 4)
            cout << tr.kth(x) << endl;
        else if (opt == 5)
            cout << tr.pre(x) << endl;
        else
            cout << tr.suc(x) << endl;
    }

    return 0;
}