比赛 2023级模拟测试1 评测结果 AAAAAAAAAA
题目名称 Sum of k_mex 最终得分 100
用户昵称 zxhhh 运行时间 3.321 s
代码语言 C++ 内存使用 13.05 MiB
提交时间 2023-09-05 20:54:59
显示代码纯文本
#include <bits/stdc++.h>

using namespace std;
const int N = 1e5+5, mod = 998244353;
typedef long long ll;
int n, a[N], p[N]; 
ll ans;
struct segment_tree {
    ll sl[N<<2], sr[N<<2], slr[N<<2], ml[N<<2], mr[N<<2];
    void upd (int p, int l, int r, int op, ll k) {
        if (op == 1) {
            ml[p] = k; slr[p] = sr[p]*k%mod; 
            sl[p] = k*(r-l+1)%mod;
        }
        else {
            mr[p] = k; slr[p] = sl[p]*k%mod;
            sr[p] = k*(r-l+1)%mod;
        }
    } 
    void push_down (int p, int l, int r) {
        int mid = l+r>>1;
        if (ml[p]) upd(p<<1, l, mid, 1, ml[p]); 
        if (mr[p]) upd(p<<1, l, mid, 2, mr[p]); 
        if (ml[p]) upd(p<<1|1, mid+1, r, 1, ml[p]); 
        if (mr[p]) upd(p<<1|1, mid+1, r, 2, mr[p]); 
        ml[p] = mr[p] = 0;
    }
    void push_up (int p) {
        sl[p] = (sl[p<<1]+sl[p<<1|1])%mod;
        sr[p] = (sr[p<<1]+sr[p<<1|1])%mod;
        slr[p] = (slr[p<<1]+slr[p<<1|1])%mod;
    }
    void update (int p, int l, int r, int idx, ll nl, ll nr) {
        if (l == r) {
            sl[p] = nl; sr[p] = nr; 
            slr[p] = nl*nr%mod;
            return;
        }
        int mid = l+r>>1; push_down(p, l, r); 
        if (idx <= mid) update(p<<1, l, mid, idx, nl, nr); 
        else update(p<<1|1, mid+1, r, idx, nl, nr); 
        push_up(p); 
    }
    ll query (int p, int l, int r, int idx, int op) {
        if (l == r) {
            if (op == 1) return sl[p];
            else return sr[p];
        }
        int mid = l+r>>1; push_down(p, l, r); 
        if (idx <= mid) return query(p<<1, l, mid, idx, op); 
        return query(p<<1|1, mid+1, r, idx, op); 
    }
    void add (int p, int l, int r, int tl, int tr, int op, ll k) {
        if (l >= tl && r <= tr) {
            if (op == 1) upd(p, l, r, op, k);
            else upd(p, l, r, op, k); 
            return;
        }
        int mid = l+r>>1; push_down(p, l, r); 
        if (tl <= mid) add(p<<1, l, mid, tl, tr, op, k); 
        if (tr > mid) add(p<<1|1, mid+1, r, tl, tr, op, k); 
        push_up (p); 
    }
    ll ql (int p, int l, int r, int tl, int tr, int op) {
        if (tl <= l && r <= tr) {
            if(op == 1) return sl[p]; 
            if (op == 2) return sr[p]; 
            return slr[p]; 
        }
        int mid = l+r>>1; push_down(p, l, r); ll res = 0;
        if (tl <= mid) res += ql(p<<1, l, mid, tl, tr, op); res %= mod;
        if (tr > mid) res += ql(p<<1|1, mid+1, r, tl, tr, op); res %= mod;
        return res;
    }
}seg;

int main () {
    freopen("k_mex.in", "r", stdin); 
    freopen("k_mex.out", "w", stdout); 
    scanf("%d", &n);    
    for (int i = 1;i <= n;i++) scanf("%d", &a[i]), p[a[i]] = i;
    ans += (ll)(p[1]-1)*p[1]/2%mod+(ll)(n-p[1])*(n-p[1]+1)/2%mod; ans %= mod;
//    cout << ans << endl;
    for (int i = 1;i <= n;i++) {
//        cout << i << endl;
        seg.update(1, 1, n, i, p[i], p[i]); 
//        cout << 1 << endl;
        if (i+1 <= n) ans += ((ll)(p[i+1]-1)*p[i+1]/2%mod+(ll)(n-p[i+1])*(n-p[i+1]+1)/2%mod)%mod*(i+1)%mod; 
//        else break;
        int l = 1, r = i;
        while (l < r) {
            int mid = (l+r)>>1; 
            if (seg.query(1, 1, n, mid, 1) >= p[i]) r = mid;
            else l = mid+1; 
        } 
//        cout << l << endl;
        seg.add(1, 1, n, l, i, 1, p[i]); 
        l = 1, r = i;
        while (l <r) {
            int mid = (l+r)>>1;
            if(seg.query(1, 1, n, mid, 2) <= p[i]) r = mid;
            else l = mid+1; 
        }
        seg.add(1, 1, n, l, i, 2, p[i]); 
        l = 1, r = i; int tl, tr;
        while (l < r) {
            int mid = (l+r)>>1;
            tl = seg.query(1, 1, n, mid, 1), tr = seg.query(1, 1, n, mid, 2); 
            if (tl > p[i+1] || tr < p[i+1]) r = mid;
            else l = mid+1; 
        }
        tl = seg.query(1, 1, n, l, 1), tr = seg.query(1, 1, n, l, 2); 
//        cout << tl <<" " << tr <<endl; 
        ll k = i-l+1, sl = seg.ql(1, 1, n, l, i, 1), sr = seg.ql(1, 1, n, l, i, 2), slr = seg.ql(1, 1, n, l, i, 3), g; 
//        cout << sl << " " << k << " " << sr <<' ' <<slr << endl;
        if (tl > p[i+1]) {
            g = (n*sl%mod)-slr+sl-k*n%mod*p[i+1]%mod+sr*p[i+1]%mod-k*p[i+1]%mod;
        } 
        else g = sl*p[i+1]%mod-slr; 
        g %= mod; g *= i+1; g %= mod;
        ans += g; ans %= mod;
//        for (int j = i;j >= 1;j--) {
//            l[j] = min(l[j], l[i]), r[j] = max(r[j], r[i]);
//            if (r[j] < p[i+1]) ans += (ll)(i+1)*l[j]%mod*(p[i+1]-r[j])%mod, ans %= mod;
//            else if (l[j] > p[i+1]) ans += (ll)(i+1)*(l[j]-p[i+1])%mod*(n-r[j]+1)%mod, ans %= mod;
//        } 
//          cout << ans <<endl;
    } 
    ans %= mod; ans = (ans+mod)%mod;
    cout << ans;
    return 0; 
}
/*
10
1 3 9 4 8 7 6 10 2 5 


*/