[Codeup 2534] 무리수 거듭제곱

2021. 8. 15. 12:01PS/코드업

무리수를 직접 거듭제곱 하기에는 부동소수점 오차가 존재하기 때문에 부적절하다. 우리는 새로운 방법을 찾아야 한다.

문제에서 주어진 무리수가 $a+b\sqrt{c}$라고 하자. 또 이 무리수의 켤래 무리수를 $a-b\sqrt{c}$라고 잡을 수 있다.

두 무리수는 다음 방정식을 만족하는 두 근이다.

$x^{2}-2ax+a^{2}-b^{2}c=0$

또 양변에 $x$를 많이 곱하면

$x^{n}-2ax^{n-1}+(a^{2}-b^{2}c)x^{n-2}=0$

이 된다. 각각의 근을 $\alpha, \beta$라고 하면

$\alpha^{n}-2a\alpha^{n-1}+(a^{2}-b^{2}c)\alpha^{n-2}=0$ (1)

$\beta^{n}-2a\beta^{n-1}+(a^{2}-b^{2}c)\beta^{n-2}=0$ (2)

이 된다. 수열 $s_{n}$을 다음과 같이 정의하자.

$s_{n}=\alpha^{n}+\beta^{n}$

(1)식과 (2)식을 더하면

$s_{n}-2as_{n-1}+(a^2-b^{2}c)s_{n-2}=0$

$s_{n}=2as_{n-1}-(a^2-b^{2}c)s_{n-2}$

이렇게 만들어진 점화식을 이용할 수 있다.

반면 $n$제곱 연산을 천만번이나 해야하므로 행렬을 사용해야한다.

$\begin{bmatrix} s_{n+1} \\ s_{n} \end{bmatrix} = \begin{bmatrix} 2a & b^{2}c-a^{2} \\ 1 & 0 \end{bmatrix} \begin{bmatrix} s_{n} \\ s_{n-1} \end{bmatrix}$

$\begin{bmatrix} s_{n+1} \\ s_{n} \end{bmatrix} = \begin{bmatrix} 2a & b^{2}c-a^{2} \\ 1 & 0 \end{bmatrix}^{n} \ \begin{bmatrix} s_{1} \\ s_{0} \end{bmatrix}$

결국 우리는 $s_{n}$을 계산해냈고 우리에게 필요한 것은 $\alpha^{n}$이다.

문제의 조건을 보면 $a(a-2)<b^{2}c \le a(a+2)$이다.

$-1<a-b\sqrt{c}<1$을 생각해보자.

$a-1<b\sqrt{c}<a+1$

$a^{2}-2a+1<b^{2}c<a^{2}+2a+1$ ($a \ge 1$이므로)

$a(a-2) < b \le a(a+2)$

이고 

$-1<\beta^{n}<1$

이다. 만약 $n$이 짝수이거나  $\beta>0$이면 $\beta^{n}>0$일 것이므로 $s_{n}-1$을 계산하면 된다.

그 외의 경우는 $s_{n}$만 계산하면 된다.

 

더보기

# 코드 보기

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include<stdio.h>
typedef long long int lo;
lo t[2][2];
lo h[2][2];
lo y=1e9+7;
lo abs(lo a){
    if(a<0return -a;
    return a;
}
lo mul(lo n,lo k){
    lo s=0,h=n%y,p=1;
    if(k<0){
        p=-1;
        k=abs(k);
    }
    for(;k;k>>=1){
        if(k&1) s=(s+h)%y;
        h=(h+h)%y;
    }
    return s*p;
}
void gop(lo (*p)[2],lo (*q)[2]){
    lo k,i,j,n[2][2]={0,0,0,0};
    for(k=0;k<2;k++){
        for(i=0;i<2;i++){
            for(j=0;j<2;j++){
                n[k][i]+=mul(p[k][j],q[j][i])%y;
            }
            n[k][i]%=y;
        }
    }
    for(i=0;i<2;i++){
        for(j=0;j<2;j++){
            p[i][j]=n[i][j];
        }
    }
}
int f(int k){
    k--;
    for(;k;k>>=1){
        if(k&1) gop(t,h);
        gop(h,h);
    }
}
 
int main(){
    lo n,m,a,b,c,mi=0;
    scanf("%lld %lld %lld %lld",&a,&b,&c,&n);
    t[0][0]=h[0][0]=2*a;
    m=b*b*c-a*a;
    if(n%2==0) mi=-1;
    if(m<0) mi=-1;
    t[0][1]=h[0][1]=m;
    t[1][0]=h[1][0]=1;
    f(n);
 
    n=2*a*t[1][0]+2*t[1][1]+mi;
    
    n%=y;
    if(n<0) n+=y;
    printf("%lld",n%y);
}
 
 
cs

# 닫기