比赛 20191022轻松模拟测试 评测结果 WWAWWTWWWW
题目名称 原谅 最终得分 10
用户昵称 djj 运行时间 3.380 s
代码语言 C++ 内存使用 59.44 MiB
提交时间 2019-10-22 17:26:02
显示代码纯文本
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>

using namespace std;
const int maxn = 1e6 + 10;

inline int read() {
    char c = getchar(); int x = 0, f = 1;
    for (; c > '9' || c < '0'; c = getchar()) if (c == '-') f = -1;
    for (; c >='0' && c <='9'; c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
    return x * f;
}

int ver[maxn << 1], nxt[maxn << 1], head[maxn], tot;
int d[maxn], z[maxn], aa[maxn], du[maxn], n, k, ans;

struct Point {
    int a_, id;
} a[maxn];

bool cmp (const Point &a, const Point &b) {
    return a.a_ == b.a_ ? a.id < b.id : a.a_ > b.a_;
}

int fa[maxn];
inline int find (int x) {
    return x == fa[x] ? x : fa[x] = find (fa[x]);
}

void add (int u, int v) {
    ver[++ tot] = v, nxt[tot] = head[u], head[u] = tot;
    ver[++ tot] = u, nxt[tot] = head[v], head[v] = tot;
}

void djj_lxy () {
    n = read(), k = read();
    if (n == k) {
        printf ("%d\n", n);
        return ;
    }
    for (register int i = 1; i <= n; i ++)
        a[i].a_ = aa[i] = read(), a[i].id = i;
    for (register int i = 1; i < n; i ++)
        add (read () + 1, read () + 1);
    sort (a + 1, a + n + 1, cmp);
    d[n] = n;
    for (register int i = n - 1; i >= 1; i --)
        d[i] = (a[i].a_ == a[i + 1].a_ ? d[i + 1] : i);
    z[1] = 1;
    for (register int i = 2; i <= n; i ++)
        z[i] = (a[i].a_ == a[i - 1].a_ ? z[i - 1] : i);
    int l = 1, r = n;
    for (; l < r; ) {
        int mid = l + r >> 1;
        mid = d[mid];
        memset (du, 0, sizeof du);
        for (register int i = 1; i <= n; i ++)
            fa[i] = i;
        for (register int i = 1; i <= mid; i ++)
            for (register int j = head[a[i].id]; j; j = nxt[j])
                if (aa[ver[j]] >= a[mid].a_)
                    fa[find (a[i].id)] = find (ver[j]), du[i] ++;
        int root = find (a[1].id), s = 0;
        for (register int i = 1; i <= mid; i ++)
            s += find (a[i].id) == root;
        bool is = 0;
        for (register int i = 2; a[i].a_ > a[mid].a_; i ++)
            if (a[i].a_ == a[i - 1].a_ && find (a[i].id) != find (a[i - 1].id)) {
                is = 1;
                break ;
            }
        if (is) {
            r = mid - 1;
            continue ;
        }
        for (register int i = 1; i <= mid; i ++)
            if (a[i].a_ > a[mid].a_ && find (a[i].id) != root) {
                is = 1;
                break ;
            }
        if (is) {
            r = mid - 1;
            continue ;
        }
        if (s > k)
            for (register int i = 1; i <= mid; i ++) {
                if (du[i] == 1 && a[i].a_ == a[mid].a_)
                    s --;
                if (s == k)
                    break ;
            }
        if (s == k) {
            printf ("%d\n", k);
            return ;
        }
        if (s > k) r = z[mid];
        else l = d[mid] + 1, ans = max (ans, s);
    }
    printf ("%d\n", ans);
}

int main () {
    freopen ("green.in", "r", stdin);
    freopen ("green.out", "w", stdout);
    djj_lxy ();
}

/*
7 5
6 2 7 5 6 5 2
3 1
1 0
0 2
2 4
4 5
4 6

4
*/