比赛 20241128 评测结果 AAAAAAAAAAAAAAAAAAAA
题目名称 猴猴的比赛 最终得分 100
用户昵称 darkMoon 运行时间 1.757 s
代码语言 C++ 内存使用 15.46 MiB
提交时间 2024-11-28 09:25:52
显示代码纯文本
#include<bits/stdc++.h>
#define int long long
using namespace std;
auto IN = freopen("monkeyclim.in", "r", stdin);
auto OUT = freopen("monkeyclim.out", "w", stdout);
auto mread = [](){int x;scanf("%lld", &x);return x;};
const int N = 1e5 + 5;
int n = mread();
vector<int> v1[N], v2[N];
struct node{
    int a, b;
}dfn[N], siz[N];
struct Q{
    int p, li, form;
}q[N + N + N + N];
int idx, a[N];
void dfs1(int x, int fa){
    dfn[x].a = ++idx;
    siz[x].a = 1;
    for(int y : v1[x]){
        if(y == fa){
            continue;
        }
        dfs1(y, x);
        siz[x].a += siz[y].a;
    }
    return;
}
void dfs2(int x, int fa){
    dfn[x].b = ++idx;
    a[dfn[x].a] = dfn[x].b;
    siz[x].b = 1;
    for(int y : v2[x]){
        if(y == fa){
            continue;
        }
        dfs2(y, x);
        siz[x].b += siz[y].b;
    }
    return;
}
int s[N];
int query(int x){
    int ans = 0;
    while(x){
        ans += s[x];
        x -= x & -x;
    }
    return ans;
}
void add(int x, int k){
    while(x <= n){
        s[x] += k;
        x += x & -x;
    }
    return;
}
signed main(){
    for(int i = 1, x, y; i < n; i ++){
        cin >> x >> y;
        v1[x].push_back(y);
        v1[y].push_back(x);
    }
    for(int i = 1, x, y; i < n; i ++){
        cin >> x >> y;
        v2[x].push_back(y);
        v2[y].push_back(x);
    }
    dfs1(1, 0);
    idx = 0;
    dfs2(1, 0);
    int ans = 0;
    idx = 0;
    for(int i = 1; i <= n; i ++){
        idx ++;
        q[idx] = {dfn[i].a, dfn[i].b, 1};
        idx ++;
        q[idx] = {dfn[i].a, dfn[i].b + siz[i].b - 1, -1};

        idx ++;
        q[idx] = {dfn[i].a + siz[i].a - 1, dfn[i].b, -1};
        idx ++;
        q[idx] = {dfn[i].a + siz[i].a - 1, dfn[i].b + siz[i].b - 1, 1};
    }
    sort(q + 1, q + 1 + idx, [](Q a, Q b){
        return a.p < b.p;
    });
    int now = 1;
    for(int i = 1; i <= n; i ++){
        add(a[i], 1);
        while(now <= idx && q[now].p == i){
            ans += q[now].form * query(q[now].li);
            now ++;
        }
    }
    // for(int i = 1; i <= n; i ++){
    //     for(int j = dfn[i].a + 1; j <= dfn[i].a + siz[i].a - 1; j ++){
    //         if(a[j] > dfn[i].b && a[j] <= dfn[i].b + siz[i].b - 1){
    //             ans ++;
    //         }
    //     }
    // }
    printf("%lld", ans);
    return 0;
}