可以先了解下用分治法求多项式的值 链接。
求 $A(x)$ 在 $w_j,j=0,1,2...,n-1$ 处的值:
设多项式
$$A(x)=a_0+a_1*x+a_2*{x^2}+a_3*{x^3}+a_4*{x^4}+a_5*{x^5}+ \dots+a_{n-2}*x^{n-2}+a_{n-1}*x^{n-1}$$
按下标的奇偶性分类,设
$$A_1(x)=a_0+a_2*{x}+a_4*{x^2}+\dots+a_{n-2}*x^{\frac{n}{2}-1}$$
$$A_2(x)=a_1+a_3*{x}+a_5*{x^2}+ \dots+a_{n-1}*x^{\frac{n}{2}-1}$$
那么不难得到 $A(x)=A_1(x^2)+xA_2(x^2)$.
我们将 $\omega_n^k (k<\frac{n}{2})$ 代入得
$$\begin{aligned}
A(\omega_n^k) &=A_1(\omega_n^{2k})+\omega_n^kA_2(\omega_n^{2k}) \\
&=A_1(\omega_{\frac{n}{2}}^{k})+\omega_n^kA_2(\omega_{\frac{n}{2}}^{k})
\end{aligned}$$
同理,将 $\omega_n^{k+\frac{n}{2}}$ 代入得
$$\begin{aligned}
A(\omega_n^{k+\frac{n}{2}}) &=A_1(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}(\omega_n^{2k+n}) \\
&=A_1(\omega_n^{2k}*\omega_n^n)-\omega_n^kA_2(\omega_n^{2k}*\omega_n^n) \\
&=A_1(\omega_n^{2k})-\omega_n^kA_2(\omega_n^{2k})
\end{aligned}$$
两个式子只有符号不同,也就是说,算出在前 $n$ 个点的值就能得到后 $n$ 个点的值,相当于问题规模减半了。
所以可以递归的实现,直到多项式仅剩一个常数项,这时候我们直接返回就好啦!
这样时间复杂度为 $O(nlogn)$.
FFT算法的伪代码:
1.求值 $A(w_j), B(w_j)$,$j=0,1,..2n-1$
2. 计算 $C(w_j)$,$j=0,1,...,2n-1$
3. 构造多项式
$D(x)=C(w_0) + C(w_1)x+...+C(w_{2n-1})x^{2n-1}$
4. 计算所有的 $D(w_j)$,$j=0,1,...2n-1$
5. 利用下式计算 $C(x)$ 的系数 $c_j$
$D(w_j) = 2nc_{2n-j},\ j=1,...,2n-1$
$D(w_0) = 2nc_0$
递归版
fft函数 $type=1$ 进去的时候 $a$ 数组存的是系数,返回时存的是计算出来的值,$a[i]=A(w_i)$;
$type=-1$ 时相反。
这里预处理了sin和cos值,大概能快2、3倍
// luogu-judger-enable-o2 #include<iostream> #include<cstdio> #include<cmath> using namespace std; const int MAXN = 4 * 60000 + 10; 、、开4倍空间 inline int read() { char c = getchar(); int x = 0, f = 1; while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();} while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();} return x * f; } const double Pi = acos(-1.0); const double Eps = 1e-8; double ccos[MAXN], ssin[MAXN]; struct complex { double x, y; complex (double xx = 0, double yy = 0) {x = xx, y = yy;} } a[MAXN], b[MAXN]; complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);} complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);} complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分 void fast_fast_tle(int limit, complex *a, int type) { if (limit == 1) return ; //只有一个常数项 complex a1[limit >> 1], a2[limit >> 1]; for (int i = 0; i < limit; i += 2) //根据下标的奇偶性分类 a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1]; fast_fast_tle(limit >> 1, a1, type); fast_fast_tle(limit >> 1, a2, type); complex Wn = complex(ccos[limit] , type * ssin[limit]), w = complex(1, 0); //complex Wn = complex(cos(2.0 * Pi / limit) , type * sin(2.0 * Pi / limit)), w = complex(1, 0); //Wn为单位根,w表示幂 for (int i = 0; i < (limit >> 1); i++, w = w * Wn) //这里的w相当于公式中的k a[i] = a1[i] + w * a2[i], a[i + (limit >> 1)] = a1[i] - w * a2[i]; //利用单位根的性质,O(1)得到另一部分 } char s[MAXN]; int res[MAXN]; int main() { int N = read(); scanf("%s", s); for (int i = 0; i < N; i++) a[i].x = s[N-1-i]-'0'; scanf("%s", s); for (int i = 0; i < N; i++) b[i].x = s[N-1-i]-'0'; //for(int i = 0;i < N;i++) printf("%f ", a[i]); int limit = 1; while (limit <= 2*N) limit <<= 1; for(int i = 1;i <= limit;i++) { ccos[i] = cos(2.0 * Pi / i); ssin[i] = sin(2.0 * Pi / i); } fast_fast_tle(limit, a, 1); fast_fast_tle(limit, b, 1); //后面的1表示要进行的变换是什么类型 //1表示从系数变为点值 //-1表示从点值变为系数 //至于为什么这样是对的,可以参考一下c向量的推导过程, for (int i = 0; i <= limit; i++) a[i] = a[i] * b[i]; fast_fast_tle(limit, a, -1); for(int i = 0;i <= 2*N;i++) res[i] = int(a[i].x/limit+0.5); int tmp = 0; //进位 for(int i = 0;i <= 2*N;i++) { res[i] += tmp; tmp = res[i] / 10; res[i] = res[i] % 10; } bool flag = false; for (int i = 2*N; i >= 0; i--) { //printf("%f ", a[i].x); if(res[i]) flag = true; if(flag) printf("%d", res[i]); //按照我们推倒的公式,这里还要除以n } return 0; }
非递归版
这个很容易发现点什么吧?
- 每个位置分治后的最终位置为其二进制翻转后得到的位置
这样的话我们可以先把原序列变换好,把每个数放在最终的位置上,再一步一步向上合并。
一句话就可以 $O(n)$ 预处理出位置 $i$ 最终的位置 $rev[i]$:
//原理也很简单,将高bit-1位(也就是i/2)反转,再将第一位补到最高位。
fo(i,0,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
// luogu-judger-enable-o2 #include<iostream> #include<cstdio> #include<cmath> using namespace std; const int MAXN = 4e6 + 10; inline int read() { char c = getchar(); int x = 0, f = 1; while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();} while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();} return x * f; } const double Pi = acos(-1.0); struct complex { double x, y; complex (double xx = 0, double yy = 0) {x = xx, y = yy;} } a[MAXN], b[MAXN]; complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);} complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);} complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分 int N, M; int bit, r[MAXN]; int limit = 1; void fast_fast_tle(complex *A, int type) { for (int i = 0; i < limit; i++) if (i < r[i]) swap(A[i], A[r[i]]); //求出要迭代的序列 for (int mid = 1; mid < limit; mid <<= 1) { //待合并区间的长度的一半 complex Wn( cos(Pi / mid) , type * sin(Pi / mid) ); //单位根 for (int R = mid << 1, j = 0; j < limit; j += R) { //R是区间的长度,j表示前已经到哪个位置了 complex w(1, 0); //幂 for (int k = 0; k < mid; k++, w = w * Wn) { //枚举左半部分 complex x = A[j + k], y = w * A[j + mid + k]; //蝴蝶效应 A[j + k] = x + y; A[j + mid + k] = x - y; } } } } int main() { int N = read(), M = read(); for (int i = 0; i <= N; i++) a[i].x = read(); for (int i = 0; i <= M; i++) b[i].x = read(); while (limit <= N + M) limit <<= 1, bit++; for (int i = 0; i < limit; i++) r[i] = ( r[i >> 1] >> 1 ) | ( (i & 1) << (bit - 1) ) ; fast_fast_tle(a, 1); fast_fast_tle(b, 1); for (int i = 0; i <= limit; i++) a[i] = a[i] * b[i]; fast_fast_tle(a, -1); for (int i = 0; i <= N + M; i++) printf("%d ", (int)(a[i].x / limit + 0.5)); return 0; }
参考链接:
(建议直接看大佬的,我都是copy过来的,整理一下思路而已)
1 https://www.cnblogs.com/zwfymqz/p/8244902.html
2. https://blog.csdn.net/enjoy_pascal/article/details/81478582