PS와 개발을 공부하자

Hackerrank) Tree Pruning 본문

Algorithm/Problem Solving

Hackerrank) Tree Pruning

sgc109 2017.05.11 17:48
문제 링크


\(2\leq{N}\leq{10^5}\)

\(1\leq{K}\leq{200}\)


이 문제는 N개의 노드로 이루어진 트리가 주어진다. 각 노드에는 가중치가 있다.

K가 주어지는데 최대 K번 가지를 칠 수 있다는 것이다. 어떤 노드 u를 가지친다는 말의 정의는

노드 u를 루트로 하는 서브트리를 제거한다는 것이다. 최대 K번 가지를 쳐서 남아있는 트리의 모든 노드의

가중치의 합을 최대로 하고싶을때 이 최대 가중치합을 구하는 문제이다.


우선 처음에 일반적인 트리dp 문제처럼 이렇게 풀어 보려고했었다.

dp[u][v][k] : 현재 노드 u를 보고있고 v개의 자식을 이미 봐주었으며 앞으로 최대 k번 가지칠수있을 때 최대 가중치 합


이렇게 하면 부분문제의 수는 최대 N*K 개이다. 왜냐하면 (u,v) 의 쌍은 간선의 수와 같으며, 간선의 수는 O(N)개이기

때문이다. 그리고 각각의 부분 문제를 구하는 방법은

1
for(int i = 0 ; i <= k; i++) max(ans, go(v,0,i) + go(u,v+1,k-i));
cs

와 같은 식으로 구할 수가 있다. 그니까 앞으로 가지를 칠 수 있는 최대 횟수 k를 현재 보고있는 자식과

그 자식을 제외한 나머지 덩어리가 나누어 가져야 하기 때문에 k가지의 나눠가지는 경우의수가있기 때문이다.

그럼 하나의 부분문제를 구하는 데에 O(K) 가 걸리기 때문에 전체 시간복잡도는 O(NK^2) 이다.

그러면 시간안에 구할 수가 없게된다.


이 문제를 풀기위한 key idea는 바로 트리를 일렬로 펼치는 것이다. 그러면 처음에 전처리로

우선 dfs를 돌리면 순서를 매겨주고 각 노드의 순서를 인덱스로하여 이 노드를 루트로 하는 

서브트리의 총 노드수와 가중치합을 구해놓는다. 그러면 부분문제는 이렇게 바뀐다.

dp[u][k] : 현재 노드 u를 보고있고 앞으로 가지칠 수 있는 최대 횟수가 k일 때의 최대 가중치 합


이렇게 부분문제를 바꿔주면 하나의 부분문제를 구하는 데에 걸리는 시간은 O(1)가 된다.

왜냐하면 매 순간 현재 노드를 가지 칠지, 말지만 결정하면 되기 때문이다.

가지를 친다면 현재 노드를 루트로하는 서브트리의 노드들은 한번에 삭제되기 때문에

이 노드들의 수만큼 u를 건너뛸 수가 있기 때문에 go(u+size[u], k-1) 를 해주면되고

만약 현재 노드를 가지치지 않는다면 자식들을 가지칠지 결정해야 하기 때문에 dfs 순서대로 이미

순서를 매겨주었기 때문에 go(u+1,k) 를 해주고 지금 노드는 그럼 가지치지않은 것이니 무조건

마지막에 트리에 남아있게 되므로 u의 가중치를 더해주면 되는 것이다. 그냥 이게 다다..

오히려 처음에 트리를 펴기위한 전처리를 더 신경써줘야한다. 그리고 

앞으로 가지칠 수 있는 수가 0이라면 더이상 가지를 칠 수 없으므로 무조건 현재 노드를

루트로 하는 서브트리의 가중치합을 누적시키고 다음으로 건너뛰어야하니까 go(u+size[u],0) 를 해준다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9+7;
const int inf = 0x3c3c3c3c;
const long long infl = 0x3c3c3c3c3c3c3c3c;
 
int w[100003];
int ww[100003];
int size[100003];
ll sum[100003];
int id[100003];
ll dp[100003][203];
vector<int> G[100003];
int N, K, cnt;
 
void dfs(int u, int dad){
    ww[cnt] = w[u];
    id[u] = cnt++;
    for(int v : G[u]){
        if(v == dad) continue;
        dfs(v,u);
        size[id[u]] += size[id[v]];
        sum[id[u]] += sum[id[v]];
    }
    size[id[u]]++;
    sum[id[u]]+=w[u];
}
 
ll go(int pos, int k){
    if(pos == cnt) return 0;
    ll& cache = dp[pos][k];
    if(cache != -1return cache;
    if(!k) return cache = go(pos+size[pos], k) + sum[pos];
    return cache = max(go(pos+size[pos], k-1), go(pos+1,k)+ww[pos]);
}
 
int main() {
    memset(dp,-1,sizeof(dp));
    scanf("%d%d",&N,&K);
    for(int i = 0 ; i < N; i++scanf("%d",&w[i]);
    for(int i = 0 ; i < N-1; i++) {
        int a,b;
        scanf("%d%d",&a,&b);
        a--,b--;
        G[a].push_back(b);
        G[b].push_back(a);
    }
    dfs(0,-1);
    printf("%lld",go(0,K));
    return 0;
}
 
cs


0 Comments
댓글쓰기 폼