블로그 옮겼습니다

K번째 원소구하기 본문

Algorithm/Memo &Tips

K번째 원소구하기

sgc109 2017. 5. 9. 11:20

우선 N개의 원소를 가진 어떤 임의의 배열 X가 있고,

이 안에 크기 M의 구간(subarray) A가 있다고 하자. 물론 값들은 정렬되지않은 상태이다.


이 때 A의 K번째 원소를 구하는 방법은 여러가지가 있다. 


우선 내가 아는 K번째 원소를 찾는 방법들을 구체적으로 써보자면

(사실 내가 파라메트릭서치라고 표현하는것도 역시 바이너리 서치인데, 구분을 위해 다르게씀)

(그리고 값의 범위가 큰 경우에는 좌표압축은 이미 했다고 가정한다.)

1. 벡터+sort

2. k번째 원소 세그먼트트리

3. k번째 원소 BIT

4. BIT + 파라메트릭서치 + 바이너리서치

5. 머지소트+세그먼트트리+파라메트릭서치+바이너리서치

6. 2D 세그먼트트리


이것 이외에도 버킷으로 어찌저찌하는 것도있는것같은데 잘모르겠다.


이것들 중에서 쿼리가 들어와서 구간이 계속해서 변한다면 1,2,3,4 번은 쓸 수가 없다. 왜냐하면 

구간마다 새로 트리를 구성해 주어야 하는데 인치웜으로 구간이 한칸씩 움직이는 것이 아니라면

아예 싹다 갈아엎어야 하기 때문에 엄청 느리다.


하지만 쿼리에 의해 구간이 아예 완전히 바뀌는 것이 아니라 인치웜으로 구간이 한칸씩 옆으로 이동하는 문제라면

2,3,4 번은 쓸 수가 있다. 왜냐하면 구간에 새로 추가되는 원소하나를 추가하고 잘려나가는 원소하나를 삭제하면 되는데

트리에서는 O(lgN) 에 할 수가 있기 때문이다.


그리고 쿼리로 인해 구간이 완전 변할 때 5,6번을 쓸 수가있는데 여기서 원소가 변경되는 쿼리까지 있다면

5번도 쓸 수가 없고 6번만 쓸 수가 있다.


1번은 일단 단순히 M개의 원소를 벡터에 넣고 정렬하여 A[k-1]로 구하는 방법이다.

이 방법은 딱 이 경우에는 다른 방법들과 시간복잡도상에서 차이가 없을지 몰라도

딱 이 경우만 쓸 수 있는 방법이다.


6번은 내가 안짜봤고 5번은 BOJ 7469번 가 대표적인 문제이다. 여기서는 2,3,4만 논할 것이다.


BOJ 9426번을 기준으로 설명을 하겠다.

4번으로 이 문제를 풀면

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
 
int t[1000003];
int A[1000003];
int N,K;
void update(int pos, int v){
    while(pos < (1<<16)){
        t[pos] += v;
        pos += pos&-pos;
    }
}
int query(int pos){
    int ret = 0;
    while(pos > 0){
        ret += t[pos];
        pos -= pos&-pos;
    }
    return ret;
}
 
int main() {
    scanf("%d%d",&N,&K);
    for(int i = 0 ; i < N; i++scanf("%d",&A[i]);
    for(int i = 0 ; i < K; i++) update(A[i]+1,1);
    long long sum = 0;
 
    for(int i = K; i <= N; i++){
        int lo = 0, hi = 1<<16;
        while(lo<hi){
            int mid = (lo+hi)/2;
            int s = query(mid);
            if(s >= (K+1)/2) hi = mid;
            else lo = mid+1;
        }
        sum += lo-1;
        if(i==N) break;
        update(A[i]+1,1);
        update(A[i-K]+1,-1);
    }
    printf("%lld",sum);
    return 0;
}
 
cs


이런식이고, 시간복잡도는 O(nlg^2n)


2번으로 풀면 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
 
int t[1000003];
int A[1000003];
int N,K;
int ans;
void update(int nl, int nr, int node, int pos, int v){
    if(nr < pos || pos < nl) return;
    if(nl == nr) {
        t[node] += v;
        return;
    }
    int nm = (nl+nr)/2;
    update(nl, nm, 2*node, pos, v);
    update(nm+1, nr, 2*node+1, pos, v);
    t[node] = t[2*node] + t[2*node+1];
}
void update(int pos, int v){
    update(0,1<<16,1,pos,v);
}
int search(int nl, int nr, int node, int k){
    if(nl == nr) return nl;
    int nm = (nl + nr)/2;
    if(k <= t[2*node]) return search(nl,nm,2*node,k);
    return search(nm+1,nr,2*node+1,k - t[2*node]);
}
int search(int k){
    return search(01<<161, k);
}
int main() {
    scanf("%d%d",&N,&K);
    for(int i = 0 ; i < N; i++scanf("%d",&A[i]);
    for(int i = 0 ; i < K; i++) update(A[i],1);
    long long sum = 0;
 
    for(int i = K; i <= N; i++){
        sum += search((K+1)/2);
        if(i==N) break;
        update(A[i],1);
        update(A[i-K],-1);
    }
    printf("%lld",sum);
    return 0;
}
 
cs

이런식이다. 그런데 update 함수를 비재귀(인덱스트리)로 짜면 더 빨라진다. 코드는 이렇다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
 
int t[1000003];
int A[1000003];
int N,K;
int ans;
 
void update(int pos, int v){
    t[pos += (1<<16)] += v;
    for(;pos>1; pos/=2) t[pos/2= t[pos] + t[pos^1];
}
int search(int nl, int nr, int node, int k){
    if(nl == nr) return nl;
    int nm = (nl + nr)/2;
    if(k <= t[2*node]) return search(nl,nm,2*node,k);
    return search(nm+1,nr,2*node+1,k - t[2*node]);
}
int search(int k){
    return search(0, (1<<16)-11, k);
}
int main() {
    scanf("%d%d",&N,&K);
    for(int i = 0 ; i < N; i++scanf("%d",&A[i]);
    for(int i = 0 ; i < K; i++) update(A[i],1);
    long long sum = 0;
    for(int i = K; i <= N; i++){
        sum += search((K+1)/2);
        if(i==N) break;
        update(A[i],1);
        update(A[i-K],-1);
    }
    printf("%lld",sum);
    return 0;
}
 
cs

이런 식이다. 그리고 사실 search 도 비재귀로 구현할 수가 있다. 아래와 같다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
 
int t[262145];
int A[250003];
int N,K;
int ans;
 
void update(int pos, int v){
    t[pos += (1<<16)] += v;
    for(;pos>1; pos/=2) t[pos/2= t[pos] + t[pos^1];
}
int search(int k){
    int pos = 1;
    while(pos < (1<<16)){
        if(k <= t[2*pos]) pos*=2;
        else k-=t[2*pos], pos=2*pos+1;
    }
    return pos-(1<<16);
}
 
int main() {
    scanf("%d%d",&N,&K);
    for(int i = 0 ; i < N; i++scanf("%d",&A[i]);
    for(int i = 0 ; i < K; i++) update(A[i],1);
    long long sum = 0;
    for(int i = K; i <= N; i++){
        sum += search((K+1)/2);
        if(i==N) break;
        update(A[i],1);
        update(A[i-K],-1);
    }
    printf("%lld",sum);
    return 0;
}
 
cs


위 셋의 시간복잡도는 모두 O(nlgn) 이다. 그런데 비재귀로 짜는 함수의 수가 늘어날 수록 속도가 빨라진다.


3번으로 풀면

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
 
int t[1000003];
int A[1000003];
int N,K;
void update(int pos, int v){
    while(pos < (1<<16)){
        t[pos] += v;
        pos += pos&-pos;
    }
}
int search(int k){
    int idx=0;
    for(int i=16;i>=0;i--){
        if(idx+(1<<i)<=(1<<16)-1 && t[idx+(1<<i)] <k){
            k-=t[idx+(1<<i)];
            idx=idx+(1<<i);
        }
    }
    if(!k) return idx;
    return idx+1;
}
int main() {
    scanf("%d%d",&N,&K);
    for(int i = 0 ; i < N; i++scanf("%d",&A[i]);
    for(int i = 0 ; i < K; i++) update(A[i]+1,1);
    long long sum = 0;
 
    for(int i = K; i <= N; i++){
        sum += search((K+1)/2)-1;
        if(i==N) break;
        update(A[i]+1,1);
        update(A[i-K]+1,-1);
    }
    printf("%lld",sum);
    return 0;
}
 
cs


이런식이다. 시간복잡도는 O(nlgn) 이다.


2,3번은 한번의 쿼리로 k 번째 수를 찾아 내려가는 것인데, 매 노드마다 왼쪽 구간에 존재하는 수와

오른쪽 구간에 존재하는 수를 검사하여 어느방향으로 가야할지를 결정하면서 리프노드까지 가는 것이다.

(실제 원소들의 값이 매우 클때 에는 좌표압축을 하기 때문에 수의 수의 가짓수가 O(n) 개라고 하겠다)

사실 시간복잡도는 2,3번은 O(nlgn)이며, 4번이 O(nlg^2n) 로 가장 느려야되는데

BIT가 워낙 빨라서 그런지 3번보다 4번이 더 빠르다. 하지만 update 함수를 비재귀(인덱스트리)로 작성하면

3번이 4번보다 빠르다.


+추가적으로 이 문제를 5번 방법으로 풀 때의 코드를 올린다. 하지만 이 문제의 제약에서는 시간안에 돌지않는다.

복잡도가 무려 O(nlgn + nlg^3n) 이기 때문이다. 하지만 구간의 정보에 대한 쿼리가 주어지는 문제에서는

앞에서 말했듯 2,3,4번은 못쓰고 이 방법을 써야 할 것이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
 
int N,K;
vector<int> d[1000000];
int A[1000003];
vector<int> init(int l, int r, int node){
    if(l == r) return d[node] = vector<int>(1,A[l]);
    int m = (l+r)/2;
    
    vector<int> L = init(l,m,2*node);
    vector<int> R = init(m+1,r,2*node+1);
 
    int pos1 = 0 , pos2 = 0;
    vector<int> ret;
    while(pos1 < L.size() && pos2 < R.size()) {
        if(L[pos1] < R[pos2]) ret.push_back(L[pos1++]);
        else ret.push_back(R[pos2++]);
    }
    while(pos1 < L.size()) ret.push_back(L[pos1++]);
    while(pos2 < R.size()) ret.push_back(R[pos2++]);
 
    return d[node] = ret;
}
 
int query(int l, int r, int nl, int nr, int node, int k){
    if(l <= nl && nr <= r) return upper_bound(d[node].begin(), d[node].end(), k) - d[node].begin();
    if(nr < l || r < nl) return 0;
    int nm = (nl+nr)/2;
    return query(l,r,nl,nm,2*node,k) + query(l,r,nm+1,nr,2*node+1,k);
}
 
int main() {
    scanf("%d%d",&N,&K);
    for(int i = 0 ; i < N; i++scanf("%d",&A[i]);
    init(0,N-1,1);
    long long sum = 0;
    for(int i = 0 ; i < N - K + 1 ; i++){
        int lo = 0, hi = 1<<17;
        while(lo<hi){
            int mid = (lo+hi)/2;
            if(query(i,i+K-1,0,N-1,1,mid) >= (K+1)/2) hi = mid;
            else lo = mid+1;
        }
        sum += lo;
    }
 
    printf("%lld",sum);
    return 0;
}
 
cs


Comments