记录编号 |
274393 |
评测结果 |
AAAAAAAAAA |
题目名称 |
超强的乘法问题 |
最终得分 |
100 |
用户昵称 |
Chenyao2333 |
是否通过 |
通过 |
代码语言 |
C++ |
运行时间 |
1.356 s |
提交时间 |
2016-06-28 16:37:27 |
内存使用 |
15.17 MiB |
显示代码纯文本
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
using namespace std;
const int kN = 6e5 + 10;
const int P = 998244353;
const int G = 3;
char numa[kN], numb[kN];
int ans[kN];
int alen, blen;
int pra[kN], prb[kN];
int src[kN];
int w[2][kN];
int pow(int d, int k) {
if (!k) return 1;
int ret = pow(d, k/2);
ret = 1ll * ret * ret % P;
if (k % 2) ret = 1ll * ret * d % P;
return ret;
}
void fill_src(char s[], int len, int tot_len) {
int p = 0;
for (int i = len-1; i >= 0; i--) {
src[p++] = s[i] - '0';
}
while (p < tot_len) src[p++] = 0;
}
void calc_w(int n) {
int len = 1 << n;
int g = pow(G, (P-1) / (len));
w[0][0] = 1;
for (int i = 1; i <= len; i++) {
w[0][i] = 1ll * w[0][i-1] * g % P;
}
for (int i = 0; i < len; i++) {
w[1][i] = w[0][len - i];
}
}
void fft(int src[], int dst[], int scale, int flag = 0) {
// bits reserver
int len = 1 << scale;
int j = 0;
for (int i = 0; i < len; i++) {
// printf("i = %d j = %d\n", i, j);
dst[j] = src[i];
for (int l = len >> 1; (j ^= l) < l; l >>= 1);
}
for (int d = 1; d <= scale; d++) {
int m = 1 << d;
for (int i = 0; i < len; i += m) {
for (int j = 0; j < m / 2; j++) {
int u = dst[i+j];
int v = 1ll * w[flag][j * (len / m)] * dst[i + j + m/2] % P;
dst[i+j] = (u + v) % P;
dst[i+j+m/2] = (u - v + P) % P;
}
}
}
int inv = pow(len, P - 2);
if (flag == 1) for (int i = 0; i < len; i++) dst[i] = (1ll * dst[i] * inv) % P;
}
int main() {
freopen("bettermul.in","r",stdin);
freopen("bettermul.out","w",stdout);
scanf("%s %s", numa, numb);
alen = strlen(numa);
blen = strlen(numb);
int n = 1, len = 2;
while (len < 2 * max(alen, blen)) len *= 2, n++;
calc_w(n);
fill_src(numa, alen, len);
fft(src, pra, n);
fill_src(numb, blen, len);
fft(src, prb, n);
for (int i = 0; i < (1 << n); i++) {
pra[i] = 1ll * pra[i] * prb[i] % P;
}
fft(pra, src, n, 1);
int t = 0;
int ans_len = 0;
for (int i = 0; i < len; i++) {
ans[i] = src[i];
ans[i] += t;
t = ans[i] / 10;
ans[i] %= 10;
if (ans[i]) ans_len = i;
}
for (int i = ans_len; i >= 0; i--) {
printf("%d", ans[i]);
}
printf("\n");
return 0;
}