记录编号 |
402917 |
评测结果 |
AAAAAAAAAAAA |
题目名称 |
增强的乘法问题 |
最终得分 |
100 |
用户昵称 |
sxysxy |
是否通过 |
通过 |
代码语言 |
C++ |
运行时间 |
0.079 s |
提交时间 |
2017-05-08 13:47:36 |
内存使用 |
32.45 MiB |
显示代码纯文本
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cstdarg>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
int gcd(int a, int b){return b?gcd(b, a%b):a;}
struct virt{
double a, b;
void operator=(int p){a = p, b = 0.0;}
virt(double x = 0, double y = 0){a = x, b = y;}
virt operator+(virt &r){return virt(a+r.a, b+r.b);}
virt operator-(virt &r){return virt(a-r.a, b-r.b);}
virt operator*(double p){return virt(a*p, b*p);}
virt operator*(virt &r){return virt(a*r.a-b*r.b, a*r.b+b*r.a);}
};
virt pa[(1<<20)+1];
virt pb[(1<<20)+1];
const double PI = acos(-1);
void rader(virt *a, int n){
int j = n>>1;
for(int i = 1; i < n-1; i++){
if(i < j)swap(a[i], a[j]);
int k = n>>1;
while(j >= k){
j -= k; k >>= 1;
}if(j < k)j += k;
}
}
void fft(virt *a, int n, bool idft = false){
rader(a, n);
double pi = idft?PI:-PI;
for(int h = 2; h <= n; h <<= 1){
virt omgn(cos(2*pi/h), sin(2*pi/h));
for(int j = 0; j < n; j += h){
virt omg(1, 0);
for(int k = j; k < j+h/2; k++){
virt u = a[k], v = omg*a[k+h/2];
a[k] = u+v, a[k+h/2] = u-v;
omg = omg*omgn;
}
}
}
if(idft)for(int i = 0; i < n; i++)a[i].a /= n;
}
void mul(vector<unsigned> &a, vector<unsigned> &b){
int la = 0, lb = 0;
int s = 0;
for(vector<unsigned>::iterator it = a.begin(); it != a.end(); ++it)
pa[la++] = *it;
for(vector<unsigned>::iterator it = b.begin(); it != b.end(); ++it)
pb[lb++] = *it;
while((1<<s) <= la+lb)s++; s = 1<<s;
for(int i = la; i < s; i++)pa[i] = 0;
for(int i = lb; i < s; i++)pb[i] = 0;
fft(pa, s); fft(pb, s);
for(int i = 0; i < s; i++)pa[i] = pa[i]*pb[i];
fft(pa, s, true);
a.resize(la+lb-1);
for(int i = 0; i <= la+lb-2; i++)
a[i] = (unsigned)(pa[i].a+0.5);
}
static const int BASE = 10000;
static const int BPOW[4] = {1, 10, 100, 1000};
struct bignum{
vector<unsigned> bits;
inline void init(int v){
bits.push_back(v);
}
inline int length(){return bits.size();}
void addw(bignum &o){
int s = 0;
if(length() > o.length()){
s = length(); o.bits.resize(s);
}else{
s = o.length(); bits.resize(s);
}
for(int i = 0; i < s; i++){
bits[i] += o.bits[i];
}
if(bits[s-1] > BASE)bits.push_back(0);
for(int i = 0; i+1 < length(); i++){
bits[i+1] += bits[i]/BASE;
bits[i] %= BASE;
}
}
inline bool is_zero(){return bits[length()-1] == 0;}
void mulw(bignum &o){
if(is_zero())return;
if(o.is_zero()){bits.clear(); init(0); return;}
mul(bits, o.bits);
for(int i = 0; i < length(); i++){
if(bits[i] >= BASE)
if(i+1 == length())bits.push_back(bits[i]/BASE);
else bits[i+1] += bits[i]/BASE;
bits[i] %= BASE;
}
}
void print( ){
printf("%d", bits[length()-1]);
for(int i = bits.size()-2; ~i; i--)printf("%04d", bits[i]); putchar('\n');
}
void get(){
static char buf[150002];
scanf("%s", buf);
int cnt = 0;
int cur = 0;
for(int i = strlen(buf)-1, x; ~i; i--){
if(length() <= cur)bits.push_back(BPOW[cnt++]*(buf[i]-'0'));
else bits[cur] += BPOW[cnt++]*(buf[i]-'0');
if(cnt == 4)cnt = 0, cur++;
}
}
};
void fast_pow(bignum &a, int x){
bignum b = a;
a.bits.clear(); a.init(1);
while(x){
if(x&1)a.mulw(b);
x >>= 1;
b.mulw(b);
}
}
int main(){
//freopen("test_data.out", "r", stdin);
freopen("mul.in", "r", stdin);
freopen("mul.out", "w", stdout);
bignum a, b;
a.get();
b.get();
a.mulw(b);
a.print();
return 0;
}