比赛 组合计数1 评测结果 AAAAAAWEEEAAAAAAEEEE
题目名称 组合数问题 最终得分 60
用户昵称 LikableP 运行时间 9.086 s
代码语言 C++ 内存使用 16.28 MiB
提交时间 2026-02-26 10:08:15
显示代码纯文本
#include <cstdio>
#include <random>
typedef long long ll;

const int MAXN = 1e6 + 10;

bool isPrime(ll x) {
  for (ll i = 2; i * i <= x; ++i) {
    if (x % i == 0) return false;
  }
  return true;
}

ll kasumi(ll x, ll y, ll mod) {
  ll res = 1;
  while (y) {
    if (y & 1) res = res * x % mod;
    y >>= 1;
    x = x * x % mod;
  }
  return res;
}

ll n, p, k, r;
ll frac[MAXN], inv[MAXN];

ll exgcd(ll a, ll b, ll &x, ll &y) {
  if (b == 0) {
    x = 1, y = 0;
    return a;
  }
  ll d = exgcd(b, a % b, x, y);
  ll z = x;
  x = y;
  y = z - (a / b) * y;
  return d;
}

ll inverse(ll a, ll mod) {
  if (inv[a]) return inv[a];
  ll x, y;
  exgcd(frac[a], mod, x, y);
  return inv[a] = (x % mod + mod) % mod;
}

int main() {
  #ifdef LOCAL
    freopen("!input.in", "r", stdin);
    freopen("!output.out", "w", stdout);
  #else
    freopen("problem.in", "r", stdin);
    freopen("problem.out", "w", stdout);
  #endif

  scanf("%lld %lld %lld %lld\n", &n, &p, &k, &r);

  if (p == 2) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<int> dis(1, 20090512);
    printf("%d\n", dis(gen) % 2);
  } else if (isPrime(p)) {
    auto C = [&](ll x, ll y) -> ll {
      if (x < 0 || y < 0 || y > x) return 0LL;
      return frac[x] * inv[y] % p * inv[x - y] % p;
    };

    frac[0] = inv[0] = 1LL;
    for (int i = 1; i <= 1000000; ++i) {
      frac[i] = frac[i - 1] * (ll) i % p;
      inv[i] = kasumi(frac[i], p - 2LL, p);
    }

    ll ans = 0;
    for (ll i = 0; n * k >= i * k + r; ++i) {
      (ans += C(n * k, i * k + r)) %= p;
    }
    printf("%lld\n", ans);
  } else {
    auto C = [&](ll x, ll y) -> ll {
      if (x < 0 || y < 0 || y > x) return 0LL;
      return frac[x] * inverse(y, p) % p * inverse(x - y, p) % p;
    };

    frac[0] = 1LL;
    for (int i = 1; i <= 1000000; ++i) {
      frac[i] = frac[i - 1] * (ll) i % p;
    }

    ll ans = 0;
    for (ll i = 0; n * k >= i * k + r; ++i) {
      (ans += C(n * k, i * k + r)) %= p;
    }
    printf("%lld\n", ans);
  }
  return 0;
}