记录编号 274393 评测结果 AAAAAAAAAA
题目名称 超强的乘法问题 最终得分 100
用户昵称 GravatarChenyao2333 是否通过 通过
代码语言 C++ 运行时间 1.356 s
提交时间 2016-06-28 16:37:27 内存使用 15.17 MiB
显示代码纯文本
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

using namespace std;
const int kN = 6e5 + 10;
const int P = 998244353;
const int G = 3;

char numa[kN], numb[kN];
int ans[kN];
int alen, blen;
int pra[kN], prb[kN];
int src[kN];
int w[2][kN];

int pow(int d, int k) {
    if (!k) return 1;
    int ret = pow(d, k/2);
    ret = 1ll * ret * ret % P;
    if (k % 2) ret = 1ll * ret * d % P;
    return ret;
}

void fill_src(char s[], int len, int tot_len) {
    int p = 0;
    for (int i = len-1; i >= 0; i--) {
        src[p++] = s[i] - '0';
    }
    while (p < tot_len) src[p++] = 0;
}

void calc_w(int n) {
    int len = 1 << n;
    int g = pow(G, (P-1) / (len));

    w[0][0] = 1;
    for (int i = 1; i <= len; i++) {
        w[0][i] = 1ll * w[0][i-1] * g % P;
    }
    for (int i = 0; i < len; i++) {
        w[1][i] = w[0][len - i];
    }
}

void fft(int src[], int dst[], int scale, int flag = 0) {
    // bits reserver
    int len = 1 << scale;
    int j = 0;
    for (int i = 0; i < len; i++) {
        // printf("i = %d j = %d\n", i, j);
        dst[j] = src[i];
        for (int l = len >> 1; (j ^= l) < l; l >>= 1);
    }

    for (int d = 1; d <= scale; d++) {
        int m = 1 << d;
        for (int i = 0; i < len; i += m) {
            for (int j = 0; j < m / 2; j++) {
                int u = dst[i+j];
                int v = 1ll * w[flag][j * (len / m)] * dst[i + j + m/2] % P;
                dst[i+j] = (u + v) % P;
                dst[i+j+m/2] = (u - v + P) % P;
            }
        }
    }

    int inv = pow(len, P - 2);
    if (flag == 1) for (int i = 0; i < len; i++) dst[i] = (1ll * dst[i] * inv) % P;
}

int main() {
    freopen("bettermul.in","r",stdin);
    freopen("bettermul.out","w",stdout);

    scanf("%s %s", numa, numb);

    alen = strlen(numa);
    blen = strlen(numb);

    int n = 1, len = 2;
    while (len < 2 * max(alen, blen)) len *= 2, n++;

    calc_w(n);

    fill_src(numa, alen, len);
    fft(src, pra, n);
    fill_src(numb, blen, len);
    fft(src, prb, n);

    for (int i = 0; i < (1 << n); i++) {
        pra[i] = 1ll * pra[i] * prb[i] % P;
    }
    fft(pra, src, n, 1);
    
    int t = 0;
    int ans_len = 0;
    for (int i = 0; i < len; i++) {
        ans[i] = src[i];
        ans[i] += t;
        t = ans[i] / 10;
        ans[i] %= 10;
        if (ans[i]) ans_len = i;
    }

    for (int i = ans_len; i >= 0; i--) {
        printf("%d", ans[i]);
    }
    printf("\n");
    return 0;
}