比赛 CSP2023-S模拟赛 评测结果 AAAAAAAAAAAAAAAAAAAA
题目名称 坡伊踹 最终得分 100
用户昵称 zxhhh 运行时间 17.743 s
代码语言 C++ 内存使用 46.36 MiB
提交时间 2023-10-17 19:09:31
显示代码纯文本
#include <bits/stdc++.h>
#define mp make_pair
#define fir first
#define sec second
#define pb push_back

using namespace std;
const int N = 2e5+5; 
typedef long long ll;
typedef pair <int, int> pii; 
int n, q, a[N], f[N][25], va[N][25], d2[N];
ll d[N]; 
vector <pii> e[N]; 

void dfs (int p, int fa) {
    f[p][0] = fa; d2[p] = d2[fa]+1; va[p][0] = min(a[p], a[fa]); 
    for (int i = 1;i <= 20;i++) f[p][i] = f[f[p][i-1]][i-1], va[p][i] = min(va[p][i-1], va[f[p][i-1]][i-1]); 
    for (auto i : e[p]) {
        if (i.fir == fa) continue; 
        d[i.fir] = d[p]+i.sec, dfs(i.fir, p); 
    }
}

int LCA (int x, int y) {
    if (d2[x] < d2[y]) swap(x, y); 
    for (int i = 20;~i;i--) {
        if (d2[f[x][i]] >= d2[y]) x = f[x][i]; 
    }
    if (x == y) return x;
    for (int i = 20;~i;i--) {
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; 
    }
    return f[x][0]; 
}

bool check (int u, int v, int lca, int k) {
    int minn = a[u], x = u; 
    for (int i = 20;~i;i--) {
        if (d2[f[x][i]] >= d2[lca] && d[u]-d[f[x][i]] <= k) minn = min(minn, va[x][i]), x = f[x][i]; 
    }
    int y = v; 
    for (int i = 20;~i;i--) {
        if (d2[f[y][i]] >= d2[lca] && d[u]-d[lca]+d[f[y][i]]-d[lca] > k) y = f[y][i]; 
    }
    if (d[u]-d[lca]+d[y]-d[lca] > k) y = f[y][0]; 
    for (int i = 20;~i;i--) {
        while (d2[f[y][i]] >= d2[lca]) minn = min(minn, va[y][i]), y = f[y][i]; 
    }
    return minn <= k;
}

int main () {
    freopen("poitry.in", "r", stdin); 
    freopen("poitry.out", "w", stdout); 
    //Genshin Start
    scanf("%d%d", &n, &q); 
    for (int i = 1;i <= n;i++) scanf("%d", &a[i]); 
    for (int i = 1;i < n;i++) {
        int x, y, v; scanf("%d%d%d", &x, &y, &v); 
        e[x].pb(mp(y, v)); e[y].pb(mp(x, v)); 
    }
    dfs(1, 0); 
    while (q--) {
        int u, v; scanf("%d%d", &u, &v); 
        int l = 0, r = 1e9, lca = LCA(u, v); 
        while (l < r) {
            int mid = l+r>>1; 
            if (check(u, v, lca, mid)) r = mid;
            else l = mid+1; 
        }
        printf("%d\n", l);
    }
    return 0;
}