블로그 옮겼습니다
Hackerrank) Tree Pruning 본문
\(2\leq{N}\leq{10^5}\)
\(1\leq{K}\leq{200}\)
이 문제는 N개의 노드로 이루어진 트리가 주어진다. 각 노드에는 가중치가 있다.
K가 주어지는데 최대 K번 가지를 칠 수 있다는 것이다. 어떤 노드 u를 가지친다는 말의 정의는
노드 u를 루트로 하는 서브트리를 제거한다는 것이다. 최대 K번 가지를 쳐서 남아있는 트리의 모든 노드의
가중치의 합을 최대로 하고싶을때 이 최대 가중치합을 구하는 문제이다.
우선 처음에 일반적인 트리dp 문제처럼 이렇게 풀어 보려고했었다.
dp[u][v][k] : 현재 노드 u를 보고있고 v개의 자식을 이미 봐주었으며 앞으로 최대 k번 가지칠수있을 때 최대 가중치 합
이렇게 하면 부분문제의 수는 최대 N*K 개이다. 왜냐하면 (u,v) 의 쌍은 간선의 수와 같으며, 간선의 수는 O(N)개이기
때문이다. 그리고 각각의 부분 문제를 구하는 방법은
1 | for(int i = 0 ; i <= k; i++) max(ans, go(v,0,i) + go(u,v+1,k-i)); | cs |
와 같은 식으로 구할 수가 있다. 그니까 앞으로 가지를 칠 수 있는 최대 횟수 k를 현재 보고있는 자식과
그 자식을 제외한 나머지 덩어리가 나누어 가져야 하기 때문에 k가지의 나눠가지는 경우의수가있기 때문이다.
그럼 하나의 부분문제를 구하는 데에 O(K) 가 걸리기 때문에 전체 시간복잡도는 O(NK^2) 이다.
그러면 시간안에 구할 수가 없게된다.
이 문제를 풀기위한 key idea는 바로 트리를 일렬로 펼치는 것이다. 그러면 처음에 전처리로
우선 dfs를 돌리면 순서를 매겨주고 각 노드의 순서를 인덱스로하여 이 노드를 루트로 하는
서브트리의 총 노드수와 가중치합을 구해놓는다. 그러면 부분문제는 이렇게 바뀐다.
dp[u][k] : 현재 노드 u를 보고있고 앞으로 가지칠 수 있는 최대 횟수가 k일 때의 최대 가중치 합
이렇게 부분문제를 바꿔주면 하나의 부분문제를 구하는 데에 걸리는 시간은 O(1)가 된다.
왜냐하면 매 순간 현재 노드를 가지 칠지, 말지만 결정하면 되기 때문이다.
가지를 친다면 현재 노드를 루트로하는 서브트리의 노드들은 한번에 삭제되기 때문에
이 노드들의 수만큼 u를 건너뛸 수가 있기 때문에 go(u+size[u], k-1) 를 해주면되고
만약 현재 노드를 가지치지 않는다면 자식들을 가지칠지 결정해야 하기 때문에 dfs 순서대로 이미
순서를 매겨주었기 때문에 go(u+1,k) 를 해주고 지금 노드는 그럼 가지치지않은 것이니 무조건
마지막에 트리에 남아있게 되므로 u의 가중치를 더해주면 되는 것이다. 그냥 이게 다다..
오히려 처음에 트리를 펴기위한 전처리를 더 신경써줘야한다. 그리고
앞으로 가지칠 수 있는 수가 0이라면 더이상 가지를 칠 수 없으므로 무조건 현재 노드를
루트로 하는 서브트리의 가중치합을 누적시키고 다음으로 건너뛰어야하니까 go(u+size[u],0) 를 해준다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | #include <bits/stdc++.h> using namespace std; typedef long long ll; const int mod = 1e9+7; const int inf = 0x3c3c3c3c; const long long infl = 0x3c3c3c3c3c3c3c3c; int w[100003]; int ww[100003]; int size[100003]; ll sum[100003]; int id[100003]; ll dp[100003][203]; vector<int> G[100003]; int N, K, cnt; void dfs(int u, int dad){ ww[cnt] = w[u]; id[u] = cnt++; for(int v : G[u]){ if(v == dad) continue; dfs(v,u); size[id[u]] += size[id[v]]; sum[id[u]] += sum[id[v]]; } size[id[u]]++; sum[id[u]]+=w[u]; } ll go(int pos, int k){ if(pos == cnt) return 0; ll& cache = dp[pos][k]; if(cache != -1) return cache; if(!k) return cache = go(pos+size[pos], k) + sum[pos]; return cache = max(go(pos+size[pos], k-1), go(pos+1,k)+ww[pos]); } int main() { memset(dp,-1,sizeof(dp)); scanf("%d%d",&N,&K); for(int i = 0 ; i < N; i++) scanf("%d",&w[i]); for(int i = 0 ; i < N-1; i++) { int a,b; scanf("%d%d",&a,&b); a--,b--; G[a].push_back(b); G[b].push_back(a); } dfs(0,-1); printf("%lld",go(0,K)); return 0; } | cs |
'Algorithm > Problem Solving' 카테고리의 다른 글
Topcoder SRM 522 Div1(Medium) CorrectMultiplication (0) | 2017.05.20 |
---|---|
BOJ 14276번 도로 건설 (0) | 2017.05.11 |
CSacademy Round #29 (Div. 2 only) D. Water Bottles (0) | 2017.05.11 |
Topcoder SRM 514 Div1(Medium) MagicalGirlLevelTwoDivOne (0) | 2017.05.10 |
Codeforces Round #412 (Div. 2) E. Prairie Partition (0) | 2017.05.09 |