记录编号 402917 评测结果 AAAAAAAAAAAA
题目名称 增强的乘法问题 最终得分 100
用户昵称 Gravatarsxysxy 是否通过 通过
代码语言 C++ 运行时间 0.079 s
提交时间 2017-05-08 13:47:36 内存使用 32.45 MiB
显示代码纯文本
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cstdarg>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
int gcd(int a, int b){return b?gcd(b, a%b):a;}
struct virt{
  double a, b;
  void operator=(int p){a = p, b = 0.0;}
  virt(double x = 0, double y = 0){a = x, b = y;}
  virt operator+(virt &r){return virt(a+r.a, b+r.b);}
  virt operator-(virt &r){return virt(a-r.a, b-r.b);}
  virt operator*(double p){return virt(a*p, b*p);}
  virt operator*(virt &r){return virt(a*r.a-b*r.b, a*r.b+b*r.a);}
};
virt pa[(1<<20)+1];
virt pb[(1<<20)+1];
const double PI = acos(-1);
void rader(virt *a, int n){
  int j = n>>1;
  for(int i = 1; i < n-1; i++){
    if(i < j)swap(a[i], a[j]);
    int k = n>>1;
    while(j >= k){
      j -= k; k >>= 1;
    }if(j < k)j += k;
  }
}
void fft(virt *a, int n, bool idft = false){
  rader(a, n);
  double pi = idft?PI:-PI;
  for(int h = 2; h <= n; h <<= 1){
    virt omgn(cos(2*pi/h), sin(2*pi/h));
    for(int j = 0; j < n; j += h){
      virt omg(1, 0);
      for(int k = j; k < j+h/2; k++){
        virt u = a[k], v = omg*a[k+h/2];
        a[k] = u+v, a[k+h/2] = u-v;
        omg = omg*omgn;
      }
    }
  }
  if(idft)for(int i = 0; i < n; i++)a[i].a /= n;
}
void mul(vector<unsigned> &a, vector<unsigned> &b){
  int la = 0, lb = 0;
  int s = 0;
  for(vector<unsigned>::iterator it = a.begin(); it != a.end(); ++it)
    pa[la++] = *it;
  for(vector<unsigned>::iterator it = b.begin(); it != b.end(); ++it)
    pb[lb++] = *it;
  while((1<<s) <= la+lb)s++; s = 1<<s;
  for(int i = la; i < s; i++)pa[i] = 0;
  for(int i = lb; i < s; i++)pb[i] = 0;
  fft(pa, s); fft(pb, s);
  for(int i = 0; i < s; i++)pa[i] = pa[i]*pb[i];
  fft(pa, s, true);
  a.resize(la+lb-1);
  for(int i = 0; i <= la+lb-2; i++)
    a[i] = (unsigned)(pa[i].a+0.5);
}
static const int BASE = 10000;
static const int BPOW[4] = {1, 10, 100, 1000};
struct bignum{
  vector<unsigned> bits;
  inline void init(int v){
    bits.push_back(v);
  }
  inline int length(){return bits.size();}
  void addw(bignum &o){
    int s = 0;
    if(length() > o.length()){
      s = length(); o.bits.resize(s);
    }else{
      s = o.length(); bits.resize(s);
    }
    for(int i = 0; i < s; i++){
      bits[i] += o.bits[i];
    }
    if(bits[s-1] > BASE)bits.push_back(0);
    for(int i = 0; i+1 < length(); i++){
      bits[i+1] += bits[i]/BASE;
      bits[i] %= BASE;
    }
  }
  inline bool is_zero(){return bits[length()-1] == 0;}
  void mulw(bignum &o){
    if(is_zero())return;
    if(o.is_zero()){bits.clear(); init(0); return;}
    mul(bits, o.bits); 
    for(int i = 0; i < length(); i++){
      if(bits[i] >= BASE)
      if(i+1 == length())bits.push_back(bits[i]/BASE);
      else bits[i+1] += bits[i]/BASE;
      bits[i] %= BASE;
    }
  }
  void print( ){
    printf("%d", bits[length()-1]);
    for(int i = bits.size()-2; ~i; i--)printf("%04d", bits[i]); putchar('\n');
  }
  void get(){
    static char buf[150002];
    scanf("%s", buf);
    int cnt = 0;
    int cur = 0;
    for(int i = strlen(buf)-1, x; ~i; i--){
      if(length() <= cur)bits.push_back(BPOW[cnt++]*(buf[i]-'0'));
      else bits[cur] += BPOW[cnt++]*(buf[i]-'0');
      if(cnt == 4)cnt = 0, cur++;
    }
  }
};
void fast_pow(bignum &a, int x){
  bignum b = a;
  a.bits.clear(); a.init(1);
  while(x){
    if(x&1)a.mulw(b);
    x >>= 1;
    b.mulw(b);
  }
}
int main(){
  //freopen("test_data.out", "r", stdin);
  freopen("mul.in", "r", stdin);
  freopen("mul.out", "w", stdout);
  bignum a, b; 
  a.get(); 
  b.get();
  a.mulw(b); 
  a.print();
  
  return 0;
}