[BOJ 17105] 골드바흐 트리플

2022. 3. 4. 20:18PS/백준

어떤 홀수가 주어질 때 그 홀수를 세 개의 소수의 합으로 나타내는 방법이 얼마나 많은지 구하는 문제이다.

주어진 홀수를 $n$이라 하자. 그럼 경우는 3가지가 있다. 각 경우마다의 경우의 수를 $a, b, c$라고 두자.

1. $n=p+q+r$

2. $n=2p+q$

3. $n=3p$

 

가장 쉽게 생각할 수 있는 것은 소수 차의 항만을 가지며 계수가 1인 다항식 $P(x)$를 생각할 수 있다. 이 $P(x)^{3}$에서 $x^n$의 계수가 우리가 구하고자 하는 답과 가장 유사할 것이다. 하지만 그 계수는 $6a+3b+c$가 된다. $c$는 쉽게 구할 수 있으나 $b$를 구하는 것이 어렵다. $b$는 $P(x)$에서 각 항의 차수만 2배가 된 식 $Q(x)$와 $P(x)$를 곱하게 된다면 $x^n$의 계수로 구할 수 있다. 

 

그럼 2가지 케이스로 나누어 답을 구하자. ($F(x)$에서 $x^n$의 계수는 $F[n]$이라 쓰겠다.)

- $R(x)=P(x)^{3}$

- $S(x)=Q(x)P(x)$

  • $c>0$

$(R[n]-3S[n]+2)/6+S[n]$ ($S(x)$에 이미 $c$의 경우가 포함되어 있다.)

  • $c=0$

$(R[n]-3S[n])/6+S[n]$

 

문제의 제한에서 $n$은 최대 $10^6$이다. 이런 점에서 위 상태 그대로 FFT를 이용하는 것은 시간초과를 일으키기 쉽다. $n=(2a+1)+(2b+1)+(2c+1)$임을 고려하면 $(n-3)/2=a+b+c$로 크기를 반으로 줄일 수 있다. 

여기서 한가지 더 고려하야할 것은 소수 2의 존재이다. 소수 2로 생각할 수 있는 것은

1. $n=2+p+q$

2. $n=4+p$

이다. 1번의 경우 $p, q$모두 홀수이므로 불가능하다. 따라서 2번만 고려하자. $n-4$가 소수인 경우 1을 더해주면 답이 나온다. 

 

이 문제에서 발견한 것은 1번 코드와 2번 코드가 같은 의미를 가진다는 것이다.

반면 2번의 코드가 더 빠르다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//1 code
vector <ll> multiply(const vector<ll> &a, const vector<ll> &b)
{
    vector <base> fa(all(a)), fb(all(b));
    vector <ll> res(sz(a));
    fft(fa,false); fft(fb,false);
    for (int i=0;i<SZ;i++) fa[i] *= fb[i];
    fft(fa,true);
    for (int i=0;i<SZ;i++) res[i] = ll(fa[i].real()+(fa[i].real()>0?0.5:-0.5));
    return res;
}
 
int main(){
    R = multiply(a, multiply(a, a));
}
cs

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//2 code
vector <ll> multiply(const vector<ll> &a)
{
    vector <base> fa(all(a));
    vector <ll> res(sz(a));
    fft(fa,false);
    for (int i=0;i<SZ;i++) fa[i] *= fa[i] * fa[i];
    fft(fa,true);
    for (int i=0;i<SZ;i++) res[i] = ll(fa[i].real()+(fa[i].real()>0?0.5:-0.5));
    return res;
}
 
int main(){
    R = multiply(a);
}
cs

 

이렇게 좋은 방법을 알아버린 이상 많이 써먹을 것이다. 완성된 코드는 아래에 있다.

 

더보기

# 코드 보기

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#define _USE_MATH_DEFINES
#include <math.h>
#include <complex>
#include <vector>
#include <algorithm>
using namespace std;
 
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(),(v).end()
typedef complex<double> base;
typedef long long ll;
const ll MAX = 1000001;
const ll SZ = 1<<21;
char num[MAX]={1,1};
 
void fft(vector <base> &a, bool invert)
{
    int n = sz(a);
    for (int i=1,j=0;i<n;i++){
        int bit = n >> 1;
        for (;j>=bit;bit>>=1) j -= bit;
        j += bit;
        if (i < j) swap(a[i],a[j]);
    }
    for (int len=2;len<=n;len<<=1){
        double ang = 2*M_PI/len*(invert?-1:1);
        base wlen(cos(ang),sin(ang));
        for (int i=0;i<n;i+=len){
            base w(1);
            for (int j=0;j<len/2;j++){
                base u = a[i+j], v = a[i+j+len/2]*w;
                a[i+j] = u+v;
                a[i+j+len/2= u-v;
                w *= wlen;
            }
        }
    }
    if (invert){
        for (int i=0;i<n;i++) a[i] /= n;
    }
}
 
void multiply(const vector<ll> &a,const vector<ll> &b,vector<ll> &res, vector<ll> &res2)
{
    vector <base> fa(all(a)), fb(all(b));
    fft(fa,false); fft(fb,false);
       vector <base> faa(all(fa));
    for (int i=0;i<SZ;i++) fa[i] *= fb[i];
    for (int i=0;i<SZ;i++) faa[i] *= faa[i] * faa[i];
    fft(fa,true);
    fft(faa,true);
    for (int i=0;i<SZ;i++) res2[i] = ll(fa[i].real()+(fa[i].real()>0?0.5:-0.5));
    for (int i=0;i<SZ;i++) res[i] = ll(faa[i].real()+(faa[i].real()>0?0.5:-0.5));
}
vector<ll> a(SZ),res(SZ),b(SZ),res2(SZ);
 
ll f(ll n){
    if(n%3==0 && !num[n/3]){
        ll ret=(res[(n-3)>>1]-res2[(n-3)>>1]*3+2)/6+res2[(n-3)>>1];
        return ret+!num[n-4];
    }
    else{
        ll ret=(res[(n-3)>>1]-res2[(n-3)>>1]*3)/6+res2[(n-3)>>1];
        return ret+!num[n-4];
    }
}
 
int main(){
    int i,j,T;
    ll t,n;
    for(i=2;i<=1000;i+=2){
        if(num[i]) continue;
        for(j=i*i;j<MAX;j+=i){
            num[j]=1;
        }
        if(i==2) i--;
    }
    for(i=3;i<MAX;i++){
        if(!num[i]){
            a[i>>1]=1//(p-1)
            if(i-1<MAX/2) b[i-1]=1//(p-1)/2
        }
    }
    multiply(a,b,res,res2); //n=2p+q 개수 구하기 & 세제곱  
    scanf("%d",&T);
    for(i=0;i<T;i++){
        scanf("%lld",&n);
        printf("%lld\n",f(n));
    }
}
cs

# 닫기

 

'PS > 백준' 카테고리의 다른 글

[BOJ 5051] 피타고라스의 정리  (0) 2022.03.05
다이아몬드 달성!!!!!!  (0) 2022.03.05
[BOJ 16808] Identity Function  (0) 2022.02.14
[BOJ 13714] 약수의 개수  (0) 2022.02.08
[BOJ 18496] Euclid's Algorithm  (0) 2022.01.16