(注:为了方便,本文将 $S$ 中 $1$ 的个数限制设为 $K$)
这种计数类的问题大概率是 DP,可以往这个方面想。
考虑状态的设计,由于这道题存在进位的问题,而且进位是从低到高的,所以可以按二进制位从低到高考虑。
那么状态里肯定要有两维:$(i,j)$,分别表示当前的位数,和已经确定的 $a$ 中元素的数量。
但是题中对二进制 $1$ 的个数限制为 $K$,而且进位相当烦人,这时就可以考虑直接把它们设进状态。
毕竟这题数据范围不大,就算不是正解也能拿不少分。
由此,设计出一个 DP:
设 $f(i,j,k,q)$ 表示 $0\sim i-1$ 位已经考虑过,当前考虑第 $i$ 位,$a$ 中已经有 $j$ 个元素确定,目前 $S$ 中有 $k$ 个 $1$,且从 $0\sim i-1$ 推过来的进位数为 $q$ 时的权值和。
初始状态:$f(0,0,0,0)=1$。
发现这个状态并不是很好从前面转移来,那么我们就用已有的状态往后转移(刷表)。
考虑在第 $i$ 位放 $t(0\le t\le n-j)$ 个 $a$ 中的元素,那么 $S$ 中 $1$ 的个数会变成 $k + ((t + q)\bmod 2)$,向第 $i+1$ 位进 $\lfloor \frac{t+q}{2} \rfloor$ 个 $1$。
那么接下来的状态就是 $f(i+1,j+t,k+((t+q)\bmod 2),\lfloor \frac{t+q}{2} \rfloor)$。
现在来算一算这次转移的贡献,直接放式子:
$$f(i,j,k,q) \times \mathrm C_{n-j}^t \times v_i^t$$
这个式子并不难理解,就是在 $a$ 剩下的 $n-j$ 个元素中选 $t$ 个,会产生 $v_i^t$ 的权值。
剩余要注意的就是统计答案。累加上所有的 $f(m+1,n,k,q)$。
但因为这题二进制 $1$ 的个数至多为 $K$,而且 $m+1$ 位及以后显然还会有进位产生的 $1$,不难发现,这个状态的 $S$ 中真正的 $1$ 的个数是 $k+\text{popcount}(q)$。
所以还要在枚举时判断一下 $k+\text{popcount}(q) \le K$。
那么这道题就做完了。时间复杂度 $O(mn^4)$,卡得很紧,组合数和 $v_i^t$ 都要预处理出来。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int maxn = 35;
const int maxm = 105;
ll C[maxn][maxn];
int n,m,K;
ll v[maxm],pw[maxm][maxn],popcnt[maxn];
int lowbit(int x) {
return x & -x;
}
int popcount(int x) {
int ans = 0;
for(;x;x -= lowbit(x))++ ans;
return ans;
}
ll f[maxm][maxn][maxn][maxn >> 1];
int main() {
freopen("sequence.in","r",stdin);
freopen("sequence.out","w",stdout);
scanf("%d %d %d",&n,&m,&K);
for(int i = 0;i <= m;++ i) {
scanf("%lld",&v[i]);
pw[i][0] = 1ll;
for(int j = 1;j <= n;++ j)pw[i][j] = pw[i][j - 1] * v[i] % mod;
}
for(int i = 0;i <= n;++ i) {
C[i][0] = 1ll;
for(int j = 1;j <= i;++ j) {
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mod;
}
popcnt[i] = popcount(i);
}
f[0][0][0][0] = 1ll;
for(int i = 0;i <= m;++ i) {
for(int j = 0;j <= n;++ j) {
for(int k = 0;k <= K;++ k) {
for(int q = 0;q <= (n >> 1);++ q) {
for(int t = 0;t <= n - j;++ t) {
(f[i + 1][j + t][k + (t + q & 1)][t + q >> 1] += f[i][j][k][q] * C[n - j][t] % mod * pw[i][t] % mod) %= mod;
}
}
}
}
}
ll ans = 0;
for(int k = 0;k <= K;++ k) {
for(int q = 0;q <= (n >> 1);++ q) {
if(k + popcnt[q] <= K) {
(ans += f[m + 1][n][k][q]) %= mod;
}
}
}
printf("%lld\n",ans);
return 0;
}