记录编号 |
575382 |
评测结果 |
AAAAAAAAAA |
题目名称 |
[Tyvj 1728]普通平衡树 |
最终得分 |
100 |
用户昵称 |
lihaoze |
是否通过 |
通过 |
代码语言 |
C++ |
运行时间 |
0.705 s |
提交时间 |
2022-09-13 10:50:29 |
内存使用 |
2.94 MiB |
显示代码纯文本
#include <bits/stdc++.h>
struct item {
int key, prior, cnt, size;
item *l, *r;
item () { }
item (int key) : key(key), prior(std::rand()), l(nullptr), r(nullptr), cnt(1), size(1) { }
};
using pitem = item*;
pitem root = nullptr;
void update(pitem& x) {
x->size = x->cnt + (x->l ? x->l->size : 0) + (x->r ? x->r->size : 0);
}
void zig(pitem& x) {
pitem y = x->l;
x->l = y->r, y->r = x, x = y;
update(x), update(x->r);
}
void zag(pitem& x) {
pitem y = x->r;
x->r = y->l, y->l = x, x = y;
update(x), update(x->l);
}
void insert(pitem& x, int y) {
if (!x)
return x = new item(y), void();
if (x->key == y)
return ++ x->cnt, update(x), void();
if (y < x->key) {
insert(x->l, y);
if (x->l->prior > x->prior) zig(x);
} else {
insert(x->r, y);
if (x->r->prior > x->prior) zag(x);
}
update(x);
}
void remove(pitem& x, int y) {
if (y < x->key) remove(x->l, y);
else if (y > x->key) remove(x->r, y);
else {
if (x->cnt > 1) -- x->cnt;
else if (!x->l) x = x->r;
else if (!x->r) x = x->l;
else {
zag(x);
remove(x->l, y);
if (x->l && x->l->prior > x->prior)
zig(x);
}
}
if (x) update(x);
}
pitem getPre(int v) {
pitem x = root, ans = new item(-1e9);
while (x) {
if (v == x->key)
if (x->l) {
x = x->l;
while (x->r) x = x->r;
ans = x;
}
if (x->key < v && x->key > ans->key)
ans = x;
x = v < x->key ? x->l : x->r;
}
return ans;
}
pitem getNxt(int v) {
pitem x = root, ans = new item(1e9);
while (x) {
if (v == x->key)
if (x->r) {
x = x->r;
while (x->l) x = x->l;
}
if (x->key > v && x->key < ans->key)
ans = x;
x = v > x->key ? x->r : x->l;
}
return ans;
}
int getValByRank(pitem& x, int rank) {
if (!x) return 1e9;
if ((x->l ? x->l->size : 0) >= rank)
return getValByRank(x->l, rank);
if ((x->l ? x->l->size : 0) + x->cnt >= rank)
return x->key;
return getValByRank(x->r, rank - (x->l ? x->l->size : 0) - x->cnt);
}
int getRankByVal(pitem& x, int v) {
if (!x) return 0;
if (v == x->key) return (x->l ? x->l->size : 0) + 1;
if (v < x->key) return getRankByVal(x->l, v);
return getRankByVal(x->r, v) + (x->l ? x->l->size : 0) + x->cnt;
}
int main() {
freopen("phs.in", "r", stdin);
freopen("phs.out", "w", stdout);
root = new item(1e9), root->r = new item(-1e9);
int n; std::cin >> n;
while (n --) {
int op, x;
std::cin >> op >> x;
switch (op) {
case 1:
insert(root, x);
break;
case 2:
remove(root, x);
break;
case 3:
std::cout << getRankByVal(root, x) << '\n';
break;
case 4:
std::cout << getValByRank(root, x) << '\n';
break;
case 5:
std::cout << getPre(x)->key << '\n';
break;
case 6:
std::cout << getNxt(x)->key << '\n';
break;
}
}
return 0;
}