比赛 组合计数1 评测结果 AAAAAAAAAAAAAAAAWWWW
题目名称 组合数问题 最终得分 80
用户昵称 xuyuqing 运行时间 2.067 s
代码语言 C++ 内存使用 31.54 MiB
提交时间 2026-02-26 12:16:58
显示代码纯文本
#include <algorithm>
#include <cstdio>
#include <iostream>

using namespace std;

const int N = 1919810;

long long n;
long long p;
long long k;
long long r;
long long fact[N];
long long inv[N];
long long res;

long long po (long long a, long long n) {
    long long ans = 1;
    while (n) {
        if (n & 1) {
            ans = ans * a % p;
        }
        a = a * a % p;
        n >>= 1;
    }
    return ans;
}

long long C (long long m, long long n) {
    if (m > n) {
        return 0;
    }
    return fact[n] * inv[m] % p * inv[n - m] % p;
}

int main () {
    
    freopen ("problem.in", "r", stdin);
    freopen ("problem.out", "w", stdout);
    
    cin >> n >> p >> k >> r;
    
    long long maxn = min ((long long) N, p) - 1;
    
    fact[0] = inv[0] = 1;
    for (int i = 1; i <= maxn; i++) {
        fact[i] = fact[i - 1] * i % p;
    }
    inv[maxn] = po (fact[maxn], p - 2);
    for (int i = maxn - 1; i >= 1; i--) {
        inv[i] = inv[i + 1] * (i + 1) % p;
    }
    
    if (n * k > 1000000) {
        if (p == 2) {
            cout << 0 << endl;
        }
        else {
            if (k == 1) {
                res = po (2, n * k);
                for (int i = 0; i < r; i++) {
                    res = (res + p - C (i, n * k)) % p;
                }
                cout << res << endl;
            }
            if (k == 2) {
                res = po (2, n * k - 1);
                if (r & 1) {
                    for (int i = 1; i < r; i += 2) {
                        res = (res + p - C (i, n * k)) % p;
                    }   
                }
                else {
                    for (int i = 0; i < r; i += 2) {
                        res = (res + p - C (i, n * k)) % p;
                    }   
                }
                
                cout << res << endl;
            }
        }
        
        return 0;
    }
    
    for (long long i = r; i <= n * k; i += k) {
        res = (res + C (i, n * k)) % p;
    }
    
    cout << res << endl;
    
    return 0;
}