[BOJ 20539] Xorshift64

2024. 1. 31. 00:22PS/백준

새로운 유형을 도전해서 그런가 좀 어려운 문제였다. 

우리는 1부터 $2^{64}-1$까지의 수를 모두 2진수로 표현할 수 있다. 그리고 진수로 표현된 자릿수를 이용해서 모든 수를 64-dimensional vector에 대응시키자. 우리는 각 bit의 자릿수에 대해서 xor 연산을 생각하기 때문에 base field가 $\mathbb{Z}/2\mathbb{Z}$인 64-dimensional vector space를 고려한다. 그렇게 된다면 덧셈이 곧 xor연산이 된다. 

이제 Xorshift64 함수는 linear operator로써 작동하고 이에 대응되는 행렬 $M\in M_{64\times 64}(\mathbb{Z}/2\mathbb{Z})$를 생각할 수 있다. 이 행렬을 구하는 방법은 vector space의 basis인 $e_{i}$ 들을 순차적으로 Xorshift64 함수에 대입해서 나온 결과를 column vector로 사용하면 된다. 

문제에서 주어진 시작 값을 $s$, 몇 번째인지 알아내야할 값을 $t$라고 두고 각각이 vector 로써 표현되었다고 하자. 그럼 우리가 구해야 하는 것은 

\[M^{x}s=t\]

를 만족하는 자연수 $x$이다. ($x$의 존재성은 문제의 설명에서 보장되어 있다.) 조금의 관찰을 해보면 $M$이 invertible하다는 사실을 알 수 있다.($M$의 order는 $2^{64}-1$로 finite하다.) 따라서 다음이 성립한다. 

\[M^{x}(M^{i}s)=M^{i}t\]

우리는 $M^{x}$가 정확히 어떤 행렬인지는 모른다. 그것을 알아내기 위해서 $i$를 0부터 63까지 변화시키자. 그럼 총 64개의 행렬에 대한 방정식을 얻고, 각 $M^{i}s$와 $M^{i}t$를 적절히 행과 열에 대응시킨 행렬을 $X$, $Y$라고 할 때, 다음을 얻는다. 

\[XM^{x}=Y\]

그럼 $X,Y$로 agumentation matrix를 만든 뒤 가우스 소거법으로 $X$를 단위 행렬로 만들어서 나온 결과, $Y$가 $M^{x}$가 된다. 여기서 제기할 수 있는 의문은 $X$를 단위 행렬로 만들 수 있느냐 인데, 이것은 실제로 가능하다. 그 이유는 $M^{x}$의 존재성이 보장되었기 때문이다. $M^{x}$가 존재하기에 $XA=Y$에 대한 solution $A$가 존재하고 이것이 존재하기 때문에 $X$가 invertible하다. $M^{x}$를 쉽게 $M_{x}$라고 쓰자. 그럼 우리는 다음을 만족하는 $x$를 찾는 문제가 된다.

\[M^{x}=M_{x}\]

$M$의 order를 $n$이라고 두고 $n$의 소인수 중 하나를 $p$라고 하자. 그럼

\[M^{n/p\cdot x}=(M^{n/p})^{x}= M_{x} ^{n/p}\]

가 성립한다. 이것에 대해서 이산 로그 문제를 풀어서 $x$를 얻었다고 하자. 그럼 $x=pk+q(0\le q<p)$이고 $n/p x=nk+n/p \cdot q$이기 때문에, 방금 얻은 $x$는 사실 $x \pmod{p}$인 것이다. $n=2^{64}-1$의 소인수들은 모두 $n$을 1번씩만 나눈다는 것을 생각하자. 그럼 각각의 $n$의 소인수 $p$에 대해서 $x$를 계산한 후 CRT로 합치면 된다. 

이것이 Pohlig-Hellman 알고리즘이다. 

 

이것들을 구현하는데 시간이 꽤나 걸린다 ㅜㅜ. 게다가, 행렬 곱셈이 상당히 많이 사용된다. 

행렬 곱셈에서 이루어지는 덧셈이나 곱셈이 $\mathbb{Z}/2\mathbb{Z}$에서 이루어진다는 것에 착안하면

\[a+b \Leftrightarrow a ^{\wedge} b\]

\[a*b \Leftrightarrow a \& b\]

라는 것을 알 수 있다. 이 것을 통해서 행렬 곱을 최적화 시키도록 하자. 

코드는 아래에 있다. 

 

더보기

# 코드 보기

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#include <bits/stdc++.h>
typedef unsigned long long llu;
typedef long long ll;
using namespace std;
typedef vector<vector<int> > matrix;
llu p[10= {3517257641655376700417};
llu mod = 18446744073709551615ULL;
int bit_s[64], bit_t[64];
matrix M(64vector<int>(64)), Mx(64vector<int>(64));
 
llu xorshift64(llu x) {
    x ^= x << 13;
    x ^= x >> 7;
    x ^= x << 17;
    return x;
}
 
void copy(matrix &A, matrix &B){
    int i, j;
    for(i = 0;i < 64;i++){
        for(j = 0;j < 64;j++) A[i][j] = B[i][j];
    }
}
 
void mat_mul(matrix &A, const matrix &B)
{
    int tmp;
    matrix C(64vector<int>(64));
    for (int k = 0; k < 64; k++)
    {
        for (int i = 0; i < 64; i++)
        {
            tmp = A[i][k];
            for (int j = 0; j < 64; j++){
                C[i][j] ^= (tmp & B[k][j]);
            }
        }
    }
    copy(A, C);
}
 
void vec_mul(matrix &A, int B[64]){
    int i, j;
    int C[64];
    for(i = 0;i < 64;i++){
        C[i] = 0;
        for(j = 0;j < 64;j++){
            C[i] += A[i][j] * B[j];
        }
    }
    for(i = 0;i < 64;i++){
        B[i] = (C[i] % 2 + 2) % 2;
    }
}
 
void mat_pow(matrix &A, llu k){
    int i, j;
    matrix S(64vector<int>(64));
    for(i = 0;i < 64;i++){
        for(j = 0;j < 64;j++) S[i][j] = 0;
        S[i][i] = 1;
    }
    for(;k;k >>= 1){
        if(k & 1) mat_mul(S, A);
        mat_mul(A, A);
    }
    copy(A, S);
}
 
void gauss_elem(matrix &A, matrix &B){
    int i, j, k;
    for(i = 0;i < 64;i++){
        if(A[i][i] == 0){
            for(j = i + 1;j < 64;j++if(A[j][i]) break;
            if(j == 64continue;
            for(k = 0;k < 64;k++){
                int tmp = A[i][k];
                A[i][k] = A[j][k];
                A[j][k] = tmp;
                
                tmp = B[i][k];
                B[i][k] = B[j][k];
                B[j][k] = tmp;
            }
            i--;
            continue;
        }
        for(j = 0;j < 64;j++){
            if(i == j) continue;
            if(A[j][i] == 0continue;
            for(k = 0;k < 64;k++){
                A[j][k] = (A[j][k] - A[i][k] + 2) % 2;
                
                B[j][k] = (B[j][k] - B[i][k] + 2) % 2;
            }
        }
    }
}
 
void transpose(matrix &A){
    int i, j;
    int C[64][64];
    for(i = 0;i < 64;i++){
        for(j = 0;j < 64;j++){
            C[i][j] = A[j][i];
        }
    }
    for(i = 0;i < 64;i++){
        for(j = 0;j < 64;j++) A[i][j] = C[i][j];
    }
}
 
void inverse(matrix &A, matrix &B){
    matrix tmp(64vector<int>(64));
    copy(tmp , A);
    int i, j;
    for(i = 0;i < 64;i++){
        for(j = 0;j < 64;j++){
            B[i][j] = 0;
        }
        B[i][i] = 1;
    }
    gauss_elem(tmp, B);
}
 
llu p_solve(matrix &A, matrix &B, llu p){
    llu S = ceil(sqrt(p));
    map<matrix, llu> mp;
    matrix C(64vector<int>(64));
    matrix inv_A(64vector<int>(64));
    inverse(A, inv_A); // A 보존  
    mat_pow(A, S);
    copy(C, A);
    int i, j;
    for(i = 0;i < 64;i++){
        for(j = 0;j < 64;j++) A[i][j] = 0;
        A[i][i] = 1;
    }
    for(i = 0;i <= S;i++){
        mp[B] = i;
        mat_mul(B, inv_A);
    }
    for(i = 0;i <= S;i++){
        if(mp.find(A) != mp.end()){
            return i * S + mp[A];
        }
        mat_mul(A, C);
    }
    return 0;
}
 
llu fpow(llu n, llu k){
    llu s = 1;
    for(;k;k >>= 1){
        if(k & 1) s = (__uint128_t) s * n % mod;
        n = (__uint128_t) n * n % mod; 
    }
    return s;
}
 
llu solve(){
    int i, j, k;
    llu a[10];
    matrix A(64vector<int>(64)), Ax(64vector<int>(64));
    for(i = 0;i < 7;i++){
        copy(A, M); copy(Ax, Mx);
        mat_pow(A, mod / p[i]);
        mat_pow(Ax, mod / p[i]);
        a[i] = p_solve(A, Ax, p[i]);
    }
    llu ret = 0;
    for(i = 0;i < 7;i++){
        ret = (ret + (__uint128_t) a[i] * fpow(mod / p[i], p[i] - 1) % mod) % mod;
    }
    return ret;
}
 
void init(llu s, llu t){
    ll i, j;
    for(i = 0;i < 64;i++){
        llu x = xorshift64(1ULL << i);
        for(j = 0;j < 64;j++){
            M[j][i] = x % 2;
            x >>= 1;
        }
    }
 
    llu es = s, et = t;
    for(i = 0;i < 64;i++){
        bit_s[i] = es % 2;
        es >>= 1;
        bit_t[i] = et % 2;
        et >>= 1;
    }
    
    int ret_s[64], ret_t[64];
    matrix X(64vector<int>(64)), Y(64vector<int>(64));
    for(i = 0;i < 64;i++){
        ret_s[i] = bit_s[i];
        ret_t[i] = bit_t[i];
    }
    for(i = 0;i < 64;i++){
        for(j = 0;j < 64;j++){
            X[i][j] = ret_s[j];
            Y[i][j] = ret_t[j];
        }
        vec_mul(M, ret_s);
        vec_mul(M, ret_t);
    }
    gauss_elem(X, Y);
    transpose(Y);
    copy(Mx, Y);
}
 
int main() {
    llu s, t;
    scanf("%llu %llu"&s, &t);
    init(s, t);
    printf("%llu"1 + solve());
}
cs

# 코드 닫기