记录编号 613045 评测结果 AAAAA
题目名称 十二重计数法(第一关) 最终得分 100
用户昵称 Gravatar终焉折枝 是否通过 通过
代码语言 C++ 运行时间 0.286 s
提交时间 2026-02-27 16:26:09 内存使用 19.78 MiB
显示代码纯文本
#include<iostream>
using namespace std;

#define ll long long
const int MOD = 998244353;
const int MAXN = 1e6 + 5;
const int G = 3;
const int Gi = 332748118;
ll n, m;
ll fact[MAXN], inv[MAXN];
int rev[MAXN << 2];
ll S2[MAXN], Polya[MAXN << 2], Polyb[MAXN << 2];
ll tmp_inv[MAXN << 2], ln_tmp1[MAXN << 2], ln_tmp2[MAXN << 2];
ll exp_tmp[MAXN << 2], PolyG[MAXN << 2], PolyF[MAXN << 2];

ll qpow(ll a, ll b){
	ll res = 1;
	while(b){
		if(b & 1) res = res * a % MOD;
		a = a * a % MOD;
		b >>= 1;
	}
	return res;
}

ll C(ll n, ll m){
	if(m < 0 || m > n) return 0;
	return fact[n] * inv[m] % MOD * inv[n - m] % MOD;
}

void NTT(ll A[], int len, int op){
	for(int i = 0;i < len;i ++) if(i < rev[i]) swap(A[i], A[rev[i]]);
	for(int mid = 1;mid < len;mid <<= 1){
		ll Wn = qpow(op == 1 ? G : Gi, (MOD - 1) / (mid << 1));
		for(int i = 0;i < len;i += (mid << 1)){
			ll wk = 1;
			for(int j = 0;j < mid;j ++, wk = wk * Wn % MOD){
				ll x = A[i + j], y = wk * A[i + j + mid] % MOD;
				A[i + j] = (x + y) % MOD;
				A[i + j + mid] = (x - y + MOD) % MOD;
			}
		}
	}
	if(op == -1){
		ll invL = qpow(len, MOD - 2);
		for(int i = 0;i < len;i ++) A[i] = A[i] * invL % MOD;
	}
}

void NTT_init(int len){
	int L = 0; while((1 << L) < len) L ++;
	for(int i = 0;i < len;i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
}

void init(){
	fact[0] = 1;
	for(int i = 1;i < MAXN;i ++){
		fact[i] = fact[i - 1] * i % MOD;
	}
	inv[MAXN - 1] = qpow(fact[MAXN - 1], MOD - 2);
	for(int i = MAXN - 2;i >= 0;i --){
		inv[i] = inv[i + 1] * (i + 1) % MOD;
	}
	int len = 1;
	while(len <= 2 * m) len <<= 1;
	NTT_init(len);
	for(int i = 0;i < len;i ++) Polya[i] = Polyb[i] = 0;
	ll flag = 1;
	for(int i = 0;i <= m;i ++){
		Polya[i] = inv[i] * flag % MOD;
		Polyb[i] = qpow(i, n) * inv[i] % MOD;
		flag = (MOD - flag);
	}
	NTT(Polya, len, 1);
	NTT(Polyb, len, 1);
	for(int i = 0;i < len;i ++){
		Polya[i] = Polya[i] * Polyb[i] % MOD;
	}
	NTT(Polya, len, -1);
	for(int i = 0;i <= m;i ++){
		S2[i] = Polya[i];
	}
//	cout << S2[m] << '\n';
}

void PolyInv(ll A[], ll B[], int N){
	if(N == 1){
		B[0] = qpow(A[0], MOD - 2);
		return;
	}
	PolyInv(A, B, (N + 1) >> 1);
	int len = 1;
	while(len < (N << 1)) len <<= 1;
	NTT_init(len);
	for(int i = 0;i < N;i ++) tmp_inv[i] = A[i];
	for(int i = N;i < len;i ++) tmp_inv[i] = 0;
	NTT(tmp_inv,  len, 1);
	NTT(B, len, 1);
	for(int i = 0;i < len;i ++){
		B[i] = B[i] * (2 - tmp_inv[i] * B[i] % MOD + MOD) % MOD;
	}
	NTT(B, len, -1);
	for(int i = N;i < len;i ++) B[i] = 0;
}

void PolyLn(ll A[], ll B[], int N){
	int len = 1;
	while(len < (N << 1)) len <<= 1;
	for(int i = 0;i < len;i ++) ln_tmp1[i] = ln_tmp2[i] = 0;
	PolyInv(A, ln_tmp1, N);
	for(int i = 1;i < N;i ++) ln_tmp2[i - 1] = A[i] * i % MOD;
	NTT_init(len);
	NTT(ln_tmp1, len, 1);
	NTT(ln_tmp2, len, 1);
	for(int i = 0;i < len;i ++) ln_tmp1[i] = ln_tmp1[i] * ln_tmp2[i] % MOD;
	NTT(ln_tmp1, len, -1);
	B[0] = 0;
	for(int i = 1;i < N;i ++) B[i] = ln_tmp1[i - 1] * qpow(i, MOD -  2) % MOD;
	for(int i = N;i < len;i ++) B[i] = 0;
}

void PolyExp(ll A[], ll B[], int N){
	if(N == 1){
		B[0] = 1;
		return;
	}
	PolyExp(A, B, (N + 1) >> 1);
	for(int i = 0;i < (N << 1);i ++) exp_tmp[i] = 0;
	PolyLn(B, exp_tmp, N);
	int len = 1;
	while(len < (N << 1)) len <<= 1;
	NTT_init(len);
	for(int i = 0;i < N;i ++) exp_tmp[i] = (A[i] - exp_tmp[i] + MOD) % MOD;
	exp_tmp[0] = (exp_tmp[0] + 1) % MOD;
	for(int i = N;i < len;i ++) exp_tmp[i] = 0;
	NTT(B, len, 1);
	NTT(exp_tmp, len, 1);
	for(int i = 0;i < len;i ++) B[i] = B[i] * exp_tmp[i] % MOD;
	NTT(B, len, -1);
	for(int i = N;i < len;i ++) B[i] = 0;
}

void EXP_init(){
	for(int i = 1;i <= m;i ++){
		for(int j = 1;i * j <= n;j ++){
			PolyG[i * j] = (PolyG[i * j] + qpow(j, MOD - 2)) % MOD;
		}
	}
	PolyExp(PolyG, PolyF, n + 1);
}


void solve1(){
	cout << qpow(m, n) << '\n';
}

void solve2(){
	if(n > m) cout << 0 << '\n';
	else cout << C(m, n) * fact[n] % MOD << '\n';
}

void solve3(){
	ll ans = 0;
	ll flag = 1;
	for(int i = 0;i <= m;i ++){
		ll cnt = flag * C(m, i) % MOD;
		cnt = cnt * qpow(m - i, n) % MOD;
		ans = (ans + cnt + MOD) % MOD;
		flag = (MOD - flag);
	}
	cout << ans << '\n';
}

void solve4(){
	ll ans = 0;
	for(int i = 1;i <= m;i ++){
		ans = (ans + S2[i]) % MOD;
	}
	cout << ans << '\n';
}

void solve5(){
	cout << (n <= m) << '\n';
}

void solve6(){
	cout << S2[m] << '\n';
}

void solve7(){
	cout << C(n + m - 1, m - 1) << '\n';
}

void solve8(){
	cout << C(m, n) << '\n';
}

void solve9(){
	cout << C(n - 1, m - 1) << '\n';
}

void solve10(){
	cout << PolyF[n] << '\n';
}

void solve11(){
	cout << (n <= m) << '\n';
}

void solve12(){
	if(n <= m) cout << 0 << '\n';
	else cout << PolyF[n - m] << '\n';
}

int main(){
	cin >> n >> m;
	init();
	EXP_init();
	solve1();
	solve2();
	solve3();
	solve4();
	solve5();
	solve6();
	solve7();
	solve8();
	solve9();
	solve10();
	solve11();
	solve12();
	return 0;
}