比赛 |
2023级模拟测试1 |
评测结果 |
AAAAAAAAAA |
题目名称 |
Sum of k_mex |
最终得分 |
100 |
用户昵称 |
zxhhh |
运行时间 |
3.321 s |
代码语言 |
C++ |
内存使用 |
13.05 MiB |
提交时间 |
2023-09-05 20:54:59 |
显示代码纯文本
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5+5, mod = 998244353;
typedef long long ll;
int n, a[N], p[N];
ll ans;
struct segment_tree {
ll sl[N<<2], sr[N<<2], slr[N<<2], ml[N<<2], mr[N<<2];
void upd (int p, int l, int r, int op, ll k) {
if (op == 1) {
ml[p] = k; slr[p] = sr[p]*k%mod;
sl[p] = k*(r-l+1)%mod;
}
else {
mr[p] = k; slr[p] = sl[p]*k%mod;
sr[p] = k*(r-l+1)%mod;
}
}
void push_down (int p, int l, int r) {
int mid = l+r>>1;
if (ml[p]) upd(p<<1, l, mid, 1, ml[p]);
if (mr[p]) upd(p<<1, l, mid, 2, mr[p]);
if (ml[p]) upd(p<<1|1, mid+1, r, 1, ml[p]);
if (mr[p]) upd(p<<1|1, mid+1, r, 2, mr[p]);
ml[p] = mr[p] = 0;
}
void push_up (int p) {
sl[p] = (sl[p<<1]+sl[p<<1|1])%mod;
sr[p] = (sr[p<<1]+sr[p<<1|1])%mod;
slr[p] = (slr[p<<1]+slr[p<<1|1])%mod;
}
void update (int p, int l, int r, int idx, ll nl, ll nr) {
if (l == r) {
sl[p] = nl; sr[p] = nr;
slr[p] = nl*nr%mod;
return;
}
int mid = l+r>>1; push_down(p, l, r);
if (idx <= mid) update(p<<1, l, mid, idx, nl, nr);
else update(p<<1|1, mid+1, r, idx, nl, nr);
push_up(p);
}
ll query (int p, int l, int r, int idx, int op) {
if (l == r) {
if (op == 1) return sl[p];
else return sr[p];
}
int mid = l+r>>1; push_down(p, l, r);
if (idx <= mid) return query(p<<1, l, mid, idx, op);
return query(p<<1|1, mid+1, r, idx, op);
}
void add (int p, int l, int r, int tl, int tr, int op, ll k) {
if (l >= tl && r <= tr) {
if (op == 1) upd(p, l, r, op, k);
else upd(p, l, r, op, k);
return;
}
int mid = l+r>>1; push_down(p, l, r);
if (tl <= mid) add(p<<1, l, mid, tl, tr, op, k);
if (tr > mid) add(p<<1|1, mid+1, r, tl, tr, op, k);
push_up (p);
}
ll ql (int p, int l, int r, int tl, int tr, int op) {
if (tl <= l && r <= tr) {
if(op == 1) return sl[p];
if (op == 2) return sr[p];
return slr[p];
}
int mid = l+r>>1; push_down(p, l, r); ll res = 0;
if (tl <= mid) res += ql(p<<1, l, mid, tl, tr, op); res %= mod;
if (tr > mid) res += ql(p<<1|1, mid+1, r, tl, tr, op); res %= mod;
return res;
}
}seg;
int main () {
freopen("k_mex.in", "r", stdin);
freopen("k_mex.out", "w", stdout);
scanf("%d", &n);
for (int i = 1;i <= n;i++) scanf("%d", &a[i]), p[a[i]] = i;
ans += (ll)(p[1]-1)*p[1]/2%mod+(ll)(n-p[1])*(n-p[1]+1)/2%mod; ans %= mod;
// cout << ans << endl;
for (int i = 1;i <= n;i++) {
// cout << i << endl;
seg.update(1, 1, n, i, p[i], p[i]);
// cout << 1 << endl;
if (i+1 <= n) ans += ((ll)(p[i+1]-1)*p[i+1]/2%mod+(ll)(n-p[i+1])*(n-p[i+1]+1)/2%mod)%mod*(i+1)%mod;
// else break;
int l = 1, r = i;
while (l < r) {
int mid = (l+r)>>1;
if (seg.query(1, 1, n, mid, 1) >= p[i]) r = mid;
else l = mid+1;
}
// cout << l << endl;
seg.add(1, 1, n, l, i, 1, p[i]);
l = 1, r = i;
while (l <r) {
int mid = (l+r)>>1;
if(seg.query(1, 1, n, mid, 2) <= p[i]) r = mid;
else l = mid+1;
}
seg.add(1, 1, n, l, i, 2, p[i]);
l = 1, r = i; int tl, tr;
while (l < r) {
int mid = (l+r)>>1;
tl = seg.query(1, 1, n, mid, 1), tr = seg.query(1, 1, n, mid, 2);
if (tl > p[i+1] || tr < p[i+1]) r = mid;
else l = mid+1;
}
tl = seg.query(1, 1, n, l, 1), tr = seg.query(1, 1, n, l, 2);
// cout << tl <<" " << tr <<endl;
ll k = i-l+1, sl = seg.ql(1, 1, n, l, i, 1), sr = seg.ql(1, 1, n, l, i, 2), slr = seg.ql(1, 1, n, l, i, 3), g;
// cout << sl << " " << k << " " << sr <<' ' <<slr << endl;
if (tl > p[i+1]) {
g = (n*sl%mod)-slr+sl-k*n%mod*p[i+1]%mod+sr*p[i+1]%mod-k*p[i+1]%mod;
}
else g = sl*p[i+1]%mod-slr;
g %= mod; g *= i+1; g %= mod;
ans += g; ans %= mod;
// for (int j = i;j >= 1;j--) {
// l[j] = min(l[j], l[i]), r[j] = max(r[j], r[i]);
// if (r[j] < p[i+1]) ans += (ll)(i+1)*l[j]%mod*(p[i+1]-r[j])%mod, ans %= mod;
// else if (l[j] > p[i+1]) ans += (ll)(i+1)*(l[j]-p[i+1])%mod*(n-r[j]+1)%mod, ans %= mod;
// }
// cout << ans <<endl;
}
ans %= mod; ans = (ans+mod)%mod;
cout << ans;
return 0;
}
/*
10
1 3 9 4 8 7 6 10 2 5
*/