블로그 옮겼습니다

BOJ 7812번 중앙 트리 본문

Algorithm/Problem Solving

BOJ 7812번 중앙 트리

sgc109 2017. 3. 20. 20:37

이 문제는 N개의 노드를 가지고 간선에 가중치가 있는 트리에서 중앙 정점이라는 것을 찾아서 이 중앙 정점에서 다른 정점까지의 모든 거리의 합을 구하는 문제이다.

중앙 정점의 정의는 어떤 한 점에서 다른 모든 정점까지의 거리의 합이 최소가 되는 점이다. 처음에 트리의 중심과 헷갈려서 트리의 중심으로 풀 뻔 했는데 잘 생각해 보면 트리의 중심은 어떤 한 점에서 다른 모든 정점까지의 거리의 최대값이 최소가 되는 점이기 때문에 확실히 다른 개념이다.


그렇다면 이 중앙 정점이라는 것은 어떻게 구할까? naive 하게는 각 N개의 정점에서부터 BFS나 DFS를돌려서 한 점에서 다른 모든 점까지의 거리의 합을 구하고 이 값을 바탕으로 최소값을 갱신해 나가는 O(N^2) 의 방법이 있을 것이다. 하지만 N이 10000 이기 때문에 이 방법으로는 조금 힘들 것이다.


그렇다면 조금 다른 방법을 생각해 보자. 우선 몇 가지 식과 용어를 정의하겠다.


일단, 'i번 노드가 j번노드 밖에 있다' 라는 말을 앞으로는 'i번 노드가 j번 노드를 루트로 갖는 서브트리에 속해 있지 않다.' 라는 뜻으로 사용하겠다.


size(i) = i번 노드를 루트로 하는 서브트리의 총 노드수

pSum(i) = i번 노드에서 모든 자손들까지의 거리의 총합

pSumR(i) = i번 노드에서 밖에있는 노드들까지의 거리의 총합

S(i) = 전체 트리에서 i번 노드에서 다른 모든 노드까지의 거리의 총합


그러면 S(i) = pSum(i) + pSumR(i) 라는 식을 세울 수가 있다.


그럼 pSum(i) 와 pSumR(i) 만 각 각 구해주면 되는데 dfs로 구할 수가 있다.


dfs를 두 번 돌려야 하는데 우선 첫번째 dfs로 size와 pSum 을 구해준다. 이 것들은 각각의 노드에서 그 노드의 자식노드를 루트로 갖는 서브트리에서의 답을 바탕으로 구해지기때문에 단순하게 재귀적으로 구할 수가 있다.


그리고 두번째 dfs에서는 pSumR(i) 을 구하는데 이 것을 구할 때에 size(i) 와 pSum(i) 가 사용되므로 미리 size(i)와 pSum(i) 를 구해준 것이다.


우리가 지금 here 번 노드를 기준으로 there 번이라는 자식을 보고 있다고 가정해보자.

그리고 here번 노드에서 there 번 노드로 가는 비용을 c 라고 해보자.

그렇다면 pSumR(there) 는 우선 there번 노드 밖에 있는 모든 노드들에서 here까지 가는 경로에 새롭게 here->there 의 간선이 추가되어야 there에 갈 수 있는 것이기 때문에 우선 there 번 노드 밖에 있는 노드의 수 만큼 c를 더해줘야한다. 이 노드 수는 size(root) - size(there) 를 하면 구할 수가 있다. 그 다음에는 기존에 there 를 제외하고 here 과 연결된 모든 노드들 에서 here로의 거리의 총합을 더해주면되는데, 이것은 here 의 밖에 있는 노드들부터 here까지 거리의 총합인 pSumR(here) 과 here의 밖에있지도않고 there의 서브트리에도 속해있지않은 노드들부터 here까지의 거리의 총합을 더해주면되는데, 이 것은 어떻게 구할까.


there를 루트로하는 서브트리에 속해있지도 않으면서 here의 밖에있지도 않은 노드는 결국 there를 제외한 here의 자식들을 루트로 갖는 서브트리들의 노드부터 here까지의 거리의 합이다. 즉 이것을 구하기 위해서는 pSum(here) - (pSum(there) + size(there) * c) 이다. 이 식을 설명하자면 here 를 루트로 하는 서브트리에서 here까지의 거리의 총합에서 there를 루트로 하는 서브트리에서 here 까지 가는 경로의 총합을 뺀건데, there 를 루트로 하는 서브트리에서 here 까지 가는 경로의 총합은 there를 루트로 하는 서브트리의 노드수만큼 there-here 간선을 지나므로 size(there) * c 에다가 recursive 하게 there 까지 올라오는 경로들의 합을 더하면 되기때문이다. (뭔가 더 잘 설명할 수가 있을 것같은데 잘 못하겠다..ㅠ) 그렇기 때문에 결과적으로 정리하면 

pSumR(i) = size(root) - size(there) + pSum(here) - (pSum(there) + size(there) * c) 이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define FOR(i,n) for(int i=0;i<n;++i)
#define pb push_back
#define all(v) (v).begin(),(v).end()
#define sz(v) ((int)(v).size())
#define inp1(a) scanf("%d",&a)
#define inp2(a,b) scanf("%d%d",&a,&b)
#define inp3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define inp4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define inp5(a,b,c,d,e) scanf("%d%d%d%d%d",&a,&b,&c,&d,&e)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef vector<int> vi;    
typedef vector<ll> vl;
typedef pair<int,int> pii;
typedef vector<pii> vii;
typedef vector<pll> vll;
typedef vector<vector<int> > vvi;
typedef pair<int,pair<int,int> > piii;
typedef vector<piii> viii;
const double EPSILON = 1e-9;
const double PI = acos(-1);
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
const int MAX_N = 102;
 
int N;
int u,v,c;
vii G[10003];
int size[10003];
ll pSum[10003];
ll pSumR[10003];
void dfs1(int here, int dad){
    for(auto p : G[here]){
        int there = p.first;
        int cost = p.second;
        if(there==dad) continue;
        dfs1(there, here);
        size[here] += size[there];
        pSum[here] +=  pSum[there] + size[there] * cost;
    }
    size[here]++;
}
void dfs2(int here, int dad){
    for(auto p : G[here]){
        int there = p.first;
        int cost = p.second;
        if(there==dad) continue;
        pSumR[there] += pSumR[here] + (size[0]-size[there]) * cost + pSum[here] - (pSum[there] + size[there] * cost);
        dfs2(there,here);
    }
}
int main() {
    while(1){
        FOR(i,10003) G[i].clear();
        memset(size,0,sizeof(size));
        memset(pSum,0,sizeof(pSum));
        memset(pSumR,0,sizeof(pSumR));
        inp1(N);
        if(!N) break;
        FOR(i,N-1){
            inp3(u,v,c);
            G[u].pb({v,c});
            G[v].pb({u,c});
        }
        dfs1(0,-1);
        dfs2(0,-1);
        ll ans = INFL;
        FOR(i,N) ans = min(ans, pSum[i] + pSumR[i]);
        printf("%lld\n",ans);
    }
    return 0;
}
cs


Comments