FFT
系数表示法:
设A是一个很大的数
A = $a_{0}$ + $a_{1}*x$ + $a_{2}*x^{2}$ + ... $a_{n-1}*x^{n-1}$
$a_{0}$ $a_{1}$ $a_{2}$ ... $a_{n-1}$ 分别表示A的第一位、第二位 ... 第n位,这样,向量($a_{0}$, $a_{1}$, $a_{2}$, ... $a_{n-1}$)就表示了A
------------
点值表示法:
令A ( x )= $a_{0}$ + $a_{1}*x$ + $a_{2}*x^{2}$ + ... $a_{n-1}*x^{n-1}$
我们取n个点:
{ (0,A(0)) , (1,A(1)) , ... , (n−1,A(n−1)) }
这n个点的集合就表示了A
设C = A * B
那么C的点值表达式就是
{ (0,A(0)∗B(0)) , (1,A(1)∗B(1)) ,..., (n−1,A(n−1)∗B(n−1)) }
考虑到两个n位数相乘的结果长度会加倍,所以一般都要取2n个点进行计算
所以,为了求出C,需要三个步骤:
1.求值:求出A,B的点值
2.点乘:A,B点值对应相乘得到C
3.插值:把C的点值表示法转换为系数表示法
------------
为了更快的算出点值,每个点 $x_{i}$ 的取值很特殊,它是复数
为了更快的算出点值,每个点 $x_{i}$ 的取值很特殊,它是复数
关于复数:
在复平面上任何一个复数 z 都能表示成为一个向量,即:z = r ( cosθ + i * sinθ )
(其中r是z的模长,θ是向量与x轴的夹角,称之为幅角)
欧拉公式:$e^{i\theta }$ = ( cosθ + i * sinθ )
由此可知 z = r * $e^{i\theta }$
棣莫弗公式:$(cosθ+i*sinθ)^{n}$= cos(nθ) + i * sin(nθ)
------------
在复数集下满足方程$x^{n}$=1的解一共有n个,这n个解构成1的n次单位根,这n个解就是要选取的 $x_{i}$
令A(x)=$\sum _{i=0}^{n-1}$ $a_{i}$ * $x^{i}$
$w_{n}^{0}$ $w_{n}^{1}$ .... $w_{n}^{n-1}$ 表示$w^{n}$= 1 的n个解
那么,$y_{k}=A(w_{n}^{k})=\sum _{i=0}^{n-1}a_{i} * w_{n}^{ki}$
如果把 A 的系数用向量表示 a = ($a_{0}$, $a_{1}$, $a_{2}$, ... $a_{n-1}$),把求出的所有 y 值也用向量表示 y = ($y_{0}$, $y_{1}$, $y_{2}$, ... $y_{n-1}$),那我们可以称向量 y 为 向量 a 的离散傅立叶变换(DFT)
为了较快的求出DFT,就需要用到快速傅里叶变换算法(FFT)
------------
n的取值:n必须为2的幂
举个栗子:A 有 5 个系数,不是 2 的幂,我们应该把它先扩大到比5大且离 5 最近的那个 2 的幂,即 2^3 = 8。然后再把 8 扩大两倍变成 16;这才是 A B 两个多项式相乘的最终的 n
原因在于FFT算法运用了分治的策略,它用A( x )偶数下标的系数和奇数下标的系数构造了两个全新的次数界为n/2的多项式:
$A^{[0]}(x)=a_{0}+a_{2}x+a_{4}x^{2}+...+a_{n-2}x^{n/2-1}$
$A^{[1]}(x)=a_{1}+a_{3}x+a_{5}x^{2}+...+a_{n-1}x^{n/2-1}$
不难看出:$A(k)=A^{[0]}(k^{2})+k*A^{[1]}(k^{2})$
------------
相消引理:
对任何整数 n >= 0,k >= 0,d > 0,$w^{dk}_{dn}=w^{k}_{n}$
推论:
$w^{n+k}_{n}=w^{n}_{n}*w^{k}_{n}=w^{k}_{n}$
$w^{n/2+k}_{n}=w^{n/2}_{n}*w^{k}_{n}=-w^{k}_{n}$
因此:$A(w^{k}_{n})=A^{[0]}(w^{k}_{n/2})+w^{k}_{n}*A^{[1]}(w^{k}_{n/2})$
这个过程就是快速傅里叶变换
A 和 B 分别转化为点值表示后,y 值对应相乘就得到了多项式 C 的点值表示,接下来把 C 的点值表示转换为系数表示就完成了乘法的运算,即离散傅里叶反变换(IDFT)
------------
构造一个范德蒙德矩阵 V 满足:
$\begin{bmatrix}y_{0}\\y_{1}\\y_{2}\\...\\y_{n-1}\end{bmatrix}=\begin{bmatrix}1&1&1&...&1\\1&w^{1}_{n}&w^{2}_{n}&...&w^{n-1}_{n}\\...&... &... &... &...\\1&w^{n-1}_{n} &w^{2(n-1)}_{n}&... &w^{(n-1)(n-1)}_{n}\end{bmatrix}\begin{bmatrix}a_{0}\\a_{1}\\a_{2}\\...\\a_{n-1}\end{bmatrix}$
不妨记做 $y = Va$
那么 $a =V^{-1}y$
对于第j行第k列,0<=j,k<=n-1,$V$处的值为$w^{kj}_{n}$,
$V^{-1}$出的值为$w^{-kj}_{n}/n$
由此推出:$a_{k}=\frac{1}{n}*\sum _{i=0}^{n-1}y_{i} * w_{n}^{-ki}$
$V^{-1}$出的值为$w^{-kj}_{n}/n$
由此推出:$a_{k}=\frac{1}{n}*\sum _{i=0}^{n-1}y_{i} * w_{n}^{-ki}$
所以把 C 的点值表示转换为系数表示只需要对 C 的 y 套用一遍FFT,根据棣莫弗公式,当指数取反时,对应的幅角也取反,实部的cos符号不变,虚部的sin取反,然后算完之后对整个a除以n即可,即逆快速傅里叶变换
裸题:HDU - 1402
#include <bits/stdc++.h> using namespace std; #define mset(var,val) memset(var,val,sizeof(var)) const double pi=acos(-1.0); const int M=5e5+10; char s1[M],s2[M]; int ans[M]; double rea[M],ina[M],reb[M],inb[M],rec[M],inc[M],Retmp[M],Intmp[M]; void FFT(double reA[], double inA[], int n, int flag) { if(n == 1) return; int k,u,i; double reWm = cos(2*pi/n), inWm = sin(2*pi/n); if(flag) inWm = -inWm; double reW = 1.0, inW = 0.0; for(k = 1,u = 0; k < n; k += 2,u++) { Retmp[u] = reA[k]; Intmp[u] = inA[k]; } for(k = 2; k < n; k += 2) { reA[k/2] = reA[k]; inA[k/2] = inA[k]; } for(k = u,i = 0; k < n && i < u; k++, i++) { reA[k] = Retmp[i]; inA[k] = Intmp[i]; } FFT(reA, inA, n/2, flag); FFT(reA + n/2, inA + n/2, n/2, flag); for(k = 0; k < n/2; k++) { int tag = k+n/2; double reT = reW * reA[tag] - inW * inA[tag]; double inT = reW * inA[tag] + inW * reA[tag]; double reU = reA[k], inU = inA[k]; reA[k] = reU + reT; inA[k] = inU + inT; reA[tag] = reU - reT; inA[tag] = inU - inT; double rew_t = reW * reWm - inW * inWm; double inw_t = reW * inWm + inW * reWm; reW = rew_t; inW = inw_t; } } int main(int argc, char const *argv[]) { while(~scanf("%s%s",s1,s2)) { mset(ans,0);mset(rea,0);mset(ina,0);mset(reb,0);mset(inb,0); int len1=strlen(s1),len2=strlen(s2),len=1; int m=max(len1,len2); while(len<m)len*=2; len*=2; for(int i=0;i<len;i++) { if(i<len1)rea[i]=s1[len1-1-i]-'0'; if(i<len2)reb[i]=s2[len2-1-i]-'0'; } FFT(rea,ina,len,0); FFT(reb,inb,len,0); for(int i=0;i<len;i++) { rec[i]=rea[i]*reb[i]-ina[i]*inb[i]; inc[i]=rea[i]*inb[i]+ina[i]*reb[i]; } FFT(rec,inc,len,1); for(int i=0;i<len;i++) { rec[i]/=len; inc[i]/=len; ans[i]+=(int)(rec[i]+0.5); ans[i+1]+=ans[i]/10; ans[i]%=10; } int lent=len1+len2+2; while(ans[lent]==0&&lent>0)lent--; for(int i=lent;i>=0;i--) printf("%d",ans[i]); printf("\n"); } return 0; }