PS와 개발을 공부하자

FFT(고속 푸리에 변환) 예제 - Hackerrank) Best spot 본문

Algorithm/Problem Solving

FFT(고속 푸리에 변환) 예제 - Hackerrank) Best spot

sgc109 2017. 7. 28. 13:32

https://www.hackerrank.com/contests/w6/challenges/best-spot/problem



우선 FFT 코드는 명우님의 블로그의 코드를 참고했음을 알립니다. http://blog.myungwoo.kr/54 


이 문제는 FFT를 모르면 절대 풀 수 없어 보이고, 안다면 단순한 구현문제로 전락해 버린다.

우선 naive 하게 생각해 보았을 때 모든 위치에서 답을 계산해 보면 O((RC)^2) 의 시간이 걸린다.

R, C가 최대 500 이기 때문에 시간안에 절대 돌 수 없다.

FFT를 사용하면 O(RClgRC) 의 시간이 걸리게 되어 시간안에 돌게 된다.


우선 squared difference (x-y)^2 의 식을 풀어보면 x^2 - 2xy + y^2 이라는 것을 알 수가 있다.

어차피 작은 그리드의 위치를 옮겨도 y는 일정하므로 y^2은 미리 구해놓을 수가 있다.

그리고 사실 x^2 도 O(RC)에 미리 구해놓을 수가 있다. 그런데 문제는 xy 를 어떻게 빠르게 구하냐 이다.

이 때 FFT가 사용된다. FFT에 대한 자세한 설명은 생략하고, 중요한건 2차원 그리드를 1차원으로 펴는것이다.

작은 그리드의 한 행, 한 행 사이에 적절히 0을 넣어주어야 한다. 그리고 나머지는 그냥 다 구현인듯..





1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <bits/stdc++.h>
#define pb push_back
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
#define fastio() ios_base::sync_with_stdio(0),cin.tie(0)
using namespace std;
typedef long long ll;
typedef complex<double> base;
const ll INFL = 0x3c3c3c3c3c3c3c3c;
const double PI = 2.0 * acos(0.0);
 
void fft(vector <base> &a, bool invert)
{
    int n = sz(a);
    for (int i=1,j=0;i<n;i++){
        int bit = n >> 1;
        for (;j>=bit;bit>>=1) j -= bit;
        j += bit;
        if (i < j) swap(a[i],a[j]);
    }
    for (int len=2;len<=n;len<<=1){
        double ang = 2*PI/len*(invert?-1:1);
        base wlen(cos(ang),sin(ang));
        for (int i=0;i<n;i+=len){
            base w(1);
            for (int j=0;j<len/2;j++){
                base u = a[i+j], v = a[i+j+len/2]*w;
                a[i+j] = u+v;
                a[i+j+len/2= u-v;
                w *= wlen;
            }
        }
    }
    if (invert){
        for (int i=0;i<n;i++) a[i] /= n;
    }
}
 
vector<int> operator*(const vector<int> &a,const vector<int> &b)
{
    vector<int> res;
    vector <base> fa(all(a)), fb(all(b));
    int n = 1;
    while (n < max(sz(a),sz(b))) n <<= 1;
    fa.resize(n); fb.resize(n);
    fft(fa,false); fft(fb,false);
    for (int i=0;i<n;i++) fa[i] *= fb[i];
    fft(fa,true);
    res.resize(n);
    for (int i=0;i<n;i++) res[i] = int(fa[i].real()+(fa[i].real()>0?0.5:-0.5));
    return res;
}
 
int poww(int x){return x * x;}
int H, W, h, w;
int P[503][503], p[503][503];
ll dp[503][503];
vector<int> A, B, C, D;
int main() {
    fastio();
    cin >> H >> W;
    A.resize(2 * H * W);
    for(int i = 0 ; i < H; i++) {
        for(int j = 0 ; j < W; j++){
            cin >> P[i][j];
            A[i * W + j] = P[i][j];
        }
    }
 
    cin >> h >> w;
    B.resize(2 * H * W);
    for(int i = 0 ; i < h; i++){
        for(int j = 0 ; j < w; j++){
            cin >> p[i][j];
            B[i * W + j] = p[i][j];
        }
        for(int j = w; j < W; j++) B[i * W + j] = 0;
    }
 
    for(int i = h ; i < H; i++){
        for(int j = 0 ; j < W; j++){
            B[i * W + j] = 0;
        }
    }
    for(int i = H * W ; i < 2 * H * W; i++) B[i] = B[i - H * W];
    reverse(all(B));
 
    C = A * B;
    
    D = vector<int>(C.begin() + H * W - 1, C.begin() + H * W - 1 + (W + w - 1* (H - h + 1));
    C.clear();
    int pos = 0;
    while(pos < sz(D)){
        for(int i = 0 ; i < W - w + 1; i++, pos++) C.emplace_back(D[pos]);
        for(int i = 0 ; i < w - 1; i++, pos++);
    }
 
    for(int i = 1 ; i <= H; i++){
        for(int j = 1 ; j <= W; j++){
            dp[i][j] = dp[i - 1][j] + dp[i][j - 1- dp[i - 1][j - 1+ poww(P[i - 1][j - 1]);
        }
    }
 
    ll ans = INFL;
    int ansI = 0, ansJ = 0;
    pos = 0;
    for(int i = h; i <= H; i++){
        for(int j = w; j <= W; j++){
            ll val = -2 * C[pos] + dp[i][j] - dp[i - h][j] - dp[i][j - w] + dp[i - h][j - w];
            if(ans > val) {
                ans = val;
                ansI = i - h;
                ansJ = j - w;
            }
            pos++;
        }
    }
 
    ll add = 0;
    for(int i = 0 ; i < h; i++){
        for(int j = 0 ; j < w; j++){
            add += poww(p[i][j]);
        }
    }
 
    cout << ans + add << "\n";
    cout << ansI + 1 << " " << ansJ + 1 << "\n";
 
    return 0;
}
cs


Tag
공유하기 링크
0 Comments
댓글쓰기 폼