| 记录编号 |
616868 |
评测结果 |
AAAAAAAAAA |
| 题目名称 |
4316.and I am home |
最终得分 |
100 |
| 用户昵称 |
RpUtl |
是否通过 |
通过 |
| 代码语言 |
C++ |
运行时间 |
0.033 s |
| 提交时间 |
2026-07-02 11:20:37 |
内存使用 |
3.74 MiB |
显示代码纯文本
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
const int N = 2e6 + 10;
typedef long long ll;
const int G = 3;
ll R[N];
ll ksm(ll a, ll b) {
if (b < 0) b += mod - 1;
ll ans = 1;
while (b) {
if (b & 1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
void init(int n) {
for (int i = 0; i < n; i++) {
R[i] = (R[i >> 1] >> 1) + ((i & 1) ? (n >> 1) : 0);
}
}
void NTT(ll *f, int n, int ok) {
for (int i = 0; i < n; i++) {
if (R[i] < i) swap(f[i], f[R[i]]);
}
for (int i = 2; i <= n; i <<= 1) {
int g = ksm(G, (mod - 1) / i * ok);
for (int j = 0; j < n; j += i) {
ll w = 1, u, v;
for (int k = j; k < j + (i >> 1); k++) {
u = f[k], v = f[k + (i >> 1)] * w % mod;
f[k] = (u + v) % mod, f[k + (i >> 1)] = (u - v + mod) % mod;
w = w * g % mod;
}
}
}
if (ok == -1) {
int inv = ksm(n, mod - 2);
for (int i = 0; i < n; i++) {
f[i] = f[i] * inv % mod;
}
}
}
ll t[N], h[N], m;
void inv(ll *f, int n) {
h[0] = ksm(f[0], mod - 2);
for (m = 1; m <= n; m <<= 1);
for (int i = 2, L; i <= m; i <<= 1) {
for (int j = 0; j < i; j++) t[j] = f[j];
for (int j = (i >> 1); j < i; j++) h[j] = 0;
L = (i << 1); init(L); NTT(t, L, 1), NTT(h, L, 1);
for (int j = 0; j < L; j++) {
h[j] = h[j] * (2ll - h[j] * t[j] % mod + mod) % mod;
}
NTT(h, L, -1);
}
for (int i = 0; i < n; i++) f[i] = h[i];
return;
}
ll g[N], f[N], n, fac[N], invf[N], ans, pw[N];
ll C(ll n, ll m) {
return fac[n] * invf[m] % mod * invf[n - m] % mod;
}
int main() {
freopen("home.in", "r", stdin);
freopen("home.out", "w", stdout);
cin >> n;
fac[0] = invf[0] = pw[0] = 1;
for (int i = 1; i <= n; i++) {
fac[i] = fac[i - 1] * i % mod;
invf[i] = ksm(fac[i], mod - 2);
pw[i] = pw[i - 1] * 4 % mod;
}
for (int i = 1; i <= n; i++) {
if (i & 1) continue;
g[i] = C(i, i / 2) * C(i, i / 2) % mod;
}
for (int i = 1; i <= n; i++) f[i] = g[i];
f[0]++; inv(f, n); int L;
for (L = 1; L <= (n << 1); L <<= 1);
init(L), NTT(g, L, 1), NTT(f, L, 1);
for (int i = 0; i < L; i++) f[i] = f[i] * g[i] % mod;
NTT(f, L, -1);
ans = (n + 1) * pw[n] % mod;
for (int i = 1; i <= n; i++) {
ans -= f[i] * pw[n - i] % mod * (n - i + 1) % mod;
ans %= mod;
}
ans = (ans % mod + mod) % mod;
cout << ans << '\n';
return 0;
}