快速傅里叶变换和快速数论变换FFT&NTT
快速傅里叶变换(FFT)
作用:加速多项式乘法
朴素高精度乘法时间O(n^2),但FFT能O(nlog2n)的时间解决
前置知识:
1.点值表示法:
f(x)={( x0,f(x0) ),( x1,f(x1) ) ,( x2, f(x2) ), ( x3, f(x3) ), ( x4, f(x4) ), ... , (xn-1, f(xn-1) )}
g(x)={( x0,g(x0) ),( x1,g(x1) ) ,( x2, g(x2) ), ( x3, g(x3) ), ( x4, g(x4) ), ... , (xn-1, g(xn-1) )}
设它们乘积为h(x),那么
h(x)={( x0,f(x0)g(x0) ),( x1, f(x1)g(x1) ), ( x2, f(x2)g(x2) ), ... , ( xn-1, f(xn-1)g(xn-1) )}
2.复数
(a1,θ1) *(a2,θ2)为(a1a2,θ1+θ2)
快速傅里叶变换的实现:
const double PI = acos(-1);
typedef complex <double> cp;
cp omega(int n, int k) {
return cp(cos(2 * PI * k / n), sin(2 * PI * k / n));
}
void fft(cp *a, int n, bool inv) {
if(n == 1) return ;
static cp buf[N];
int m = n / 2;
for(int i = 0; i < m ; i++){
buf[i] = a[2 * i];
buf[i + m] = a[2 * i + 1];
}
for(int i = 0; i < n; i++)
a[i] = buf[i];
fft(a, m, inv);
fft(a + m, m, inv);
for(int i = 0; i < m; i++) {
cp x = omega(n, i);
if(inv) x = conj(x);
//conj是一个自带的求共轭复数的函数,精度较高。当复数模为1时,共轭复数等于倒数
buf[i] = a[i] + x * a[i + m];
buf[i + m] = a[i] - x * a[i + m];
}
for(int i = 0; i < n; i++)
a[i] = buf[i];
} 注1:a中i属于0-m-1是A1(i)的取值,a中i属于m到n-1是A2(i)的取值,那么递归的分析,对现在A1中的第i位置,其是由0-m/2-1的得来的,然后最后再更新a数组,a(i)就是A(omega(n, i))的取值了。
注2:n是偶数,多项式的项数。
注3:inv表示单位根是否要取倒数,FFT的逆变换即点值表示法转化为系数表示法。
做法:把点值表示法作为系数,用取了倒数的单位根代入求个点值表示法,得到Zi再除以n就是i的系数ai(证明参考:https://www.cnblogs.com/RabbitHu/p/FFT.html)以下提到"博客"均指此篇博客。
注4:在逆FFT的时候,应该用floor(a[i].real() / n + 0.5);来得到res[i](精度问题)
优化fft(非递归)
发现每次都往下其实都是先去递归,使得所有元素到达其应该在的地方,然后再不断往上递归对a赋值的。
规律:参考博客,可以发现,一个位置a上的数,最后所在的位置是"a二进制翻转得到的数"
据此写出非递归版本fft:先把每个数放到最后的位置上,然后不断向上,从而求出最终答案。
#include<iostream>
#include<cstdio>
#include<complex>
using namespace std;
const double PI = acos(-1);
typedef complex <double> cp;
cp a[N], b[N], omg[N], inv[N];
void init() {
for(int i = 0; i < n; i++) {
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(cp *a, cp *omg) {
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++) {
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
static cp buf[N];
for(int l = 2; l <= n; l *= 2) { //区间长度
int m = l / 2;
for(int j = 0; j < n; j += l) //区间起点
for(int i = 0; i < m; i++) {
buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
}
for(int j = 0; j < n; j++)
a[j] = buf[j];
}
}
蝴蝶变换:
之前为什么需要buf数组:
如果
a[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m]
会对更新a[j + i + m]造成影响。
而通过蝴蝶变换
cp t = omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - t
a[j + i] = a[j + i] + t
不就顺序换了一下??
反正就不用buf数组就是了。
FFT最终模板:
const double PI = acos(-1);
typedef complex <double> cp;
cp a[N], b[N], omg[N], inv[N];
void init() {
for(int i = 0; i < n; i++) {
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(cp *a, cp *omg) {
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++) {
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
for(int l = 2; l <= n; l *= 2) { //区间长度
int m = l / 2;
for(cp *p = a; p != a + n; p += l)
for(int i = 0; i < m; i++) {
cp t = omg[n / l * i] * p[i + m];
p[i + m] = p[i] - t;
p[i] += t;
}
}
} 题1:a*bIII http://www.acmicpc.sdnu.edu.cn/problem/show/1531
注意三点:
1.0与其他数相乘只输出一个0.
2.memset是可以对复数初始化的
3.FFT中的n需要是2的倍数
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1000005;
const double PI = acos(-1);
typedef complex <double> cp;
char sa[N], sb[N];
int n = 1, lena, lenb, ans[N];
cp a[N], b[N], omg[N], inv[N];
void init(){
for(int i = 0; i < n; i++){
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(cp *a, cp *omg) {
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++) {
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
for(int l = 2; l <= n; l *= 2) { //区间长度
int m = l / 2;
for(cp *p = a; p != a + n; p += l)
for(int i = 0; i < m; i++) {
cp t = omg[n / l * i] * p[i + m];
p[i + m] = p[i] - t;
p[i] += t;
}
}
}
signed main()
{
while(~scanf("%s%s",sa,sb)) {
memset(ans, 0, sizeof(ans));
memset(a, 0, sizeof(a));
memset(b, 0, sizeof(b));
n = 1;
lena = strlen(sa), lenb = strlen(sb);
if(lena == 1 && sa[0] == '0' || lenb == 1 && sb[0] == '0') {
printf("0\n");
continue;
}
while(n < lena + lenb) n *= 2;
for(int i = 0; i < lena; i++){
a[i].real(sa[lena - 1 - i] - '0');
}
for(int i = 0; i < lenb; i++){
b[i].real(sb[lenb - 1 - i] - '0');
}
init();
fft(a, omg);
fft(b, omg);
for(int i = 0; i < n; i++) {
a[i] *= b[i];
}
fft(a, inv);
for(int i = 0; i < n; i++) {
ans[i] += floor(a[i].real() / n + 0.5);
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
int beg;
for(int i = n-1; i >= 0; i--) {
if(ans[i] != 0){
beg = i;
break;
}
}
for(int i = beg; i >= 0; i--){
printf("%lld",ans[i]);
}
putchar('\n');
}
return 0;
}
FFT的缺点是它的复数运算double精度问题导致它实际上是k*nlongn的,会比NTT的常数大很多。
NTT
前置知识:
原根:对于g,p属于Z, 如果g^i mod p ( 1<=i<=p-1)的值互不相同,则称g为p的原根。
或者说对于任意i,j(1<=i<j <= p-1) g^i mod p /= g^j mod p,那么g为p的原根。
常见模数有:998244353,1004535809,469762049,这几个的原根都为3
在NTT中,我们拿原根来代替FFT的单位根
#define g 3
const int mod = 998244353;
const int N = 300000;
inline get_rev()
{
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++) {
rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (lim - 1)));
}
}
inline void ntt(int *a, int inv) {
for(int i = 0; i < n; i++) {
if(i < rev[i]) swap(a[i], a[rev[i]]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
for(int l = 2; l <= n; l *= 2) { //区间长度
int m = l / 2;
int tmp = q_pow(g, (mod-1)/l);
if(inv == -1) tmp = q_pow(tmp, mod-2);
for(int i = 0; i < n; i += l) {
int omega = 1;
for(int j = 0; j < m; j++, omega = omega*tmp%mod) {
int x = a[i + j], y = omega * a[i + j + m] % mod;
a[i + j] = (x + y) % mod, a[i + j + m] = (x - y + mod) % mod;
}
}
}
if(inv == -1)
{
int nI = q_pow(n, mod-2);
for(int i = 0; i < n; i++) {
a[i] = a[i] * nI % mod;
}
}
} 例1:http://www.acmicpc.sdnu.edu.cn/problem/show/1532
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define g 3
const int mod = 998244353;
const int N = 300000;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
inline void write(int x)
{
if(x<0)x=-x,putchar('-');
if(x>9)write(x/10);
putchar(x%10+'0');
}
char sa[N], sb[N];
int n = 1, a[N], b[N], rev[N], lena, lenb;
inline int q_pow(int a, int b){
int ans = 1;
while(b > 0){
if(b & 1){
ans = ans * a % mod;
}
a = a * a % mod;
b >>= 1;
}
return ans;
}
inline get_rev()
{
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++) {
rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (lim - 1)));
}
}
inline void ntt(int *a, int inv) {
for(int i = 0; i < n; i++) {
if(i < rev[i]) swap(a[i], a[rev[i]]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
for(int l = 2; l <= n; l *= 2) { //区间长度
int m = l / 2;
int tmp = q_pow(g, (mod-1)/l);
if(inv == -1) tmp = q_pow(tmp, mod-2);
for(int i = 0; i < n; i += l) {
int omega = 1;
for(int j = 0; j < m; j++, omega = omega*tmp%mod) {
int x = a[i + j], y = omega * a[i + j + m] % mod;
a[i + j] = (x + y) % mod, a[i + j + m] = (x - y + mod) % mod;
}
}
}
if(inv == -1)
{
int nI = q_pow(n, mod-2);
for(int i = 0; i < n; i++) {
a[i] = a[i] * nI % mod;
}
}
}
signed main()
{
while(~scanf("%s%s",sa,sb))
{
n = 1;
memset(a, 0, sizeof(a));
memset(b, 0, sizeof(b));
lena = strlen(sa), lenb = strlen(sb);
for(int i = 0; i < lena; i++) {
a[i] = sa[lena - 1 - i] - '0';
}
for(int i = 0; i < lenb; i++) {
b[i] = sb[lenb - 1 - i] - '0';
}
while(n < lena + lenb) n <<= 1;
get_rev();
ntt(a, 1);
ntt(b, 1);
for(int i = 0; i < n; i++) {
a[i] = a[i] * b[i] % mod;
}
ntt(a, -1);
for(int i = 0; i < n; i++) {
a[i + 1] += a[i] / 10;
a[i] %= 10;
}
int cnt = n;
while(cnt >= 0 && a[cnt] == 0) cnt--;
if(cnt == -1) {
printf("0");
}
else {
for(int i = cnt; i >= 0; i--){
write(a[i]);
}
}
putchar('\n');
}
return 0;
}