2023. 2. 25. 13:22ㆍ수학
우리는 임의의 두 다항식을 좀 더 빠르게 곱하고 싶다.
각 $n-1$차 다항식을 $A(z), B(z)$라고 하자. 또 $C(z):=A(z)B(z)$라고 하자.
FFT의 핵심 아이디어는 Lagrange Interpolation이다.
Lemma. (Lagrange Interpolation) 임의의 $m-1$차 다항식(in field)은 $m$개의 서로 다른 점으로 유일하게 결정된다.
Proof. 임의의 $m-1$차 다항식을 $f(x):=a_{m-1}x^{m-1}+a_{m-2}x^{m-2}+\cdots+a_{1}x+a_{0}$라고 두자. 이 때 우리가 서로 다른 $m$개의 점 $\{b_{i}\}$들에 대하여 그 함숫값들 $\{f(b_{i})\}$를 알고 있다고 가정하자. 그럼 임의의 $i$에 대해서
\[a_{m-1}b_{i}^{m-1}+a_{m-2}b_{i}^{m-2}+\cdots+a_{1}b_{i}+a_{0}=f(b_{i})\]
위 식을 행렬로 다타내면
\[\begin{bmatrix} b_{1}^{m-1} & b_{1}^{m-2} & \cdots & b_{1} & 1 \\ b_{2}^{m-1} & b_{2}^{m-2} & \cdots & b_{2} & 1 \\ & & \vdots & & \\ b_{m-1}^{m-1} & b_{m-1}^{m-2} & \cdots & b_{m-1} & 1 \\ b_{m}^{m-1} & b_{m}^{m-2} & \cdots & b_{m} & 1 \end{bmatrix} \begin{bmatrix} a_{m-1} \\ a_{m-2} \\ \vdots \\ a_{1} \\ a_{0} \end{bmatrix} = \begin{bmatrix} f(b_{1}) \\ f(b_{2}) \\ \vdots \\ f(b_{m-1}) \\ f(b_{m}) \end{bmatrix} \]
가장 왼쪽의 행렬은 Vandermonde Matrix로 행렬을 $V$라 했을 때 행렬식은 \[\det(V)=\prod_{1\le i<j\le m} (b_{j}-b_{i})\]
이고 $b_{i}\ne b_{j}$ if $i\ne j$이므로 $\det(V)\ne 0$이다. $V$가 invertible하므로 $\{a_{i}\}$ 들은 유일하게 결정된다.
이 사실을 이용해 $A(z_{0}),\cdots,A(z_{2n-1})$과 $B(z_{0}),\cdots,B(z_{2n-1})$의 값을 구하면 $C(z_{i})=A(z_{i})B(z_{i})$라는 사실을 이용해 $2n-1$개의 서로 다른 점들을 알 수 있고 $\deg(C(z))=2n-2$이므로 $C(z)$를 유일하게 결정할 수 있다. 이렇게 어떤 다항식을 서로 다른 점들로 표현하는 방법을 DFT(Discrete Fourier Transform)라고 한다. 우리는 대입할 값으로 primitive root of unity $w$를 선택한다.(i.e. $w^{n}=1$ and $w^{m}\ne 1$ with $0<m<n$). 우리는 $A(w^{0}), A(w^{1}),\cdots,A(w^{n-1})$의 값을 빠르게 구하고 싶다!! 다항식 $A(x)=\sum_{i=0}^{n-1} a_{i}x^{i}$라고 하자. 그럼...
\[A(x) = (a_{0}+a_{2}x^{2}+\cdots)+x(a_{1}+a_{3}x^{2}+\cdots)=A_{even}(x^{2})+xA_{odd}(x^{2})\]
with
\[A_{even}(x) = \sum_{i=0}^{\lfloor{N/2\rfloor}}a_{2i}x^{i},\ A_{odd}(x)=\sum_{i=0}^{\lfloor{N/2\rfloor}}a_{2i+1}x^{i}\]
이 사실을 이용하면 다음과 같은 방법이 나타난다! $n$이 2의 거듭제곱이라고 하자.
1. $n/2$개의 점 $w^{0},\cdots,w^{n-2}$에 대해 $A_{even}(x),A_{odd}(x)$의 함숫값을 모두 구한다.(DFT 한다)
- 각각에 대해서 또 재귀적으로 구하자.(이 과정이 계속되려면 $n/2$역시 짝수여야 한다.)
2. 다음과 같은 관계 식을 이용해 $A(x)$의 $w^{0},\cdots,w^{n-1}$에 대한 함숫값을 모두 구한다.
\[A(w^{j})=A_{even}(w^{2j})+w^{j}A_{odd}(w^{2j}),\ A(w^{j+n/2})=A_{even}(w^{2j})-w^{j}A_{odd}(w^{2j})\]
with $0\le j < n/2$
시작 복잡도는 $T(n)=2T(n/2)+O(n)$으로 Master's Theorem에 의해 $O(n\log n)$이다.
우리는 꽤나 아름다운 방법으로 $C(z)$의 서로 다른 $2n-1$개의 점을 구했다. 문제는 어떻게 점에서 다항식으로 복구를 하냐는 것이다...!! $n-1$차 다항식 $A(z)$에 대해 $n$개의 점 $\{w_{i}\}$와 그 함숫값 $\{y_{i}\}$를 알고 있다고 하자. 그럼 다음 행렬이 성립함을 앞에서 보았다.
\[\begin{bmatrix} 1 & 1 & \cdots & 1 & 1 \\ 1 & w & \cdots & w^{n-2} & w^{n-1} \\ & & \vdots & & \\ \\ 1 & w^{n-2} & \cdots & w^{(n-2)(n-2)} & w^{(n-2)(n-1)} \\ 1 & w^{n-1} & \cdots & w^{(n-1)(n-2)} & w^{(n-1)(n-1)} \end{bmatrix} \begin{bmatrix} a_{0} \\ a_{1} \\ \vdots \\ a_{n-2} \\ a_{n-1} \end{bmatrix} = \begin{bmatrix} y_{0} \\ y_{1} \\ \vdots \\ y_{n-2} \\ y_{n-1} \end{bmatrix} \]
위 행렬의 역행렬을 구하면 다음을 얻는다.
\[\begin{bmatrix} a_{0} \\ a_{1} \\ \vdots \\ a_{n-2} \\ a_{n-1} \end{bmatrix} = \frac{1}{n} \begin{bmatrix} 1 & 1 & \cdots & 1 & 1 \\ 1 & w^{-1} & \cdots & w^{-(n-2)} & w^{-(n-1)} \\ & & \vdots & & \\ \\ 1 & w^{-(n-2)} & \cdots & w^{-(n-2)(n-2)} & w^{-(n-2)(n-1)} \\ 1 & w^{-(n-1)} & \cdots & w^{-(n-1)(n-2)} & w^{-(n-1)(n-1)} \end{bmatrix}\begin{bmatrix} y_{0} \\ y_{1} \\ \vdots \\ y_{n-2} \\ y_{n-1} \end{bmatrix} \]
이 말은 $w$를 $w^{-1}$로 바꾼 후 $y_{0},\cdots,y_{n-1}$을 계수로 가지는 다항식을 DFT한 뒤 $n$으로 나누면 된다는 말이다. 즉, DFT의 역변환 역시 시간 복잡도 $O(n\log n)$을 가진다.
우리는 FFT의 전 과정을 살펴보았다. 이제 FFT를 좀 더 빠르게!! 개선해보자. 위의 설명에서 재귀적으로 DFT를 구현했는데 재귀를 풀어서 속도를 향상시키자.
$N=8$일 때의 상황을 생각해보자. 처음에 우리에게 주어진 것은 아래와 같은 인덱스를 가진 배열 $a$이다.
\[0,\ 1,\ 2,\ 3,\ 4,\ 5,\ 6,\ 7\]
이 배열을 재귀에서 사용하는 방식대로 다시 정렬해보자. 우리에게는 다음과 같은 식이 있다.
\[A(w^{j})=A_{even}(w^{2j})+w^{j}A_{odd}(w^{2j}),\ A(w^{j+n/2})=A_{even}(w^{2j})-w^{j}A_{odd}(w^{2j})\]
with $0\le j < n/2$
계산을 편하게 하기 위해서 $A_{even}(n)$끼리 모으고 $A_{odd}(n)$끼리 모으자. 그럼 다음과 같다.
\[0,\ 2,\ 4,\ 6,\ 1,\ 3,\ 5,\ 7\]
재귀가 한번 더 들어가므로 다음과 같이 변한다.
\[0,\ 4,\ 2,\ 6,\ 1,\ 5,\ 3,\ 7\]
우리가 필요한 것은 $A_{even}(x),A_{odd}(x)$에 $w^{0},w^{2},w^{4},w^{6}$을 대입한 값이다. 이를 얻기 위해 필요한 것은 $A_{even\ even}(x), A_{even\ odd}(x), A_{odd\ even}(x), A_{odd\ odd}(x)$에 $w^{0}, w^{4}$을 대입한 값이다. 그 이유는 거꾸로 생각해보면 알 수 있다. 거꾸로 다음과 같은 계산을 해보자. 맨 마지막 배열에서 $[0,4],[2,6],[1,5],[3,7]$에서 연산을 해준다.
\[[0,4] \rightarrow [a_{0}+w^{0}a_{4}, a_{0}-w^{0}a_{4}],\ [2,6] \rightarrow [a_{2}+w^{0}a_{6}, a_{2}-w^{0}a_{6}]\]
$-w^{0}=w^{4}$ 라는 사실을 기억하자. 이제 $[0,2],[4,6]$에서 연산을 해주면
\[[0, 2] \rightarrow [(a_{0}+w^{0}a_{4}) + w^{0}(a_{2}+w^{0}a_{6}),\ (a_{0}+w^{0}a_{4}) - w^{0}(a_{2}+w^{0}a_{6})] = [A_{even}(w^{0}),\ A_{even}(w^{4})]\]
\[[4, 6] \rightarrow [(a_{0}-w^{0}a_{4}) + w^{2}(a_{2}-w^{0}a_{6}),\ (a_{0}-w^{0}a_{4}) - w^{2}(a_{2}-w^{0}a_{6})]=[A_{even}(w^{2}),A_{even}(w^{6})] \]
주어진 $A_{even}, A_{odd}$의 관계식에서 $j=0$인 경우를 알면 $w^{0},w^{0+4}$를 알 수 있고 $j=2$인 경우를 알면 $w^{2},w^{2+4}$를 알 수 있다고 했으니 위 결과는 합리적이다. 위와 같이 수들을 정렬해두고 연산을 하면 되겠다! 위의 수들은 놀라운 특성이 있다. 원래의 인덱스 수의 비트를 반대로 돌리면 우리가 구했던 마지막 배열이 나온다. 예를 들어 $3=011_{(2)}$의 비트 반전은 $110_{(2)}=6$이다. 마지막 배열의 $6$의 위치에 $3$이 있다는 것을 알 수 있다.
아래에는 위의 사실들을 이용한 FFT 코드가 있다.
# 코드 보기
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
#include<bits/stdc++.h>
using namespace std;
using base = complex<double>;
const double PI = acos(-1);
void FFT(vector<base> &a, bool inv){
int n = a.size(), j = 0;
for(int i = 1;i < n;i++){ //bit reverse
int bit = (n >> 1);
while(j >= bit){
j -= bit;
bit >>= 1;
}
j += bit;
if(i < j) swap(a[i], a[j]);
}
double ang = 2 * acos(-1) / n * (inv ? -1 : 1);
vector<base> roots(n / 2);
for(int i=0;i < n / 2;i++) roots[i] = base(cos(ang * i), sin(ang * i)); //calculate w^n
for(int i = 2;i <= n;i <<= 1){
int step = n / i;
for(int j = 0;j < n;j += i){
for(int k = 0;k < i / 2;k++){
base u = a[j + k], v = a[j + k + i / 2] * roots[step * k];
a[j + k] = u + v;
a[j + k + i / 2] = u - v;
}
}
}
if(inv) for(int i = 0;i < n;i++) a[i] /= n;
}
|
cs |
# 닫기
우리가 FFT를 사용할 때 주의해야 할 점은 안정성이 떨어진다는 것이다. 그 말은, 배열의 원소 하나 하나의 값이 크면 계산에서 오차가 발생할 수 있다는 것이다. 그러기 위해서는 NFT를 사용할 수 있으나 이곳에서는 실수 FFT로 안정성을 올리는 방법을 소개한다.
정수 벡터 2개 $V,\ W$가 입력으로 들어왔다고 하자. 각 벡터의 길이를 더한 값이 결과 벡터의 길이가 될 것이다. 따라서 두 벡터의 길이의 합보다 큰 가장 작은 2의 거듭제곱을 선택한 뒤 그 길이를 $n$이라고 하자. 정수 벡터 $V, W$를 길이가 $n$인 복소수 벡터 $V_{1}, V_{2}$에 담을 것이다! 다음과 같이 담아보자.
V_{1}[i] = base(V[i] >> 15, V[i] & 32767) V_{2}[i] = base(W[i] >> 15, W[i] & 32767)
이 말은 실수 부분에 앞쪽 비트를, 허수 부분에 뒷쪽 15비트를 담겠다는 것이다. 이것은 다항식의 관점으로 다음과 같이 해석할 수 있다. $V_{F}(z)$를 $V(z)$의 계수들의 앞쪽 비트를 계수로 하는 다항식으로 정의하고 $V_{R}(z)$를 $V(z)$의 계수들의 뒷쪽 비트를 계수로 하는 다항식으로 정의하자. 그럼 $V_{1}(z)=V_{F}(z)+i V_{R}(z)$이고 같은 방법으로 $V_{2}(z)=W_{F}(z)+i W_{R}(z)$가 된다. 각각의 $V_{1}, V_{2}$벡터를 FFT를 돌린 뒤 $conj$ 연산을 이용해 $V_{F}$와 $V_{R}$을 구하자. 그리고 새로운 다항식 $R_{1}(z)=V_{F}(z)V_{2}(z),\ R_{2}(z)=V_{R}(z)V_{2}(z)$를 정의하자. 얻어진 점들을 DFT 역변환을 통해 다항식을 얻어내자. 그럼 다음과 같은 관계가 있다.
\[V(z)W(z) = \Re(R_{1}(z))\cdot 2^{30} + (\Im(R_{1}(z))+\Re(R_{2}(z)))\cdot 2^{15} + \Im(R_{2}(z))\]
그 이유는
\[\Re(R_{1})=\Re(V_{F}W_{F}+iV_{F}W_{R})=V_{F}W_{F}\]
\[\Im(R_{1})=\Im(V_{F}W_{F}+iV_{F}W_{R})=V_{F}W_{R}\]
\[\Re(R_{2})=\Re(V_{R}W_{F}+iV_{R}W_{R})=V_{R}W_{F}\]
\[\Im(R_{2})=\Im(V_{R}W_{F}+iV_{R}W_{R})=V_{R}W_{R}\]
이며
\[VW=(V_{F}\cdot 2^{15}+V_{R})(W_{F}\cdot 2^{15}+W_{R})\]
이기 때문이다. 아래에는 위의 테크닉을 구현한 FFT의 전체적인 코드가 있다!
# 코드 보기
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
#include<bits/stdc++.h>
using namespace std;
using base = complex<double>;
using ll = long long;
const double PI = acos(-1);
void FFT(vector<base> &a, bool inv){
int n = a.size(), j = 0;
for(int i = 1;i < n;i++){ //bit reverse
int bit = (n >> 1);
while(j >= bit){
j -= bit;
bit >>= 1;
}
j += bit;
if(i < j) swap(a[i], a[j]);
}
double ang = 2 * acos(-1) / n * (inv ? -1 : 1);
vector<base> roots(n / 2);
for(int i=0;i < n / 2;i++) roots[i] = base(cos(ang * i), sin(ang * i)); //calculate w^n
for(int i = 2;i <= n;i <<= 1){
int step = n / i;
for(int j = 0;j < n;j += i){
for(int k = 0;k < i / 2;k++){
base u = a[j + k], v = a[j + k + i / 2] * roots[step * k];
a[j + k] = u + v;
a[j + k + i / 2] = u - v;
}
}
}
if(inv) for(int i = 0;i < n;i++) a[i] /= n;
}
vector<ll> multiply(vector<ll> &v, vector<ll> &w, ll mod){
int n = 2; while(n < v.size() + w.size()) n <<= 1;
vector<base> v1(n), v2(n), r1(n), r2(n);
for(int i=0; i<v.size(); i++) //비트를 쪼갬(앞에 비트, 뒤에 비트) V1(Z) = V_{F}(Z) + iV_{R}(Z) (V_{F}=V front 즉, 앞 비트 V_{R}=V rear 즉, 뒷 비트)
v1[i] = base(v[i] >> 15, v[i] & 32767); //(a,b)
for(int i=0; i<w.size(); i++)
v2[i] = base(w[i] >> 15, w[i] & 32767);//(c,d) V2(Z) = W_{F}(Z) + iW_{R}(Z)
FFT(v1, 0);
FFT(v2, 0);
for(int i=0; i<n; i++){
int j = (i ? (n - i) : i); //i=0이면 j=0, i>0이면 j= n-i 왜냐면 (a+bi)w^{i}의 conjugate는 (a-bi)w^{-i}=(a-bi)w^{n-i}이기 때문.
base ans1 = (v1[i] + conj(v1[j])) * base(0.5, 0); //V_{F}(Z)
base ans2 = (v1[i] - conj(v1[j])) * base(0, -0.5); //V_{R}(Z)
base ans3 = (v2[i] + conj(v2[j])) * base(0.5, 0); //W_{F}(Z)
base ans4 = (v2[i] - conj(v2[j])) * base(0, -0.5); //W_{R}(Z)
r1[i] = (ans1 * ans3) + (ans1 * ans4) * base(0, 1); //V_{F}(Z)V2(Z)
r2[i] = (ans2 * ans3) + (ans2 * ans4) * base(0, 1); //V_{R}(Z)V2(Z)
}
FFT(r1, 1);
FFT(r2, 1);
vector<ll> ret(n);
for(int i=0; i<n; i++){
ll av = (ll)round(r1[i].real()); //V_{F}W_{F}(Z) (앞비트 2개가 곱해졌으니 2^30곱하기)
ll bv = (ll)round(r1[i].imag()) + (ll)round(r2[i].real()); //V_{F}W_{R}(Z) + W_{F}V_{R}(Z) (앞비트와 뒷비트가 곱해졌으니 2^15 곱하기)
ll cv = (ll)round(r2[i].imag()); //W_{R}(Z)V_{R}(Z) (뒷비트 끼리의 연산)
av %= mod, bv %= mod, cv %= mod;
ret[i] = ((av << 30) + (bv << 15) + cv) % mod;
ret[i] = (ret[i] + mod) % mod;
}
return ret;
}
|
cs |
# 닫기
'수학' 카테고리의 다른 글
Implementation of Lucy-Hedgehog algorithm (1) | 2024.01.22 |
---|---|
$x^2+xy+y^2$꼴 자연수 (0) | 2023.09.02 |
Funny fact in Linear Algebra (0) | 2022.11.02 |
Lucy-Hedgehog Algorithm (0) | 2022.09.18 |
Extension of Wilson's Theorem and its application (12) | 2022.07.14 |