| 记录编号 |
613045 |
评测结果 |
AAAAA |
| 题目名称 |
十二重计数法(第一关) |
最终得分 |
100 |
| 用户昵称 |
终焉折枝 |
是否通过 |
通过 |
| 代码语言 |
C++ |
运行时间 |
0.286 s |
| 提交时间 |
2026-02-27 16:26:09 |
内存使用 |
19.78 MiB |
显示代码纯文本
#include<iostream>
using namespace std;
#define ll long long
const int MOD = 998244353;
const int MAXN = 1e6 + 5;
const int G = 3;
const int Gi = 332748118;
ll n, m;
ll fact[MAXN], inv[MAXN];
int rev[MAXN << 2];
ll S2[MAXN], Polya[MAXN << 2], Polyb[MAXN << 2];
ll tmp_inv[MAXN << 2], ln_tmp1[MAXN << 2], ln_tmp2[MAXN << 2];
ll exp_tmp[MAXN << 2], PolyG[MAXN << 2], PolyF[MAXN << 2];
ll qpow(ll a, ll b){
ll res = 1;
while(b){
if(b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
ll C(ll n, ll m){
if(m < 0 || m > n) return 0;
return fact[n] * inv[m] % MOD * inv[n - m] % MOD;
}
void NTT(ll A[], int len, int op){
for(int i = 0;i < len;i ++) if(i < rev[i]) swap(A[i], A[rev[i]]);
for(int mid = 1;mid < len;mid <<= 1){
ll Wn = qpow(op == 1 ? G : Gi, (MOD - 1) / (mid << 1));
for(int i = 0;i < len;i += (mid << 1)){
ll wk = 1;
for(int j = 0;j < mid;j ++, wk = wk * Wn % MOD){
ll x = A[i + j], y = wk * A[i + j + mid] % MOD;
A[i + j] = (x + y) % MOD;
A[i + j + mid] = (x - y + MOD) % MOD;
}
}
}
if(op == -1){
ll invL = qpow(len, MOD - 2);
for(int i = 0;i < len;i ++) A[i] = A[i] * invL % MOD;
}
}
void NTT_init(int len){
int L = 0; while((1 << L) < len) L ++;
for(int i = 0;i < len;i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
}
void init(){
fact[0] = 1;
for(int i = 1;i < MAXN;i ++){
fact[i] = fact[i - 1] * i % MOD;
}
inv[MAXN - 1] = qpow(fact[MAXN - 1], MOD - 2);
for(int i = MAXN - 2;i >= 0;i --){
inv[i] = inv[i + 1] * (i + 1) % MOD;
}
int len = 1;
while(len <= 2 * m) len <<= 1;
NTT_init(len);
for(int i = 0;i < len;i ++) Polya[i] = Polyb[i] = 0;
ll flag = 1;
for(int i = 0;i <= m;i ++){
Polya[i] = inv[i] * flag % MOD;
Polyb[i] = qpow(i, n) * inv[i] % MOD;
flag = (MOD - flag);
}
NTT(Polya, len, 1);
NTT(Polyb, len, 1);
for(int i = 0;i < len;i ++){
Polya[i] = Polya[i] * Polyb[i] % MOD;
}
NTT(Polya, len, -1);
for(int i = 0;i <= m;i ++){
S2[i] = Polya[i];
}
// cout << S2[m] << '\n';
}
void PolyInv(ll A[], ll B[], int N){
if(N == 1){
B[0] = qpow(A[0], MOD - 2);
return;
}
PolyInv(A, B, (N + 1) >> 1);
int len = 1;
while(len < (N << 1)) len <<= 1;
NTT_init(len);
for(int i = 0;i < N;i ++) tmp_inv[i] = A[i];
for(int i = N;i < len;i ++) tmp_inv[i] = 0;
NTT(tmp_inv, len, 1);
NTT(B, len, 1);
for(int i = 0;i < len;i ++){
B[i] = B[i] * (2 - tmp_inv[i] * B[i] % MOD + MOD) % MOD;
}
NTT(B, len, -1);
for(int i = N;i < len;i ++) B[i] = 0;
}
void PolyLn(ll A[], ll B[], int N){
int len = 1;
while(len < (N << 1)) len <<= 1;
for(int i = 0;i < len;i ++) ln_tmp1[i] = ln_tmp2[i] = 0;
PolyInv(A, ln_tmp1, N);
for(int i = 1;i < N;i ++) ln_tmp2[i - 1] = A[i] * i % MOD;
NTT_init(len);
NTT(ln_tmp1, len, 1);
NTT(ln_tmp2, len, 1);
for(int i = 0;i < len;i ++) ln_tmp1[i] = ln_tmp1[i] * ln_tmp2[i] % MOD;
NTT(ln_tmp1, len, -1);
B[0] = 0;
for(int i = 1;i < N;i ++) B[i] = ln_tmp1[i - 1] * qpow(i, MOD - 2) % MOD;
for(int i = N;i < len;i ++) B[i] = 0;
}
void PolyExp(ll A[], ll B[], int N){
if(N == 1){
B[0] = 1;
return;
}
PolyExp(A, B, (N + 1) >> 1);
for(int i = 0;i < (N << 1);i ++) exp_tmp[i] = 0;
PolyLn(B, exp_tmp, N);
int len = 1;
while(len < (N << 1)) len <<= 1;
NTT_init(len);
for(int i = 0;i < N;i ++) exp_tmp[i] = (A[i] - exp_tmp[i] + MOD) % MOD;
exp_tmp[0] = (exp_tmp[0] + 1) % MOD;
for(int i = N;i < len;i ++) exp_tmp[i] = 0;
NTT(B, len, 1);
NTT(exp_tmp, len, 1);
for(int i = 0;i < len;i ++) B[i] = B[i] * exp_tmp[i] % MOD;
NTT(B, len, -1);
for(int i = N;i < len;i ++) B[i] = 0;
}
void EXP_init(){
for(int i = 1;i <= m;i ++){
for(int j = 1;i * j <= n;j ++){
PolyG[i * j] = (PolyG[i * j] + qpow(j, MOD - 2)) % MOD;
}
}
PolyExp(PolyG, PolyF, n + 1);
}
void solve1(){
cout << qpow(m, n) << '\n';
}
void solve2(){
if(n > m) cout << 0 << '\n';
else cout << C(m, n) * fact[n] % MOD << '\n';
}
void solve3(){
ll ans = 0;
ll flag = 1;
for(int i = 0;i <= m;i ++){
ll cnt = flag * C(m, i) % MOD;
cnt = cnt * qpow(m - i, n) % MOD;
ans = (ans + cnt + MOD) % MOD;
flag = (MOD - flag);
}
cout << ans << '\n';
}
void solve4(){
ll ans = 0;
for(int i = 1;i <= m;i ++){
ans = (ans + S2[i]) % MOD;
}
cout << ans << '\n';
}
void solve5(){
cout << (n <= m) << '\n';
}
void solve6(){
cout << S2[m] << '\n';
}
void solve7(){
cout << C(n + m - 1, m - 1) << '\n';
}
void solve8(){
cout << C(m, n) << '\n';
}
void solve9(){
cout << C(n - 1, m - 1) << '\n';
}
void solve10(){
cout << PolyF[n] << '\n';
}
void solve11(){
cout << (n <= m) << '\n';
}
void solve12(){
if(n <= m) cout << 0 << '\n';
else cout << PolyF[n - m] << '\n';
}
int main(){
cin >> n >> m;
init();
EXP_init();
solve1();
solve2();
solve3();
solve4();
solve5();
solve6();
solve7();
solve8();
solve9();
solve10();
solve11();
solve12();
return 0;
}